From 3df249c127e637f8af6c99e5e4fed9c466803d79 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 6 Aug 2020 16:25:26 +0800 Subject: [PATCH 01/42] Use table expr when inserting table, close #3239 --- callbacks/create.go | 8 ++------ tests/go.mod | 4 ++-- tests/table_test.go | 11 +++++++++++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index b41a3ef2..3a414dd7 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -43,9 +43,7 @@ func Create(config *Config) func(db *gorm.DB) { if db.Statement.SQL.String() == "" { db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) + db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") @@ -105,9 +103,7 @@ func CreateWithReturning(db *gorm.DB) { } if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) + db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") diff --git a/tests/go.mod b/tests/go.mod index 6eb6eb07..82d4fdc8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,8 +8,8 @@ require ( github.com/lib/pq v1.6.0 gorm.io/driver/mysql v0.3.1 gorm.io/driver/postgres v0.2.6 - gorm.io/driver/sqlite v1.0.8 - gorm.io/driver/sqlserver v0.2.5 + gorm.io/driver/sqlite v1.0.9 + gorm.io/driver/sqlserver v0.2.6 gorm.io/gorm v0.2.19 ) diff --git a/tests/table_test.go b/tests/table_test.go index faee6499..647b5e19 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -40,6 +40,17 @@ func TestTable(t *testing.T) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) From 39c8d6220b75b5a28dfff6ae88da17485b35dc46 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 6 Aug 2020 17:48:46 +0800 Subject: [PATCH 02/42] Fix soft delete panic when using unaddressable value --- soft_delete.go | 2 +- tests/delete_test.go | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/soft_delete.go b/soft_delete.go index 6b88b1a5..180bf745 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -64,7 +64,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } - if stmt.Dest != stmt.Model && stmt.Model != nil { + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) diff --git a/tests/delete_test.go b/tests/delete_test.go index 3d461f65..f5b3e784 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -43,6 +43,14 @@ func TestDelete(t *testing.T) { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } } + + if err := DB.Delete(users[0]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } } func TestDeleteWithTable(t *testing.T) { From 15b96ed3f482a29201b2c6c15fa0d3936d4d9a17 Mon Sep 17 00:00:00 2001 From: Caelansar Date: Mon, 10 Aug 2020 15:34:20 +0800 Subject: [PATCH 03/42] add testcase --- tests/scanner_valuer_test.go | 69 +++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 63a7c63c..6b8f086e 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -35,7 +35,9 @@ func TestScannerValuer(t *testing.T) { {"name1", "value1"}, {"name2", "value2"}, }, - Role: Role{Name: "admin"}, + Role: Role{Name: "admin"}, + ExampleStruct: ExampleStruct1{"name", "value"}, + ExampleStructPtr: &ExampleStruct1{"name", "value"}, } if err := DB.Create(&data).Error; err != nil { @@ -49,6 +51,14 @@ func TestScannerValuer(t *testing.T) { } AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") + + if result.ExampleStructPtr.Val != "value" { + t.Errorf(`ExampleStructPtr.Val should equal to "value", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value" { + t.Errorf(`ExampleStruct.Val should equal to "value", but got %v`, result.ExampleStruct.Val) + } } func TestScannerValuerWithFirstOrCreate(t *testing.T) { @@ -124,21 +134,23 @@ func TestInvalidValuer(t *testing.T) { type ScannerValuerStruct struct { gorm.Model - Name sql.NullString - Gender *sql.NullString - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - Birthday sql.NullTime - Password EncryptedData - Bytes []byte - Num Num - Strings StringsSlice - Structs StructsSlice - Role Role - UserID *sql.NullInt64 - User User - EmptyTime EmptyTime + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Password EncryptedData + Bytes []byte + Num Num + Strings StringsSlice + Structs StructsSlice + Role Role + UserID *sql.NullInt64 + User User + EmptyTime EmptyTime + ExampleStruct ExampleStruct1 + ExampleStructPtr *ExampleStruct1 } type EncryptedData []byte @@ -207,6 +219,31 @@ type ExampleStruct struct { Value string } +type ExampleStruct1 struct { + Name string `json:"name,omitempty"` + Val string `json:"val,omitempty"` +} + +func (s ExampleStruct1) Value() (driver.Value, error) { + if len(s.Name) == 0 { + return nil, nil + } + //for test, has no practical meaning + s.Name = "" + return json.Marshal(s) +} + +func (s *ExampleStruct1) Scan(src interface{}) error { + switch value := src.(type) { + case string: + return json.Unmarshal([]byte(value), s) + case []byte: + return json.Unmarshal(value, s) + default: + return errors.New("not supported") + } +} + type StructsSlice []ExampleStruct func (l StructsSlice) Value() (driver.Value, error) { From 4a9d3a688aa47a7db7611902f6467f0b311aee79 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Aug 2020 21:22:51 +0800 Subject: [PATCH 04/42] Don't parse ignored anonymous field --- schema/field.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema/field.go b/schema/field.go index 1ca4cb6d..ea6364a4 100644 --- a/schema/field.go +++ b/schema/field.go @@ -311,7 +311,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { if reflect.Indirect(fieldValue).Kind() == reflect.Struct { var err error field.Creatable = false From a3dda47afac01b7430efb200d27473e24fe2fca9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Aug 2020 21:22:51 +0800 Subject: [PATCH 05/42] Don't parse ignored anonymous field --- schema/field.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema/field.go b/schema/field.go index 1ca4cb6d..ea6364a4 100644 --- a/schema/field.go +++ b/schema/field.go @@ -311,7 +311,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { if reflect.Indirect(fieldValue).Kind() == reflect.Struct { var err error field.Creatable = false From 7d45833f3e309f9c15bb9ca301c1782b23cb9f0e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 12:05:55 +0800 Subject: [PATCH 06/42] Fix driver.Valuer interface returns nil, close #3248 --- schema/field.go | 60 +++++++++++++++++------------------- tests/scanner_valuer_test.go | 52 ++++++++++++++++--------------- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/schema/field.go b/schema/field.go index ea6364a4..84fdb695 100644 --- a/schema/field.go +++ b/schema/field.go @@ -731,40 +731,10 @@ func (field *Field) setupValuerAndSetter() { return nil } default: - if _, ok := fieldValue.Interface().(sql.Scanner); ok { - // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } - - reflectV := reflect.ValueOf(v) - if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else if reflectV.Kind() == reflect.Ptr { - if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(value, reflectV.Elem().Interface()) - } - } else { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) - } - return - } - } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) - - if valuer, ok := v.(driver.Valuer); ok { - if valuer == nil || reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - v, _ = valuer.Value() - } - } - if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { @@ -778,10 +748,38 @@ func (field *Field) setupValuerAndSetter() { if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } + + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + err = fieldValue.Interface().(sql.Scanner).Scan(v) } return } + } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner + field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } else { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { return fallbackSetter(value, v, field.Set) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 6b8f086e..b8306af7 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -36,8 +36,8 @@ func TestScannerValuer(t *testing.T) { {"name2", "value2"}, }, Role: Role{Name: "admin"}, - ExampleStruct: ExampleStruct1{"name", "value"}, - ExampleStructPtr: &ExampleStruct1{"name", "value"}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err != nil { @@ -46,19 +46,18 @@ func TestScannerValuer(t *testing.T) { var result ScannerValuerStruct - if err := DB.Find(&result).Error; err != nil { + if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil { t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) } + if result.ExampleStructPtr.Val != "value2" { + t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value1" { + t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct) + } AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") - - if result.ExampleStructPtr.Val != "value" { - t.Errorf(`ExampleStructPtr.Val should equal to "value", but got %v`, result.ExampleStructPtr.Val) - } - - if result.ExampleStruct.Val != "value" { - t.Errorf(`ExampleStruct.Val should equal to "value", but got %v`, result.ExampleStruct.Val) - } } func TestScannerValuerWithFirstOrCreate(t *testing.T) { @@ -68,9 +67,11 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) { } data := ScannerValuerStruct{ - Name: sql.NullString{String: "name", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } var result ScannerValuerStruct @@ -109,7 +110,9 @@ func TestInvalidValuer(t *testing.T) { } data := ScannerValuerStruct{ - Password: EncryptedData("xpass1"), + Password: EncryptedData("xpass1"), + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err == nil { @@ -149,8 +152,8 @@ type ScannerValuerStruct struct { UserID *sql.NullInt64 User User EmptyTime EmptyTime - ExampleStruct ExampleStruct1 - ExampleStructPtr *ExampleStruct1 + ExampleStruct ExampleStruct + ExampleStructPtr *ExampleStruct } type EncryptedData []byte @@ -215,25 +218,24 @@ func (l *StringsSlice) Scan(input interface{}) error { } type ExampleStruct struct { - Name string - Value string + Name string + Val string } -type ExampleStruct1 struct { - Name string `json:"name,omitempty"` - Val string `json:"val,omitempty"` +func (ExampleStruct) GormDataType() string { + return "bytes" } -func (s ExampleStruct1) Value() (driver.Value, error) { +func (s ExampleStruct) Value() (driver.Value, error) { if len(s.Name) == 0 { return nil, nil } - //for test, has no practical meaning + // for test, has no practical meaning s.Name = "" return json.Marshal(s) } -func (s *ExampleStruct1) Scan(src interface{}) error { +func (s *ExampleStruct) Scan(src interface{}) error { switch value := src.(type) { case string: return json.Unmarshal([]byte(value), s) From 045d5f853838b9800acdb8ae204969ba3d93e00a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 12:18:36 +0800 Subject: [PATCH 07/42] Fix count with join and no model, close #3255 --- callbacks/query.go | 2 +- tests/count_test.go | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/callbacks/query.go b/callbacks/query.go index be829fbc..5ae1e904 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -96,7 +96,7 @@ func BuildQuerySQL(db *gorm.DB) { // inline joins if len(db.Statement.Joins) != 0 { - if len(db.Statement.Selects) == 0 { + if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} diff --git a/tests/count_test.go b/tests/count_test.go index 05661ae8..216fa3a1 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -67,4 +67,9 @@ func TestCount(t *testing.T) { if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) } + + var count4 int64 + if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count) + } } From ecc946be6e93a108bbdcc10cf2719d08baa50f3f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 16:05:06 +0800 Subject: [PATCH 08/42] Test update from sub query --- callbacks/update.go | 9 +++++++-- tests/update_test.go | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 12806af6..0ced3ffb 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -174,11 +174,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { sort.Strings(keys) for _, k := range keys { + kv := value[k] + if _, ok := kv.(*gorm.DB); ok { + kv = []interface{}{kv} + } + if stmt.Schema != nil { if field := stmt.Schema.LookUpField(k); field != nil { if field.DBName != "" { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) assignValue(field, value[k]) } } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { @@ -189,7 +194,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) } } diff --git a/tests/update_test.go b/tests/update_test.go index 2ff150dd..83a7b9a2 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -545,3 +545,21 @@ func TestUpdatesTableWithIgnoredValues(t *testing.T) { t.Errorf("element's ignored field should not be updated") } } + +func TestUpdateFromSubQuery(t *testing.T) { + user := *GetUser("update_from_sub_query", Config{Company: true}) + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error: %v", err) + } + + if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Company.Name { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } +} From dea93edb6acdccdb398a5f9d89412f9bd0be5b39 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 16:28:21 +0800 Subject: [PATCH 09/42] Copy TableExpr when clone statement --- statement.go | 1 + tests/update_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/statement.go b/statement.go index e9d826c4..b5b5db5a 100644 --- a/statement.go +++ b/statement.go @@ -392,6 +392,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) clone() *Statement { newStmt := &Statement{ + TableExpr: stmt.TableExpr, Table: stmt.Table, Model: stmt.Model, Dest: stmt.Dest, diff --git a/tests/update_test.go b/tests/update_test.go index 83a7b9a2..a59a8856 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -562,4 +562,14 @@ func TestUpdateFromSubQuery(t *testing.T) { if result.Name != user.Company.Name { t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) } + + DB.Model(&user.Company).Update("Name", "new company name") + if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + DB.First(&result, user.ID) + if result.Name != "new company name" { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } } From 2c4e8571259bf6193cf5d396594104fca7fa727d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 18:09:04 +0800 Subject: [PATCH 10/42] Should ignore association conditions when querying with struct --- statement.go | 12 ++++++------ tests/query_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/statement.go b/statement.go index b5b5db5a..6114f468 100644 --- a/statement.go +++ b/statement.go @@ -309,10 +309,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c for _, field := range s.Fields { if field.Readable { if v, isZero := field.ValueOf(reflectValue); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { + if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } } } @@ -322,10 +322,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c for _, field := range s.Fields { if field.Readable { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { + if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } } } diff --git a/tests/query_test.go b/tests/query_test.go index 59f1130b..4c2a2abd 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -103,6 +103,22 @@ func TestFind(t *testing.T) { }) } +func TestQueryWithAssociation(t *testing.T) { + user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create user: %v", err) + } + + if err := DB.Where(&user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } + + if err := DB.Where(user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } +} + func TestFindInBatches(t *testing.T) { var users = []User{ *GetUser("find_in_batches", Config{}), From 2faff25dfbcfff9e3fb37c8fcf1a20a468f887a8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 18:38:39 +0800 Subject: [PATCH 11/42] Fix FirstOr(Init/Create) when assigning with association --- finisher_api.go | 67 +++++++++++++++++++++++++++++++-------------- tests/query_test.go | 2 ++ 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 33a4f121..8a3d4199 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -132,19 +133,47 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat return } -func (tx *DB) assignExprsToValue(exprs []clause.Expression) { - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) +func (tx *DB) assignInterfacesToValue(values ...interface{}) { + for _, value := range values { + switch v := value.(type) { + case []clause.Expression: + for _, expr := range v { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := tx.Statement.Schema.LookUpField(column); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Column: + if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + default: + } } - case clause.Column: - if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: + exprs := tx.Statement.BuildCondition(value) + tx.assignInterfacesToValue(exprs) + default: + if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + for _, f := range s.Fields { + if f.Readable { + if v, isZero := f.ValueOf(reflectValue); !isZero { + if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + } + } + } + } } - default: + } else if len(values) > 0 { + exprs := tx.Statement.BuildCondition(values[0], values[1:]...) + tx.assignInterfacesToValue(exprs) + return } } } @@ -154,22 +183,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } tx.Error = nil } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return } @@ -180,20 +207,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return tx.Create(dest) diff --git a/tests/query_test.go b/tests/query_test.go index 4c2a2abd..72dd89b9 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -110,6 +110,8 @@ func TestQueryWithAssociation(t *testing.T) { t.Fatalf("errors happened when create user: %v", err) } + user.CreatedAt = time.Time{} + user.UpdatedAt = time.Time{} if err := DB.Where(&user).First(&User{}).Error; err != nil { t.Errorf("search with struct with association should returns no error, but got %v", err) } From 6834c25cec6b037299970cc845de1a186e04ba1f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 12:02:41 +0800 Subject: [PATCH 12/42] Fix stack overflow for embedded self-referred associations, close #3269 --- schema/field.go | 8 +++++++- schema/model_test.go | 22 ++++++++++++++++++++++ schema/relationship.go | 17 +++++++---------- schema/schema_test.go | 6 ++++++ 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/schema/field.go b/schema/field.go index 84fdb695..78eeccdc 100644 --- a/schema/field.go +++ b/schema/field.go @@ -317,7 +317,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Creatable = false field.Updatable = false field.Readable = false - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + + cacheStore := schema.cacheStore + if _, embedded := schema.cacheStore.Load("embedded_cache_store"); !embedded { + cacheStore = &sync.Map{} + cacheStore.Store("embedded_cache_store", true) + } + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { diff --git a/schema/model_test.go b/schema/model_test.go index a13372b5..84c7b327 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -39,3 +39,25 @@ type AdvancedDataTypeUser struct { Active mybool Admin *mybool } + +type BaseModel struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time + CreatedBy *int + Created *VersionUser `gorm:"foreignKey:CreatedBy"` + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +type VersionModel struct { + BaseModel + Version int + CompanyID int +} + +type VersionUser struct { + VersionModel + Name string + Age uint + Birthday *time.Time +} diff --git a/schema/relationship.go b/schema/relationship.go index 93080105..537a3582 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -5,7 +5,6 @@ import ( "reflect" "regexp" "strings" - "sync" "github.com/jinzhu/inflection" "gorm.io/gorm/clause" @@ -67,16 +66,14 @@ func (schema *Schema) parseRelation(field *Field) { } ) + cacheStore := schema.cacheStore if field.OwnerSchema != nil { - if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil { - schema.err = err - return - } - } else { - if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { - schema.err = err - return - } + cacheStore = field.OwnerSchema.cacheStore + } + + if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil { + schema.err = err + return } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { diff --git a/schema/schema_test.go b/schema/schema_test.go index 99781e47..966f80e4 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -160,3 +160,9 @@ func TestCustomizeTableName(t *testing.T) { t.Errorf("Failed to customize table with TableName method") } } + +func TestNestedModel(t *testing.T) { + if _, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}); err != nil { + t.Fatalf("failed to parse nested user, got error %v", err) + } +} From 2a716e04e6528f1979dc0a7a2de509f0350e9e04 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 12:16:42 +0800 Subject: [PATCH 13/42] Avoid panic for invalid transaction, close #3271 --- finisher_api.go | 6 ++++-- tests/transaction_test.go | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 8a3d4199..19534460 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -445,7 +445,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // Commit commit a transaction func (db *DB) Commit() *DB { - if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) } else { db.AddError(ErrInvalidTransaction) @@ -456,7 +456,9 @@ func (db *DB) Commit() *DB { // Rollback rollback a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { - db.AddError(committer.Rollback()) + if !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Rollback()) + } } else { db.AddError(ErrInvalidTransaction) } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index c101388a..aea151d9 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "errors" "testing" @@ -57,6 +58,25 @@ func TestTransaction(t *testing.T) { } } +func TestCancelTransaction(t *testing.T) { + ctx := context.Background() + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + user := *GetUser("cancel_transaction", Config{}) + DB.Create(&user) + + err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var result User + tx.First(&result, user.ID) + return nil + }) + + if err == nil { + t.Fatalf("Transaction should get error when using cancelled context") + } +} + func TestTransactionWithBlock(t *testing.T) { assertPanic := func(f func()) { defer func() { From 681268cc43a2aa665e5577680b88ac77b9e5b64c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 16:31:09 +0800 Subject: [PATCH 14/42] Refactor Create/Query/Update/DeleteClauses interface --- schema/field.go | 22 -------------------- schema/interfaces.go | 8 ++++---- schema/schema.go | 17 ++++++++++++++++ soft_delete.go | 48 +++++++++++++++++++++++++++++++++----------- 4 files changed, 57 insertions(+), 38 deletions(-) diff --git a/schema/field.go b/schema/field.go index 78eeccdc..bc47e543 100644 --- a/schema/field.go +++ b/schema/field.go @@ -88,23 +88,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...) - } - - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...) - } - - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...) - } - - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...) - } - // if field is valuer, used its value or first fields as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { @@ -353,11 +336,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } - - field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) - field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) } else { schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } diff --git a/schema/interfaces.go b/schema/interfaces.go index f5d07843..e8e51e4c 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -7,17 +7,17 @@ type GormDataTypeInterface interface { } type CreateClausesInterface interface { - CreateClauses() []clause.Interface + CreateClauses(*Field) []clause.Interface } type QueryClausesInterface interface { - QueryClauses() []clause.Interface + QueryClauses(*Field) []clause.Interface } type UpdateClausesInterface interface { - UpdateClauses() []clause.Interface + UpdateClauses(*Field) []clause.Interface } type DeleteClausesInterface interface { - DeleteClauses() []clause.Interface + DeleteClauses(*Field) []clause.Interface } diff --git a/schema/schema.go b/schema/schema.go index 9206c24e..d81da4b8 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -219,6 +219,23 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) return schema, schema.err } } + + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } } } diff --git a/soft_delete.go b/soft_delete.go index 180bf745..875623bc 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -24,37 +24,61 @@ func (n DeletedAt) Value() (driver.Value, error) { return n.Time, nil } -func (DeletedAt) QueryClauses() []clause.Interface { +func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { return []clause.Interface{ clause.Where{Exprs: []clause.Expression{ clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"}, + Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: nil, }, }}, } } -func (DeletedAt) DeleteClauses() []clause.Interface { - return []clause.Interface{SoftDeleteClause{}} +type SoftDeleteQueryClause struct { + Field *schema.Field } -type SoftDeleteClause struct { -} - -func (SoftDeleteClause) Name() string { +func (sd SoftDeleteQueryClause) Name() string { return "" } -func (SoftDeleteClause) Build(clause.Builder) { +func (sd SoftDeleteQueryClause) Build(clause.Builder) { } -func (SoftDeleteClause) MergeClause(*clause.Clause) { +func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { } -func (SoftDeleteClause) ModifyStatement(stmt *Statement) { +func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + }}) + stmt.Clauses["soft_delete_enabled"] = clause.Clause{} + } +} + +func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteDeleteClause{Field: f}} +} + +type SoftDeleteDeleteClause struct { + Field *schema.Field +} + +func (sd SoftDeleteDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}}) + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}}) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) From 9fcc337bd1ccfccfddcdbd4a9b8b08ad08bf465c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 17:41:36 +0800 Subject: [PATCH 15/42] Fix create from map --- callbacks/associations.go | 59 ++++++++++++++++++++++++--------------- callbacks/create.go | 22 ++++++++++++--- callbacks/helper.go | 10 ++++++- tests/create_test.go | 39 ++++++++++++++++++++++++++ tests/go.mod | 2 +- 5 files changed, 103 insertions(+), 29 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 3508335a..2710ffe9 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -48,14 +48,19 @@ func SaveBeforeAssociations(db *gorm.DB) { elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } } + } else { + break } } @@ -112,22 +117,24 @@ func SaveAfterAssociations(db *gorm.DB) { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) - } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() } - } - elems = reflect.Append(elems, rv) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + db.AddError(ref.ForeignKey.Set(rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } } } @@ -207,7 +214,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) @@ -277,7 +287,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) diff --git a/callbacks/create.go b/callbacks/create.go index 3a414dd7..4cc0f555 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -61,16 +61,26 @@ func Create(config *Config) func(db *gorm.DB) { case reflect.Slice, reflect.Array: if config.LastInsertIDReversed { for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID-- } } } else { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID++ } } @@ -140,6 +150,10 @@ func CreateWithReturning(db *gorm.DB) { for rows.Next() { BEGIN: reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) + if reflect.Indirect(reflectValue).Kind() != reflect.Struct { + break + } + for idx, field := range fields { fieldValue := field.ReflectValueOf(reflectValue) diff --git a/callbacks/helper.go b/callbacks/helper.go index 7bd910f6..80fbc2a1 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -26,6 +26,10 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { values.Columns = append(values.Columns, clause.Column{Name: k}) + if len(values.Values) == 0 { + values.Values = [][]interface{}{{}} + } + values.Values[0] = append(values.Values[0], value) } } @@ -61,11 +65,15 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st sort.Strings(columns) values.Values = make([][]interface{}, len(mapValues)) + values.Columns = make([]clause.Column, len(columns)) for idx, column := range columns { + values.Columns[idx] = clause.Column{Name: column} + for i, v := range result[column] { - if i == 0 { + if len(values.Values[i]) == 0 { values.Values[i] = make([]interface{}, len(columns)) } + values.Values[i][idx] = v } } diff --git a/tests/create_test.go b/tests/create_test.go index ae6e1232..ab0a78d4 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -39,6 +39,45 @@ func TestCreate(t *testing.T) { } } +func TestCreateFromMap(t *testing.T) { + if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result User + if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + datas := []map[string]interface{}{ + {"Name": "create_from_map_2", "Age": 19}, + {"name": "create_from_map_3", "Age": 20}, + } + + if err := DB.Model(&User{}).Create(datas).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var result3 User + if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } +} + func TestCreateWithAssociations(t *testing.T) { var user = *GetUser("create_with_associations", Config{ Account: true, diff --git a/tests/go.mod b/tests/go.mod index 82d4fdc8..54a808d0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v0.3.1 gorm.io/driver/postgres v0.2.6 gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.6 + gorm.io/driver/sqlserver v0.2.7 gorm.io/gorm v0.2.19 ) From dc48e04896aa529bb4014390347e21e2c4c509b2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Aug 2020 11:21:40 +0800 Subject: [PATCH 16/42] Fix nested embedded struct, close #3278 --- schema/field.go | 8 +++----- schema/model_test.go | 5 ++--- schema/schema.go | 37 ++++++++++++++++++----------------- schema/schema_test.go | 18 ++++++++++++++++- schema/utils.go | 2 ++ tests/embedded_struct_test.go | 4 ++-- utils/tests/utils.go | 2 +- 7 files changed, 46 insertions(+), 30 deletions(-) diff --git a/schema/field.go b/schema/field.go index bc47e543..35c1e44d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -301,14 +301,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Updatable = false field.Readable = false - cacheStore := schema.cacheStore - if _, embedded := schema.cacheStore.Load("embedded_cache_store"); !embedded { - cacheStore = &sync.Map{} - cacheStore.Store("embedded_cache_store", true) - } + cacheStore := &sync.Map{} + cacheStore.Store(embeddedCacheKey, true) if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { schema.err = err } + for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema ef.OwnerSchema = field.EmbeddedSchema diff --git a/schema/model_test.go b/schema/model_test.go index 84c7b327..1f2b0948 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -41,7 +41,7 @@ type AdvancedDataTypeUser struct { } type BaseModel struct { - ID uint `gorm:"primarykey"` + ID uint CreatedAt time.Time CreatedBy *int Created *VersionUser `gorm:"foreignKey:CreatedBy"` @@ -51,8 +51,7 @@ type BaseModel struct { type VersionModel struct { BaseModel - Version int - CompanyID int + Version int } type VersionUser struct { diff --git a/schema/schema.go b/schema/schema.go index d81da4b8..458256d1 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -212,29 +212,30 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { - // parse relations for unidentified fields - for _, field := range schema.Fields { - if field.DataType == "" && field.Creatable { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && field.Creatable { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } } - } - fieldValue := reflect.New(field.IndirectFieldType) - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } } } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 966f80e4..c0ad3c25 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -162,7 +162,23 @@ func TestCustomizeTableName(t *testing.T) { } func TestNestedModel(t *testing.T) { - if _, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}); err != nil { + versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { t.Fatalf("failed to parse nested user, got error %v", err) } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Int, Size: 64}, + {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, + } + + for _, f := range fields { + checkSchemaField(t, versionUser, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } } diff --git a/schema/utils.go b/schema/utils.go index 1481d428..29f2fefb 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -9,6 +9,8 @@ import ( "gorm.io/gorm/utils" ) +var embeddedCacheKey = "embedded_cache_store" + func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index fb0d6f23..c29078bd 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -160,9 +160,9 @@ func TestEmbeddedRelations(t *testing.T) { Advanced bool } - DB.Debug().Migrator().DropTable(&AdvancedUser{}) + DB.Migrator().DropTable(&AdvancedUser{}) - if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil { + if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { t.Errorf("Failed to auto migrate advanced user, got error %v", err) } } diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 0067d5c6..817e4b0b 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -76,7 +76,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } } else { name := reflect.ValueOf(got).Type().Elem().Name() - t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) + t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got) } return } From 50826742fd0bd26caf55a7a5a96b2c85b612f4ae Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Aug 2020 18:00:36 +0800 Subject: [PATCH 17/42] Add error gorm.ErrInvalidData --- callbacks/create.go | 2 ++ callbacks/update.go | 2 ++ errors.go | 2 ++ tests/update_test.go | 9 +++++++++ 4 files changed, 15 insertions(+) diff --git a/callbacks/create.go b/callbacks/create.go index 4cc0f555..7a32ed5c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -309,6 +309,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } } + default: + stmt.AddError(gorm.ErrInvalidData) } } diff --git a/callbacks/update.go b/callbacks/update.go index 0ced3ffb..5656d166 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -252,6 +252,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } + default: + stmt.AddError(gorm.ErrInvalidData) } } diff --git a/errors.go b/errors.go index 115b8e25..32ff8ec1 100644 --- a/errors.go +++ b/errors.go @@ -19,6 +19,8 @@ var ( ErrPrimaryKeyRequired = errors.New("primary key required") // ErrModelValueRequired model value required ErrModelValueRequired = errors.New("model value required") + // ErrInvalidData unsupported data + ErrInvalidData = errors.New("unsupported data") // ErrUnsupportedDriver unsupported driver ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered diff --git a/tests/update_test.go b/tests/update_test.go index a59a8856..49a13be9 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -334,6 +334,15 @@ func TestSelectWithUpdateWithMap(t *testing.T) { AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") } +func TestWithUpdateWithInvalidMap(t *testing.T) { + user := *GetUser("update_with_invalid_map", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error for unsupported updating data") + } +} + func TestOmitWithUpdate(t *testing.T) { user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) From b5de8aeb425cc9eccf92b8c3252fc0a7201ed52e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Aug 2020 18:58:53 +0800 Subject: [PATCH 18/42] Fix overrite SELECT clause --- chainable_api.go | 3 +++ finisher_api.go | 2 +- tests/query_test.go | 5 +++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/chainable_api.go b/chainable_api.go index 4df8780e..78724cc8 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -91,6 +91,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } + delete(tx.Statement.Clauses, "SELECT") case string: fields := strings.FieldsFunc(v, utils.IsChar) @@ -112,6 +113,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } + + delete(tx.Statement.Clauses, "SELECT") } else { tx.Statement.AddClause(clause.Select{ Distinct: db.Statement.Distinct, diff --git a/finisher_api.go b/finisher_api.go index 19534460..88873948 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -294,7 +294,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - defer tx.Statement.AddClause(clause.Select{}) + defer delete(tx.Statement.Clauses, "SELECT") } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { expr := clause.Expr{SQL: "count(1)"} diff --git a/tests/query_test.go b/tests/query_test.go index 72dd89b9..d71c813a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -346,6 +346,11 @@ func TestSelect(t *testing.T) { if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) } + + r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID) + if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) + } } func TestOmit(t *testing.T) { From 3411425d651e540cf19f9845d83cc507d929f2e6 Mon Sep 17 00:00:00 2001 From: deepoli <67894732+deepoil@users.noreply.github.com> Date: Tue, 18 Aug 2020 20:03:09 +0900 Subject: [PATCH 19/42] fix return value and delete unused default (#3280) --- chainable_api.go | 2 +- finisher_api.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 78724cc8..9b46a95b 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -69,7 +69,7 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) { if len(args) > 0 { tx = tx.Select(args[0], args[1:]...) } - return tx + return } // Select specify fields that you want when querying, creating, updating diff --git a/finisher_api.go b/finisher_api.go index 88873948..db069c5c 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -148,7 +148,6 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) } - default: } } } From c1782d60c149483111b021e29c412d9139bd46ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 15:47:08 +0800 Subject: [PATCH 20/42] Fix embedded scanner/valuer, close #3283 --- schema/field.go | 34 +++++++++++++++++++++------------- tests/scanner_valuer_test.go | 6 ++++++ 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/schema/field.go b/schema/field.go index 35c1e44d..59367399 100644 --- a/schema/field.go +++ b/schema/field.go @@ -92,32 +92,40 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { - var overrideFieldValue bool - if v, err := valuer.Value(); v != nil && err == nil { - overrideFieldValue = true + if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { fieldValue = reflect.ValueOf(v) } - if field.IndirectFieldType.Kind() == reflect.Struct { - for i := 0; i < field.IndirectFieldType.NumField(); i++ { - if !overrideFieldValue { - newFieldType := field.IndirectFieldType.Field(i).Type + var getRealFieldValue func(reflect.Value) + getRealFieldValue = func(v reflect.Value) { + rv := reflect.Indirect(v) + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + for i := 0; i < rv.Type().NumField(); i++ { + newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - overrideFieldValue = true - } - // copy tag settings from valuer - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + if rv.Type() != reflect.Indirect(fieldValue).Type() { + getRealFieldValue(fieldValue) + } + + if fieldValue.IsValid() { + return + } + + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } } } } } + + getRealFieldValue(fieldValue) } } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index b8306af7..ce8a2b50 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -27,6 +27,7 @@ func TestScannerValuer(t *testing.T) { Male: sql.NullBool{Bool: true, Valid: true}, Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}}, Password: EncryptedData("pass1"), Bytes: []byte("byte"), Num: 18, @@ -143,6 +144,7 @@ type ScannerValuerStruct struct { Male sql.NullBool Height sql.NullFloat64 Birthday sql.NullTime + Allergen NullString Password EncryptedData Bytes []byte Num Num @@ -299,3 +301,7 @@ func (t *EmptyTime) Scan(v interface{}) error { func (t EmptyTime) Value() (driver.Value, error) { return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil } + +type NullString struct { + sql.NullString +} From 3313c11888538af30abed9b168550b426a4af082 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 19:02:32 +0800 Subject: [PATCH 21/42] Fix embedded struct containing field named ID, close #3286 --- schema/field.go | 8 ++++++++ schema/schema_helper_test.go | 9 +++++++-- schema/schema_test.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/schema/field.go b/schema/field.go index 59367399..de937132 100644 --- a/schema/field.go +++ b/schema/field.go @@ -336,6 +336,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.PrimaryKey = true } else { ef.PrimaryKey = false + + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } + + if ef.DefaultValue == "" { + ef.HasDefaultValue = false + } } for k, v := range field.TagSettings { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index f202b487..4e916f84 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -49,7 +49,12 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* } } - if parsedField, ok := s.FieldsByName[f.Name]; !ok { + parsedField, ok := s.FieldsByDBName[f.DBName] + if !ok { + parsedField, ok = s.FieldsByName[f.Name] + } + + if !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") @@ -62,7 +67,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* for _, name := range []string{f.DBName, f.Name} { if name != "" { - if field := s.LookUpField(name); field == nil || parsedField != field { + if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) } } diff --git a/schema/schema_test.go b/schema/schema_test.go index c0ad3c25..c28812af 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -182,3 +182,35 @@ func TestNestedModel(t *testing.T) { }) } } + +func TestEmbeddedStruct(t *testing.T) { + type Company struct { + ID int + Name string + } + + type Corp struct { + ID uint + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } +} From 528e5ba5c41b647367d48e527b9fe9ad7dfcdd72 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 20:30:39 +0800 Subject: [PATCH 22/42] Cleanup Model after Count --- finisher_api.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index db069c5c..cf46f78a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -289,6 +289,9 @@ func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest + defer func() { + tx.Statement.Model = nil + }() } if len(tx.Statement.Selects) == 0 { From 0c9870d1ae52a466837daf7f8386e3f2c0c1505c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Aug 2020 10:39:01 +0800 Subject: [PATCH 23/42] Test Association Mode with conditions --- tests/associations_has_many_test.go | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index d8befd8a..173e9231 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -21,6 +21,23 @@ func TestHasManyAssociation(t *testing.T) { DB.Model(&user2).Association("Pets").Find(&user2.Pets) CheckUser(t, user2, user) + var pets []Pet + DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets) + + if len(pets) != 1 { + t.Fatalf("should only find one pets, but got %v", len(pets)) + } + + CheckPet(t, pets[0], *user.Pets[0]) + + if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 { + t.Fatalf("should only find one pets, but got %v", count) + } + + if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 { + t.Fatalf("should only find no pet with invalid conditions, but got %v", count) + } + // Count AssertAssociationCount(t, user, "Pets", 2, "") @@ -40,13 +57,13 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") - var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} - if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil { + if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) } - for _, pet := range pets { + for _, pet := range pets2 { var pet = pet if pet.ID == 0 { t.Fatalf("Pet's ID should be created") From 06de6e8834baf8ed56230727cdf715809e2c7f27 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Aug 2020 10:58:35 +0800 Subject: [PATCH 24/42] Test same field name from embedded field, close #3291 --- schema/schema_helper_test.go | 2 +- schema/schema_test.go | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 4e916f84..cc0306e0 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -57,7 +57,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings") if f.DBName != "" { if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { diff --git a/schema/schema_test.go b/schema/schema_test.go index c28812af..8bd1e5ca 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -4,6 +4,7 @@ import ( "sync" "testing" + "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) @@ -184,13 +185,19 @@ func TestNestedModel(t *testing.T) { } func TestEmbeddedStruct(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + type Company struct { - ID int - Name string + ID int + OwnerID int + Name string } type Corp struct { - ID uint + CorpBase Base Company `gorm:"embedded;embeddedPrefix:company_"` } @@ -201,9 +208,11 @@ func TestEmbeddedStruct(t *testing.T) { } fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, } for _, f := range fields { From f88e8b072c6e9dc5ecb0530823ee957f9cff5f6f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Aug 2020 18:13:29 +0800 Subject: [PATCH 25/42] Check valid pointer before use it as Valuer --- schema/field.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/schema/field.go b/schema/field.go index de937132..497aa02d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -473,16 +473,16 @@ func (field *Field) setupValuerAndSetter() { } } - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = setter(value, v) - } - } else if reflectV.Kind() == reflect.Ptr { + if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { err = setter(value, reflectV.Elem().Interface()) } + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = setter(value, v) + } } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } From 2b510d6423f6299d53eee6a69252a6acc4c431c7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 21 Aug 2020 15:40:50 +0800 Subject: [PATCH 26/42] Don't create index for join table, close #3294 --- schema/relationship.go | 4 ++-- schema/utils.go | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 537a3582..c8d129f2 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -225,7 +225,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(removeSettingFromTag(ownField.StructField.Tag, "column"), "autoincrement"), + Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -248,7 +248,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(removeSettingFromTag(relField.StructField.Tag, "column"), "autoincrement"), + Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), }) } diff --git a/schema/utils.go b/schema/utils.go index 29f2fefb..41bd9d60 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -51,8 +51,11 @@ func toColumns(val string) (results []string) { return } -func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { - return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) +func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag { + for _, name := range names { + tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) + } + return tag } // GetRelationsValues get relations's values from a reflect value From 3a97639880a6a965c5e8209e2ff5557008e8b191 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 10:40:37 +0800 Subject: [PATCH 27/42] Fix unordered joins, close #3267 --- callbacks/query.go | 8 ++++---- chainable_api.go | 5 +---- statement.go | 13 +++++++++---- tests/joins_test.go | 8 ++++++++ 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 5ae1e904..f6cb32d5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) { } joins := []clause.Join{} - for name, conds := range db.Statement.Joins { + for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name for _, s := range relation.FieldSchema.DBNames { @@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) { }) } else { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) } } diff --git a/chainable_api.go b/chainable_api.go index 9b46a95b..e1b73457 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -172,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() - if tx.Statement.Joins == nil { - tx.Statement.Joins = map[string][]interface{}{} - } - tx.Statement.Joins[query] = args + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) return } diff --git a/statement.go b/statement.go index 6114f468..214a15bb 100644 --- a/statement.go +++ b/statement.go @@ -29,7 +29,7 @@ type Statement struct { Distinct bool Selects []string // selected columns Omits []string // omit columns - Joins map[string][]interface{} + Joins []join Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool @@ -44,6 +44,11 @@ type Statement struct { assigns []interface{} } +type join struct { + Name string + Conds []interface{} +} + // StatementModifier statement modifier interface type StatementModifier interface { ModifyStatement(*Statement) @@ -401,7 +406,6 @@ func (stmt *Statement) clone() *Statement { Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, - Joins: map[string][]interface{}{}, Preloads: map[string][]interface{}{}, ConnPool: stmt.ConnPool, Schema: stmt.Schema, @@ -417,8 +421,9 @@ func (stmt *Statement) clone() *Statement { newStmt.Preloads[k] = p } - for k, j := range stmt.Joins { - newStmt.Joins[k] = j + if len(stmt.Joins) > 0 { + newStmt.Joins = make([]join, len(stmt.Joins)) + copy(newStmt.Joins, stmt.Joins) } stmt.Settings.Range(func(k, v interface{}) bool { diff --git a/tests/joins_test.go b/tests/joins_test.go index e54d3784..f78ddf67 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "regexp" "sort" "testing" @@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) { if db5.Error != nil { t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement + + if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } } func TestJoinsWithSelect(t *testing.T) { From cc6a64adfb0ed47d5f8ccf8de13eaf8145656973 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 15:40:19 +0800 Subject: [PATCH 28/42] Support smart migrate, close #3078 --- migrator.go | 1 + migrator/migrator.go | 63 ++++++++++++++++++++++++++++++++-- schema/field.go | 5 +++ statement.go | 1 - tests/go.mod | 6 ++-- tests/migrate_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 149 insertions(+), 7 deletions(-) diff --git a/migrator.go b/migrator.go index 37051f81..ed8a8e26 100644 --- a/migrator.go +++ b/migrator.go @@ -42,6 +42,7 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error + MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) diff --git a/migrator/migrator.go b/migrator/migrator.go index d50159dd..d93b8a6d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "reflect" + "regexp" "strings" "gorm.io/gorm" @@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { - // TODO smart migrate data type for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { @@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, field := range stmt.Schema.FieldsByDBName { - if !tx.Migrator().HasColumn(value, field.DBName) { + var foundColumn *sql.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { return err } + } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + // found, smart migrate + return err } } @@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } return nil }); err != nil { - fmt.Println(err) return err } } @@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error { + // found, smart migrate + fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + realDataType := strings.ToLower(columnType.DatabaseTypeName()) + + alterColumn := false + + // check size + if length, _ := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) + matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) + if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) { + alterColumn = true + } + } + + // check nullable + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + // not primary key & database is nullable + if !field.PrimaryKey && nullable { + alterColumn = true + } + } + + if alterColumn { + return m.DB.Migrator().AlterColumn(value, field.Name) + } + + return nil +} + func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() diff --git a/schema/field.go b/schema/field.go index 497aa02d..524d19fb 100644 --- a/schema/field.go +++ b/schema/field.go @@ -55,6 +55,7 @@ type Field struct { Comment string Size int Precision int + Scale int FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField @@ -160,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Precision, _ = strconv.Atoi(p) } + if s, ok := field.TagSettings["SCALE"]; ok { + field.Scale, _ = strconv.Atoi(s) + } + if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { field.NotNull = true } diff --git a/statement.go b/statement.go index 214a15bb..95d23fa5 100644 --- a/statement.go +++ b/statement.go @@ -379,7 +379,6 @@ func (stmt *Statement) Build(clauses ...string) { } } } - // TODO handle named vars } func (stmt *Statement) Parse(value interface{}) (err error) { diff --git a/tests/go.mod b/tests/go.mod index 54a808d0..9d4e892d 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.3.1 - gorm.io/driver/postgres v0.2.6 + gorm.io/driver/mysql v0.3.2 + gorm.io/driver/postgres v0.2.9 gorm.io/driver/sqlite v1.0.9 gorm.io/driver/sqlserver v0.2.7 - gorm.io/gorm v0.2.19 + gorm.io/gorm v0.2.36 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 1b002049..4cc8a7c3 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) { } } +func TestSmartMigrateColumn(t *testing.T) { + type UserMigrateColumn struct { + ID uint + Name string + Salary float64 + Birthday time.Time + } + + DB.Migrator().DropTable(&UserMigrateColumn{}) + + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + Name string `gorm:"size:128"` + Salary float64 `gorm:"precision:2"` + Birthday time.Time `gorm:"precision:2"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 128 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + + type UserMigrateColumn3 struct { + ID uint + Name string `gorm:"size:256"` + Salary float64 `gorm:"precision:3"` + Birthday time.Time `gorm:"precision:3"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 256 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("salary's precision should be 2, but got %v", precision) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + +} + func TestMigrateWithComment(t *testing.T) { type UserWithComment struct { gorm.Model From ebdb4edda8363fdd79c87ab323ca19b2be7a8872 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 20:08:23 +0800 Subject: [PATCH 29/42] Add AllowGlobalUpdate mode --- callbacks/delete.go | 2 +- callbacks/update.go | 2 +- gorm.go | 7 +++++++ soft_delete.go | 2 +- tests/delete_test.go | 4 ++++ tests/update_test.go | 4 ++++ 6 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 288f2d69..f444f020 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -51,7 +51,7 @@ func Delete(db *gorm.DB) { } } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { db.AddError(gorm.ErrMissingWhereClause) return } diff --git a/callbacks/update.go b/callbacks/update.go index 5656d166..bd8a4150 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -69,7 +69,7 @@ func Update(db *gorm.DB) { db.Statement.Build("UPDATE", "SET", "WHERE") } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { db.AddError(gorm.ErrMissingWhereClause) return } diff --git a/gorm.go b/gorm.go index 1ace0099..3c187f42 100644 --- a/gorm.go +++ b/gorm.go @@ -32,6 +32,8 @@ type Config struct { DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool + // AllowGlobalUpdate allow global update + AllowGlobalUpdate bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -61,6 +63,7 @@ type Session struct { PrepareStmt bool WithConditions bool SkipDefaultTransaction bool + AllowGlobalUpdate bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -154,6 +157,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.SkipDefaultTransaction = true } + if config.AllowGlobalUpdate { + txConfig.AllowGlobalUpdate = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx diff --git a/soft_delete.go b/soft_delete.go index 875623bc..d33bf866 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -98,7 +98,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !ok { + if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { stmt.DB.AddError(ErrMissingWhereClause) return } diff --git a/tests/delete_test.go b/tests/delete_test.go index f5b3e784..09c1a075 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -118,4 +118,8 @@ func TestBlockGlobalDelete(t *testing.T) { if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while deleting error") } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } } diff --git a/tests/update_test.go b/tests/update_test.go index 49a13be9..e52dc652 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -222,6 +222,10 @@ func TestBlockGlobalUpdate(t *testing.T) { if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } } func TestSelectWithUpdate(t *testing.T) { From 84dbb36d3bd91a5e7b3c1ee5a617ea923a4098d2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 24 Aug 2020 20:24:25 +0800 Subject: [PATCH 30/42] Add Golang v1.15 --- .github/workflows/tests.yml | 10 +++++----- tests/default_value_test.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b626ce94..4388c31d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest, macos-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -38,7 +38,7 @@ jobs: sqlite_windows: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [windows-latest] runs-on: ${{ matrix.platform }} @@ -64,7 +64,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -108,7 +108,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] # can not run in macOS and widnowsOS runs-on: ${{ matrix.platform }} @@ -150,7 +150,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} diff --git a/tests/default_value_test.go b/tests/default_value_test.go index ea496d60..aa4a511a 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -10,7 +10,7 @@ func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model Email string `gorm:"not null;index:,unique"` - Name string `gorm:"not null;default:'foo'"` + Name string `gorm:"not null;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;not null;default:''"` Age int `gorm:"default:18"` From 3dfa8a66f1bef0a7469c34968cb298c208e59fb9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 17:27:28 +0800 Subject: [PATCH 31/42] Fix panic when delet without pointer, close #3308 --- callbacks/delete.go | 12 ++++++------ soft_delete.go | 5 ----- tests/delete_test.go | 4 ++++ 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index f444f020..76b78fb4 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -41,7 +41,7 @@ func Delete(db *gorm.DB) { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } - if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) @@ -51,15 +51,15 @@ func Delete(db *gorm.DB) { } } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } - db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.Build("DELETE", "FROM", "WHERE") } + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/soft_delete.go b/soft_delete.go index d33bf866..484f565c 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -98,11 +98,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - return - } - stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build("UPDATE", "SET", "WHERE") } diff --git a/tests/delete_test.go b/tests/delete_test.go index 09c1a075..17299677 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -48,6 +48,10 @@ func TestDelete(t *testing.T) { t.Errorf("errors happened when delete: %v", err) } + if err := DB.Delete(User{}).Error; err != gorm.ErrMissingWhereClause { + t.Errorf("errors happened when delete: %v", err) + } + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("should returns record not found error, but got %v", err) } From 0f3201e73b97c358d2b7d98d24185fab91e5dd73 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 18:18:16 +0800 Subject: [PATCH 32/42] friendly invalid field error message --- schema/relationship.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index c8d129f2..dad2e629 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -336,7 +336,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue primarySchema, foreignSchema = schema, relation.FieldSchema ) - reguessOrErr := func(err string, args ...interface{}) { + reguessOrErr := func() { switch gl { case guessHas: schema.guessRelation(relation, field, guessEmbeddedHas) @@ -345,7 +345,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) default: - schema.err = fmt.Errorf(err, args...) + schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) } } @@ -354,7 +354,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if field.OwnerSchema != nil { primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } else { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } case guessBelongs: @@ -363,7 +363,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if field.OwnerSchema != nil { primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema } else { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } } @@ -373,7 +373,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if f := foreignSchema.LookUpField(foreignKey); f != nil { foreignFields = append(foreignFields, f) } else { - reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys) + reguessOrErr() return } } @@ -392,7 +392,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } if len(foreignFields) == 0 { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } else if len(relation.primaryKeys) > 0 { for idx, primaryKey := range relation.primaryKeys { @@ -400,11 +400,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if len(primaryFields) < idx+1 { primaryFields = append(primaryFields, f) } else if f != primaryFields[idx] { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) + reguessOrErr() return } } else { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) + reguessOrErr() return } } @@ -414,7 +414,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } else if len(primarySchema.PrimaryFields) == len(foreignFields) { primaryFields = append(primaryFields, primarySchema.PrimaryFields...) } else { - reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name) + reguessOrErr() return } } From 3195ae12072f51d15064a3428f4e906c6873c4e2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 18:59:19 +0800 Subject: [PATCH 33/42] Allow override alias table in preload conditions --- callbacks/preload.go | 6 +++--- tests/preload_test.go | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index cd09a6d6..25b8cb2b 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -50,7 +50,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) - tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) @@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } reflectResults := rel.FieldSchema.MakeSlice().Elem() - column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues) + column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { @@ -103,7 +103,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) fieldValues := make([]interface{}, len(relForeignFields)) diff --git a/tests/preload_test.go b/tests/preload_test.go index 3caa17b4..7e5d2622 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -108,6 +109,20 @@ func TestPreloadWithConds(t *testing.T) { } CheckUser(t, users2[0], users[0]) + + var users3 []User + if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB { + return tx.Table("accounts AS a").Select("a.*") + }).Find(&users3, "id IN ?", userIDs).Error; err != nil { + t.Errorf("failed to query, got error %v", err) + } + sort.Slice(users3, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for i, u := range users3 { + CheckUser(t, u, users[i]) + } } func TestNestedPreloadWithConds(t *testing.T) { From 0d96f99499f2501a0d3a5e0d93ef157cc287e44f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 26 Aug 2020 12:22:11 +0800 Subject: [PATCH 34/42] Update README --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index b51297c4..c727e2cf 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Composite Primary Key * Auto Migrations * Logger -* Extendable, write Plugins based on GORM callbacks +* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… * Every feature comes with tests * Developer Friendly @@ -40,4 +40,3 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) - From ce8853e7a6142420a786be1b0f0c5ffeb8778778 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 15:03:57 +0800 Subject: [PATCH 35/42] Add GormValuer interface support --- README.md | 2 +- callbacks/create.go | 8 +++--- callbacks/delete.go | 4 +-- callbacks/interfaces.go | 39 ++++++++++++++++++++++++++++ callbacks/query.go | 2 +- callbacks/update.go | 8 +++--- interfaces.go | 37 +++------------------------ schema/interfaces.go | 4 ++- statement.go | 2 ++ tests/scanner_valuer_test.go | 49 ++++++++++++++++++++++++++++++++++++ 10 files changed, 108 insertions(+), 47 deletions(-) create mode 100644 callbacks/interfaces.go diff --git a/README.md b/README.md index c727e2cf..9c0aded0 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Context, Prepared Statment Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map -* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key * Auto Migrations * Logger diff --git a/callbacks/create.go b/callbacks/create.go index 7a32ed5c..cc7e2671 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { - if i, ok := value.(gorm.BeforeCreateInterface); ok { + if i, ok := value.(BeforeCreateInterface); ok { called = true db.AddError(i.BeforeCreate(tx)) } @@ -203,14 +203,14 @@ func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterCreate { - if i, ok := value.(gorm.AfterCreateInterface); ok { + if i, ok := value.(AfterCreateInterface); ok { called = true db.AddError(i.AfterCreate(tx)) } diff --git a/callbacks/delete.go b/callbacks/delete.go index 76b78fb4..e95117a1 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -11,7 +11,7 @@ import ( func BeforeDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.BeforeDeleteInterface); ok { + if i, ok := value.(BeforeDeleteInterface); ok { db.AddError(i.BeforeDelete(tx)) return true } @@ -75,7 +75,7 @@ func Delete(db *gorm.DB) { func AfterDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterDeleteInterface); ok { + if i, ok := value.(AfterDeleteInterface); ok { db.AddError(i.AfterDelete(tx)) return true } diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go new file mode 100644 index 00000000..2302470f --- /dev/null +++ b/callbacks/interfaces.go @@ -0,0 +1,39 @@ +package callbacks + +import "gorm.io/gorm" + +type BeforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} + +type AfterCreateInterface interface { + AfterCreate(*gorm.DB) error +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*gorm.DB) error +} + +type AfterUpdateInterface interface { + AfterUpdate(*gorm.DB) error +} + +type BeforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type AfterSaveInterface interface { + AfterSave(*gorm.DB) error +} + +type BeforeDeleteInterface interface { + BeforeDelete(*gorm.DB) error +} + +type AfterDeleteInterface interface { + AfterDelete(*gorm.DB) error +} + +type AfterFindInterface interface { + AfterFind(*gorm.DB) error +} diff --git a/callbacks/query.go b/callbacks/query.go index f6cb32d5..0703b92e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -214,7 +214,7 @@ func Preload(db *gorm.DB) { func AfterQuery(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterFindInterface); ok { + if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) return true } diff --git a/callbacks/update.go b/callbacks/update.go index bd8a4150..73c062b4 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { - if i, ok := value.(gorm.BeforeUpdateInterface); ok { + if i, ok := value.(BeforeUpdateInterface); ok { called = true db.AddError(i.BeforeUpdate(tx)) } @@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterUpdate { - if i, ok := value.(gorm.AfterUpdateInterface); ok { + if i, ok := value.(AfterUpdateInterface); ok { called = true db.AddError(i.AfterUpdate(tx)) } diff --git a/interfaces.go b/interfaces.go index b2ce59b3..e933952b 100644 --- a/interfaces.go +++ b/interfaces.go @@ -53,38 +53,7 @@ type TxCommitter interface { Rollback() error } -type BeforeCreateInterface interface { - BeforeCreate(*DB) error -} - -type AfterCreateInterface interface { - AfterCreate(*DB) error -} - -type BeforeUpdateInterface interface { - BeforeUpdate(*DB) error -} - -type AfterUpdateInterface interface { - AfterUpdate(*DB) error -} - -type BeforeSaveInterface interface { - BeforeSave(*DB) error -} - -type AfterSaveInterface interface { - AfterSave(*DB) error -} - -type BeforeDeleteInterface interface { - BeforeDelete(*DB) error -} - -type AfterDeleteInterface interface { - AfterDelete(*DB) error -} - -type AfterFindInterface interface { - AfterFind(*DB) error +// Valuer gorm valuer interface +type Valuer interface { + GormValue(context.Context, *DB) clause.Expr } diff --git a/schema/interfaces.go b/schema/interfaces.go index e8e51e4c..98abffbd 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -1,6 +1,8 @@ package schema -import "gorm.io/gorm/clause" +import ( + "gorm.io/gorm/clause" +) type GormDataTypeInterface interface { GormDataType() string diff --git a/statement.go b/statement.go index 95d23fa5..fba1991d 100644 --- a/statement.go +++ b/statement.go @@ -161,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { stmt.Vars = append(stmt.Vars, v.Value) case clause.Column, clause.Table: stmt.QuoteTo(writer, v) + case Valuer: + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) case clause.Expr: var varStr strings.Builder var sql = v.SQL diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index ce8a2b50..ec16ccf6 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -1,16 +1,20 @@ package tests_test import ( + "context" "database/sql" "database/sql/driver" "encoding/json" "errors" + "fmt" "reflect" + "regexp" "strconv" "testing" "time" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -305,3 +309,48 @@ func (t EmptyTime) Value() (driver.Value, error) { type NullString struct { sql.NullString } + +type Point struct { + X, Y int +} + +func (point *Point) Scan(v interface{}) error { + return nil +} + +func (point Point) GormDataType() string { + return "geo" +} + +func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + return clause.Expr{ + SQL: "ST_PointFromText(?)", + Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)}, + } +} + +func TestGORMValuer(t *testing.T) { + type UserWithPoint struct { + Name string + Point Point + } + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if stmt.SQL.String() == "" || len(stmt.Vars) != 2 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } +} From 7a90496701f7b81e06daaa134a8f8853c1f935d5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 16:27:59 +0800 Subject: [PATCH 36/42] Test create from sql expr with map --- callbacks/create.go | 4 ++++ callbacks/helper.go | 12 ++++++++---- tests/scanner_valuer_test.go | 26 ++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index cc7e2671..c59b14b5 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -225,8 +225,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) + case *map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, *value) case []map[string]interface{}: values = ConvertSliceOfMapToValuesForCreate(stmt, value) + case *[]map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, *value) default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) diff --git a/callbacks/helper.go b/callbacks/helper.go index 80fbc2a1..e0a66dd2 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -20,8 +20,10 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter for _, k := range keys { value := mapValue[k] - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { @@ -46,8 +48,10 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st for idx, mapValue := range mapValues { for k, v := range mapValue { - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if _, ok := result[k]; !ok { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index ec16ccf6..dbf5adac 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -353,4 +353,30 @@ func TestGORMValuer(t *testing.T) { if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } + + stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } } From cd54dddd94a992edd446611aeccc939a64ad2658 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 18:42:40 +0800 Subject: [PATCH 37/42] Test update with GormValuer --- tests/go.mod | 2 +- tests/scanner_valuer_test.go | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 9d4e892d..b0ed4497 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v0.3.2 gorm.io/driver/postgres v0.2.9 gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.7 + gorm.io/driver/sqlserver v0.2.8 gorm.io/gorm v0.2.36 ) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index dbf5adac..f42daae7 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -314,10 +314,6 @@ type Point struct { X, Y int } -func (point *Point) Scan(v interface{}) error { - return nil -} - func (point Point) GormDataType() string { return "geo" } @@ -379,4 +375,19 @@ func TestGORMValuer(t *testing.T) { if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } + + stmt = dryRunDB.Session(&gorm.Session{ + AllowGlobalUpdate: true, + }).Model(&UserWithPoint{}).Updates(UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } } From d50dbb0896100640d61a8b4017aa46946f3bc6c5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 19:15:40 +0800 Subject: [PATCH 38/42] Fix check valid db name, close #3315 --- chainable_api.go | 6 +++--- finisher_api.go | 2 +- utils/utils.go | 4 ++-- utils/utils_test.go | 14 ++++++++++++++ 4 files changed, 20 insertions(+), 6 deletions(-) create mode 100644 utils/utils_test.go diff --git a/chainable_api.go b/chainable_api.go index e1b73457..c8417a6d 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -93,7 +93,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") case string: - fields := strings.FieldsFunc(v, utils.IsChar) + fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) // normal field names if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { @@ -133,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) } else { tx.Statement.Omits = columns } @@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsChar) + fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/finisher_api.go b/finisher_api.go index cf46f78a..2cde3c31 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -362,7 +362,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx.AddError(ErrModelValueRequired) } - fields := strings.FieldsFunc(column, utils.IsChar) + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, diff --git a/utils/utils.go b/utils/utils.go index e93f3055..71336f4b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -29,8 +29,8 @@ func FileWithLineNum() string { return "" } -func IsChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' +func IsValidDBNameChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' } func CheckTruth(val interface{}) bool { diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 00000000..5737c511 --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,14 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestIsValidDBNameChar(t *testing.T) { + for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} { + if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 { + t.Fatalf("failed to parse db name %v", db) + } + } +} From dacbaa5f02bf40efa5d8841047c047f7a5340d9f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 19:52:01 +0800 Subject: [PATCH 39/42] Fix update attrs order --- callbacks/update.go | 6 ++++-- tests/scanner_valuer_test.go | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 73c062b4..46f59157 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -199,7 +199,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if !stmt.UpdatingColumn && stmt.Schema != nil { - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { now := stmt.DB.NowFunc() @@ -222,7 +223,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(updatingValue) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index f42daae7..fb1f5791 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -384,7 +384,7 @@ func TestGORMValuer(t *testing.T) { }).Statement if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { - t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String()) } if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { From c19a3abefb2aef853e4541ae1af7fa93f2dc0848 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Aug 2020 11:31:13 +0800 Subject: [PATCH 40/42] Fix self-referential belongs to, close #3319 --- association.go | 4 ++-- schema/relationship.go | 34 +++++++++++++++++++--------------- schema/relationship_test.go | 14 ++++++++++++++ schema/schema_test.go | 2 +- 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/association.go b/association.go index e59b8938..25e1fe8d 100644 --- a/association.go +++ b/association.go @@ -54,7 +54,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro for _, queryClause := range association.Relationship.JoinTable.QueryClauses { joinStmt.AddClause(queryClause) } - joinStmt.Build("WHERE", "LIMIT") + joinStmt.Build("WHERE") tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } @@ -112,7 +112,7 @@ func (association *Association) Replace(values ...interface{}) error { updateMap[ref.ForeignKey.DBName] = nil } - association.DB.UpdateColumns(updateMap) + association.Error = association.DB.UpdateColumns(updateMap).Error } case schema.HasOne, schema.HasMany: var ( diff --git a/schema/relationship.go b/schema/relationship.go index dad2e629..5132ff74 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -82,7 +82,9 @@ func (schema *Schema) parseRelation(field *Field) { schema.buildMany2ManyRelation(relation, field, many2many) } else { switch field.IndirectFieldType.Kind() { - case reflect.Struct, reflect.Slice: + case reflect.Struct: + schema.guessRelation(relation, field, guessBelongs) + case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) @@ -324,10 +326,10 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel type guessLevel int const ( - guessHas guessLevel = iota - guessEmbeddedHas - guessBelongs + guessBelongs guessLevel = iota guessEmbeddedBelongs + guessHas + guessEmbeddedHas ) func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { @@ -338,25 +340,19 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue reguessOrErr := func() { switch gl { - case guessHas: - schema.guessRelation(relation, field, guessEmbeddedHas) - case guessEmbeddedHas: - schema.guessRelation(relation, field, guessBelongs) case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) + case guessEmbeddedBelongs: + schema.guessRelation(relation, field, guessHas) + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + // case guessEmbeddedHas: default: schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) } } switch gl { - case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { - reguessOrErr() - return - } case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: @@ -366,6 +362,14 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue reguessOrErr() return } + case guessHas: + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema + } else { + reguessOrErr() + return + } } if len(relation.foreignKeys) > 0 { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 2c09f528..2e85c538 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -55,6 +55,20 @@ func TestBelongsToOverrideReferences(t *testing.T) { }) } +func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy *int32 + Creator *User `gorm:"foreignKey:CreatedBy;references:ID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}}, + }) +} + func TestHasOneOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/schema/schema_test.go b/schema/schema_test.go index 8bd1e5ca..4d13ebd2 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -171,7 +171,7 @@ func TestNestedModel(t *testing.T) { fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, - {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Int, Size: 64}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64}, {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, } From 94c6bb980b8c3775d98121d5d42109cefe596c5c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Aug 2020 12:25:25 +0800 Subject: [PATCH 41/42] Refactor association --- association.go | 92 ++++++++++++++++++++------------------------------ 1 file changed, 37 insertions(+), 55 deletions(-) diff --git a/association.go b/association.go index 25e1fe8d..db77cc4e 100644 --- a/association.go +++ b/association.go @@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { - var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - tx = association.DB.Model(out) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: queryConds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: queryConds}) - } - - association.Error = tx.Find(out, conds...).Error + association.Error = association.buildCondition().Find(out, conds...).Error } - return association.Error } @@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error { association.Error = association.Replace(values...) } default: - association.saveAssociation(false, values...) + association.saveAssociation( /*clear*/ false, values...) } } @@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { // save associations - association.saveAssociation(true, values...) + association.saveAssociation( /*clear*/ true, values...) // set old associations's foreign key to null reflectValue := association.DB.Statement.ReflectValue @@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error { var ( primaryFields, relPrimaryFields []*schema.Field joinPrimaryKeys, joinRelPrimaryKeys []string - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + joinValue = reflect.New(rel.JoinTable.ModelType).Interface() ) for _, ref := range rel.References { @@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error { relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error } if association.Error == nil { + // clean up deleted values's foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -328,33 +305,8 @@ func (association *Association) Clear() error { func (association *Association) Count() (count int64) { if association.Error == nil { - var ( - conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() - tx = association.DB.Model(modelValue) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE", "LIMIT") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: conds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: conds}) - } - - association.Error = tx.Count(&count).Error + association.Error = association.buildCondition().Count(&count).Error } - return } @@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if len(values) != reflectValue.Len() { + // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { @@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: + // clear old data if clear && len(values) == 0 { association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) @@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } } + +func (association *Association) buildCondition() *DB { + var ( + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) + } + + return tx +} From 06461b32549fb13090b92713703228da2e8290aa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Aug 2020 21:16:47 +0800 Subject: [PATCH 42/42] GORM V2.0.0 --- tests/go.mod | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index b0ed4497..1a6fe7a8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.3.2 - gorm.io/driver/postgres v0.2.9 - gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.8 - gorm.io/gorm v0.2.36 + gorm.io/driver/mysql v1.0.0 + gorm.io/driver/postgres v1.0.0 + gorm.io/driver/sqlite v1.1.0 + gorm.io/driver/sqlserver v1.0.0 + gorm.io/gorm v1.9.19 ) replace gorm.io/gorm => ../