From 0546b59743ec2759051cb921a4dc5f7c31f36e3d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Jul 2020 11:28:00 +0800 Subject: [PATCH 01/65] Fix save many2many associations with UUID primary key, close #3182 --- callbacks/create.go | 9 ++++++++- tests/postgres_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index eecb80a1..de5bf1f8 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -149,10 +149,17 @@ func CreateWithReturning(db *gorm.DB) { for rows.Next() { BEGIN: + reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) for idx, field := range fields { - fieldValue := field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))) + fieldValue := field.ReflectValueOf(reflectValue) + if onConflict.DoNothing && !fieldValue.IsZero() { db.RowsAffected++ + + if int(db.RowsAffected) >= db.Statement.ReflectValue.Len() { + return + } + goto BEGIN } diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 98302d87..ab47a548 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -37,3 +37,36 @@ func TestPostgres(t *testing.T) { t.Errorf("No error should happen, but got %v", err) } } + +type Post struct { + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` + Title string + Categories []*Category `gorm:"Many2Many:post_categories"` +} + +type Category struct { + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` + Title string + Posts []*Post `gorm:"Many2Many:post_categories"` +} + +func TestMany2ManyWithDefaultValueUUID(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories") + DB.AutoMigrate(&Post{}, &Category{}) + + post := Post{ + Title: "Hello World", + Categories: []*Category{ + {Title: "Coding"}, + {Title: "Golang"}, + }, + } + + if err := DB.Create(&post).Error; err != nil { + t.Errorf("Failed, got error: %v", err) + } +} From da16f7b4756ead84856448fab67ff6aeddf91f60 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Jul 2020 12:13:40 +0800 Subject: [PATCH 02/65] Create extension uuid-ossp for postgres test database --- tests/postgres_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/postgres_test.go b/tests/postgres_test.go index ab47a548..a0b1fddb 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -55,6 +55,10 @@ func TestMany2ManyWithDefaultValueUUID(t *testing.T) { t.Skip() } + if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil { + t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err) + } + DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories") DB.AutoMigrate(&Post{}, &Category{}) From 87112ab1c711db2d8dd26ee32a4ccd0bb9307261 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Jul 2020 15:05:38 +0800 Subject: [PATCH 03/65] Fix row callback name --- callbacks/callbacks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index f61252d4..0a12468c 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -45,6 +45,6 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) - db.Callback().Row().Register("gorm:raw", RowQuery) + db.Callback().Row().Register("gorm:row", RowQuery) db.Callback().Raw().Register("gorm:raw", RawExec) } From 7021db3655381405b8c3f848319a66128b96041b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Jul 2020 19:03:19 +0800 Subject: [PATCH 04/65] Fix FieldsWithDefaultDBValue for primary field, close #3187 --- schema/schema.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index bcf65939..1106f0c5 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -184,11 +184,11 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if field := schema.PrioritizedPrimaryField; field != nil { switch field.GORMDataType { case Int, Uint: - if !field.HasDefaultValue || field.DefaultValueInterface != nil { - schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) - } - if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + field.HasDefaultValue = true field.AutoIncrement = true } From 6ed697dd0225631c19bcfc43bf8762ced235742c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 23 Jul 2020 23:41:56 +0800 Subject: [PATCH 05/65] TestFirstOrCreateWithPrimaryKey, close #3192 --- callbacks/create.go | 10 +--------- tests/create_test.go | 19 +++++++++++++++++++ tests/go.mod | 6 +++--- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index de5bf1f8..707b94c1 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -70,16 +70,8 @@ func Create(config *Config) func(db *gorm.DB) { } } } else { - allUpdated := int(db.RowsAffected) == db.Statement.ReflectValue.Len() - isZero := true - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - - if !allUpdated { - _, isZero = db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) - } - - if isZero { + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero { db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) insertID++ } diff --git a/tests/create_test.go b/tests/create_test.go index 46cc06c6..ae6e1232 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -352,3 +352,22 @@ func TestOmitWithCreate(t *testing.T) { CheckUser(t, result2, user2) } + +func TestFirstOrCreateWithPrimaryKey(t *testing.T) { + company := Company{ID: 100, Name: "company100_with_primarykey"} + DB.FirstOrCreate(&company) + + if company.ID != 100 { + t.Errorf("invalid primary key after creating, got %v", company.ID) + } + + companies := []Company{ + {ID: 101, Name: "company101_with_primarykey"}, + {ID: 102, Name: "company102_with_primarykey"}, + } + DB.Create(&companies) + + if companies[0].ID != 101 || companies[1].ID != 102 { + t.Errorf("invalid primary key after creating, got %v, %v", companies[0].ID, companies[1].ID) + } +} diff --git a/tests/go.mod b/tests/go.mod index 3a5b4224..6eb6eb07 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.9 - gorm.io/driver/postgres v0.2.5 + gorm.io/driver/mysql v0.3.1 + gorm.io/driver/postgres v0.2.6 gorm.io/driver/sqlite v1.0.8 - gorm.io/driver/sqlserver v0.2.4 + gorm.io/driver/sqlserver v0.2.5 gorm.io/gorm v0.2.19 ) From c3f52cee8b1e3d26fd0618399cc2a0cc012ff216 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 23 Jul 2020 23:56:13 +0800 Subject: [PATCH 06/65] Don't scan last insert id 0 --- callbacks/create.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/callbacks/create.go b/callbacks/create.go index 707b94c1..c86cefe4 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -78,7 +78,9 @@ func Create(config *Config) func(db *gorm.DB) { } } case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + if insertID > 0 { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } } else { db.AddError(err) From 69d81118936a761a140d35eb07f1cd249067a1a9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 24 Jul 2020 08:32:50 +0800 Subject: [PATCH 07/65] Fix panic when using invalid data, close #3193 --- callbacks/create.go | 6 +++--- callbacks/delete.go | 2 +- callbacks/query.go | 2 +- callbacks/update.go | 2 +- errors.go | 6 ------ statement.go | 4 +++- 6 files changed, 9 insertions(+), 13 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index c86cefe4..b41a3ef2 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -51,7 +51,7 @@ func Create(config *Config) func(db *gorm.DB) { db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") } - if !db.DryRun { + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { @@ -130,7 +130,7 @@ func CreateWithReturning(db *gorm.DB) { db.Statement.WriteQuoted(field.DBName) } - if !db.DryRun { + if !db.DryRun && db.Error == nil { rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { @@ -179,7 +179,7 @@ func CreateWithReturning(db *gorm.DB) { db.AddError(err) } } - } else if !db.DryRun { + } else if !db.DryRun && db.Error == nil { if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { db.RowsAffected, _ = result.RowsAffected() } else { diff --git a/callbacks/delete.go b/callbacks/delete.go index 51a33bf0..288f2d69 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -60,7 +60,7 @@ func Delete(db *gorm.DB) { db.Statement.Build("DELETE", "FROM", "WHERE") } - if !db.DryRun { + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { diff --git a/callbacks/query.go b/callbacks/query.go index 5c322a05..66bbf805 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -23,7 +23,7 @@ func Query(db *gorm.DB) { BuildQuerySQL(db) } - if !db.DryRun { + if !db.DryRun && db.Error == nil { rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) diff --git a/callbacks/update.go b/callbacks/update.go index d549f97b..e492cfc9 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -74,7 +74,7 @@ func Update(db *gorm.DB) { return } - if !db.DryRun { + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { diff --git a/errors.go b/errors.go index e1b58835..12e64611 100644 --- a/errors.go +++ b/errors.go @@ -7,20 +7,14 @@ import ( var ( // ErrRecordNotFound record not found error ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL - ErrInvalidSQL = errors.New("invalid SQL") // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` ErrInvalidTransaction = errors.New("no valid transaction") - // ErrUnaddressable unaddressable value - ErrUnaddressable = errors.New("using unaddressable value") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") // ErrMissingWhereClause missing where clause ErrMissingWhereClause = errors.New("WHERE conditions required") // ErrUnsupportedRelation unsupported relations ErrUnsupportedRelation = errors.New("unsupported relations") - // ErrPtrStructSupported only ptr of struct supported - ErrPtrStructSupported = errors.New("only ptr of struct supported") // ErrorPrimaryKeyRequired primary keys required ErrorPrimaryKeyRequired = errors.New("primary key required") // ErrorModelValueRequired model value required diff --git a/statement.go b/statement.go index 5f4238ef..310484d8 100644 --- a/statement.go +++ b/statement.go @@ -95,7 +95,9 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { } if v.Name == clause.PrimaryKey { - if stmt.Schema != nil && stmt.Schema.PrioritizedPrimaryField != nil { + if stmt.Schema == nil { + stmt.DB.AddError(ErrorModelValueRequired) + } else if stmt.Schema.PrioritizedPrimaryField != nil { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0]) From f4cfa9411bc3eae4488d52c30272cd3cdb6e2127 Mon Sep 17 00:00:00 2001 From: Qt Date: Sun, 26 Jul 2020 10:03:58 +0800 Subject: [PATCH 08/65] define err with the same code style (#3199) --- association.go | 2 +- errors.go | 8 ++++---- finisher_api.go | 2 +- statement.go | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/association.go b/association.go index aa740fc5..e59b8938 100644 --- a/association.go +++ b/association.go @@ -170,7 +170,7 @@ func (association *Association) Replace(values ...interface{}) error { if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { - return ErrorPrimaryKeyRequired + return ErrPrimaryKeyRequired } _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) diff --git a/errors.go b/errors.go index 12e64611..115b8e25 100644 --- a/errors.go +++ b/errors.go @@ -15,10 +15,10 @@ var ( ErrMissingWhereClause = errors.New("WHERE conditions required") // ErrUnsupportedRelation unsupported relations ErrUnsupportedRelation = errors.New("unsupported relations") - // ErrorPrimaryKeyRequired primary keys required - ErrorPrimaryKeyRequired = errors.New("primary key required") - // ErrorModelValueRequired model value required - ErrorModelValueRequired = errors.New("model value required") + // ErrPrimaryKeyRequired primary keys required + ErrPrimaryKeyRequired = errors.New("primary key required") + // ErrModelValueRequired model value required + ErrModelValueRequired = errors.New("model value required") // ErrUnsupportedDriver unsupported driver ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered diff --git a/finisher_api.go b/finisher_api.go index 6bfe5d20..77bea578 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -325,7 +325,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { } } } else if tx.Statement.Table == "" { - tx.AddError(ErrorModelValueRequired) + tx.AddError(ErrModelValueRequired) } fields := strings.FieldsFunc(column, utils.IsChar) diff --git a/statement.go b/statement.go index 310484d8..e9d826c4 100644 --- a/statement.go +++ b/statement.go @@ -96,7 +96,7 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if v.Name == clause.PrimaryKey { if stmt.Schema == nil { - stmt.DB.AddError(ErrorModelValueRequired) + stmt.DB.AddError(ErrModelValueRequired) } else if stmt.Schema.PrioritizedPrimaryField != nil { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { From c7667e9299134799da6f16e19eaf50cb8419736f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 28 Jul 2020 14:26:09 +0800 Subject: [PATCH 09/65] Refactor Prepared Statement --- gorm.go | 22 +++++++++++++++------- prepare_stmt.go | 14 +++++++++----- tests/.gitignore | 1 + 3 files changed, 25 insertions(+), 12 deletions(-) create mode 100644 tests/.gitignore diff --git a/gorm.go b/gorm.go index 338a1473..c786b5a5 100644 --- a/gorm.go +++ b/gorm.go @@ -108,11 +108,15 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { err = config.Dialector.Initialize(db) } + preparedStmt := &PreparedStmtDB{ + ConnPool: db.ConnPool, + Stmts: map[string]*sql.Stmt{}, + PreparedSQL: make([]string, 0, 100), + } + db.cacheStore.Store("preparedStmt", preparedStmt) + if config.PrepareStmt { - db.ConnPool = &PreparedStmtDB{ - ConnPool: db.ConnPool, - Stmts: map[string]*sql.Stmt{}, - } + db.ConnPool = preparedStmt } db.Statement = &Statement{ @@ -157,9 +161,13 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Stmts: map[string]*sql.Stmt{}, + if v, ok := db.cacheStore.Load("preparedStmt"); ok { + preparedStmt := v.(*PreparedStmtDB) + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + mux: preparedStmt.mux, + Stmts: preparedStmt.Stmts, + } } } diff --git a/prepare_stmt.go b/prepare_stmt.go index 197c257c..2f4e1d57 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -7,16 +7,19 @@ import ( ) type PreparedStmtDB struct { - Stmts map[string]*sql.Stmt - mux sync.RWMutex + Stmts map[string]*sql.Stmt + PreparedSQL []string + mux sync.RWMutex ConnPool } func (db *PreparedStmtDB) Close() { db.mux.Lock() - for k, stmt := range db.Stmts { - delete(db.Stmts, k) - stmt.Close() + for _, query := range db.PreparedSQL { + if stmt, ok := db.Stmts[query]; ok { + delete(db.Stmts, query) + stmt.Close() + } } db.mux.Unlock() @@ -40,6 +43,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { stmt, err := db.ConnPool.PrepareContext(context.Background(), query) if err == nil { db.Stmts[query] = stmt + db.PreparedSQL = append(db.PreparedSQL, query) } db.mux.Unlock() diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..08cb523c --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +go.sum From a140908839f5f6f3b2e493fbe7b779fb9fffc3ff Mon Sep 17 00:00:00 2001 From: Qt Date: Tue, 28 Jul 2020 17:25:03 +0800 Subject: [PATCH 10/65] refactor function convertParams's default case (#3208) --- logger/sql.go | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index d3c0bf10..02d559c5 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -50,30 +50,22 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case string: vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper default: - if v == nil { + rv := reflect.ValueOf(v) + if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { vars[idx] = "NULL" + } else if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + convertParams(v, idx) + } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { + convertParams(reflect.Indirect(rv).Interface(), idx) } else { - rv := reflect.ValueOf(v) - - if !rv.IsValid() { - vars[idx] = "NULL" - } else if rv.Kind() == reflect.Ptr && rv.IsNil() { - vars[idx] = "NULL" - } else if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - convertParams(v, idx) - } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { - convertParams(reflect.Indirect(rv).Interface(), idx) - } else { - for _, t := range convertableTypes { - if rv.Type().ConvertibleTo(t) { - convertParams(rv.Convert(t).Interface(), idx) - return - } + for _, t := range convertableTypes { + if rv.Type().ConvertibleTo(t) { + convertParams(rv.Convert(t).Interface(), idx) + return } - - vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper } + vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper } } } From 2cbdd29f26eeb81e7c1b9f014bf1a0a8066f76ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 29 Jul 2020 10:23:14 +0800 Subject: [PATCH 11/65] Returns error for invalid embedded field, close #3209 --- schema/field.go | 78 ++++++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/schema/field.go b/schema/field.go index a170e60e..f377a34a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -304,44 +304,48 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { - var err error - field.Creatable = false - field.Updatable = false - field.Readable = false - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { - schema.err = err + if reflect.Indirect(fieldValue).Kind() == reflect.Struct { + var err error + field.Creatable = false + field.Updatable = false + field.Readable = false + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + schema.err = err + } + for _, ef := range field.EmbeddedSchema.Fields { + ef.Schema = schema + ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) + // index is negative means is pointer + if field.FieldType.Kind() == reflect.Struct { + ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) + } else { + ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) + } + + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { + ef.DBName = prefix + ef.DBName + } + + if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else { + ef.PrimaryKey = false + } + + for k, v := range field.TagSettings { + ef.TagSettings[k] = v + } + } + + field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) + field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) + } else { + schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } - for _, ef := range field.EmbeddedSchema.Fields { - ef.Schema = schema - ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) - // index is negative means is pointer - if field.FieldType.Kind() == reflect.Struct { - ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) - } else { - ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) - } - - if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { - ef.DBName = prefix + ef.DBName - } - - if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else { - ef.PrimaryKey = false - } - - for k, v := range field.TagSettings { - ef.TagSettings[k] = v - } - } - - field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) - field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) } return field From 7c2ecdfc1c738f118b892d593ac3899d8e92b74b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2020 10:23:35 +0800 Subject: [PATCH 12/65] Fix use pointer of Valuer as foreign key, close #3212 --- schema/field.go | 5 +++-- tests/scanner_valuer_test.go | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/schema/field.go b/schema/field.go index f377a34a..329ae41c 100644 --- a/schema/field.go +++ b/schema/field.go @@ -742,15 +742,16 @@ func (field *Field) setupValuerAndSetter() { } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if valuer, ok := v.(driver.Valuer); ok { - if valuer == nil { + if valuer == nil || reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { v, _ = valuer.Value() } } - reflectV := reflect.ValueOf(v) if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 632bd74a..bee0ae98 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -136,6 +136,8 @@ type ScannerValuerStruct struct { Strings StringsSlice Structs StructsSlice Role Role + UserID *sql.NullInt64 + User User } type EncryptedData []byte From 47a5196734de9f4d8486a1be568c8341991b4ac8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2020 17:36:39 +0800 Subject: [PATCH 13/65] Fix uninitialized Valuer return time.Time, close #3214 --- schema/field.go | 2 ++ tests/scanner_valuer_test.go | 44 ++++++++++++++++++++++++------------ 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/schema/field.go b/schema/field.go index 329ae41c..6d0fd1cc 100644 --- a/schema/field.go +++ b/schema/field.go @@ -213,6 +213,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time + } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + field.DataType = Time } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { field.DataType = Time } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index bee0ae98..2c2c1e18 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -124,20 +124,21 @@ func TestInvalidValuer(t *testing.T) { type ScannerValuerStruct struct { gorm.Model - Name sql.NullString - Gender *sql.NullString - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - Birthday sql.NullTime - Password EncryptedData - Bytes []byte - Num Num - Strings StringsSlice - Structs StructsSlice - Role Role - UserID *sql.NullInt64 - User User + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Password EncryptedData + Bytes []byte + Num Num + Strings StringsSlice + Structs StructsSlice + Role Role + UserID *sql.NullInt64 + User User + EmptyTime EmptyTime } type EncryptedData []byte @@ -244,3 +245,18 @@ func (role Role) Value() (driver.Value, error) { func (role Role) IsAdmin() bool { return role.Name == "admin" } + +type EmptyTime struct { + time.Time +} + +func (t *EmptyTime) Scan(v interface{}) error { + nullTime := sql.NullTime{} + err := nullTime.Scan(v) + t.Time = nullTime.Time + return err +} + +func (t EmptyTime) Value() (driver.Value, error) { + return t.Time, nil +} From 7bb883b665082c0506991f8c87e5f02d86254920 Mon Sep 17 00:00:00 2001 From: lninl Date: Thu, 30 Jul 2020 17:39:57 +0800 Subject: [PATCH 14/65] Auto creating/updating time with unix (milli) second (#3213) * Auto creating/updating time with unix (milli) second * add test for 'Auto creating/updating time with unix (milli) second' --- callbacks/update.go | 10 +++++++--- schema/field.go | 13 +++++++++++-- tests/customize_field_test.go | 36 +++++++++++++++++++++++------------ 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index e492cfc9..12806af6 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -140,7 +140,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - var priamryKeyExprs []clause.Expression + var primaryKeyExprs []clause.Expression for i := 0; i < stmt.ReflectValue.Len(); i++ { var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) var notZero bool @@ -150,10 +150,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { notZero = notZero || !isZero } if notZero { - priamryKeyExprs = append(priamryKeyExprs, clause.And(exprs...)) + primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) } } - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(priamryKeyExprs...)}}) + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { @@ -202,6 +202,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) + } else if field.AutoUpdateTime == schema.UnixMillisecond { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) } else if field.GORMDataType == schema.Time { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) } else { @@ -223,6 +225,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 } else if field.GORMDataType == schema.Time { value = stmt.DB.NowFunc() } else { diff --git a/schema/field.go b/schema/field.go index 6d0fd1cc..4eb95b98 100644 --- a/schema/field.go +++ b/schema/field.go @@ -19,8 +19,9 @@ type DataType string type TimeType int64 const ( - UnixSecond TimeType = 1 - UnixNanosecond TimeType = 2 + UnixSecond TimeType = 1 + UnixMillisecond TimeType = 2 + UnixNanosecond TimeType = 3 ) const ( @@ -233,6 +234,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoCreateTime = UnixMillisecond } else { field.AutoCreateTime = UnixSecond } @@ -241,6 +244,8 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond + } else if strings.ToUpper(v) == "MILLI" { + field.AutoUpdateTime = UnixMillisecond } else { field.AutoUpdateTime = UnixSecond } @@ -551,6 +556,8 @@ func (field *Field) setupValuerAndSetter() { case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } @@ -558,6 +565,8 @@ func (field *Field) setupValuerAndSetter() { if data != nil { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(value).SetInt(data.UnixNano()) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetInt(data.UnixNano() / 1e6) } else { field.ReflectValueOf(value).SetInt(data.Unix()) } diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go index 9c6ab948..bf3c78fa 100644 --- a/tests/customize_field_test.go +++ b/tests/customize_field_test.go @@ -61,18 +61,20 @@ func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { func TestCustomizeField(t *testing.T) { type CustomizeFieldStruct struct { gorm.Model - Name string - FieldAllowCreate string `gorm:"<-:create"` - FieldAllowUpdate string `gorm:"<-:update"` - FieldAllowSave string `gorm:"<-"` - FieldAllowSave2 string `gorm:"<-:create,update"` - FieldAllowSave3 string `gorm:"->:false;<-:create"` - FieldReadonly string `gorm:"->"` - FieldIgnore string `gorm:"-"` - AutoUnixCreateTime int64 `gorm:"autocreatetime"` - AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` - AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` - AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` + Name string + FieldAllowCreate string `gorm:"<-:create"` + FieldAllowUpdate string `gorm:"<-:update"` + FieldAllowSave string `gorm:"<-"` + FieldAllowSave2 string `gorm:"<-:create,update"` + FieldAllowSave3 string `gorm:"->:false;<-:create"` + FieldReadonly string `gorm:"->"` + FieldIgnore string `gorm:"-"` + AutoUnixCreateTime int64 `gorm:"autocreatetime"` + AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"` + AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` + AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` + AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"` + AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` } DB.Migrator().DropTable(&CustomizeFieldStruct{}) @@ -118,6 +120,10 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid create/update unix time: %#v", result) } + if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 { + t.Fatalf("invalid create/update unix milli time: %#v", result) + } + if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 { t.Fatalf("invalid create/update unix nano time: %#v", result) } @@ -163,6 +169,8 @@ func TestCustomizeField(t *testing.T) { createWithDefaultTime := generateStruct("create_with_default_time") createWithDefaultTime.AutoUnixCreateTime = 100 createWithDefaultTime.AutoUnixUpdateTime = 100 + createWithDefaultTime.AutoUnixMilliCreateTime = 100 + createWithDefaultTime.AutoUnixMilliUpdateTime = 100 createWithDefaultTime.AutoUnixNanoCreateTime = 100 createWithDefaultTime.AutoUnixNanoUpdateTime = 100 DB.Create(&createWithDefaultTime) @@ -174,6 +182,10 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) } + if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { + t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) + } + if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) } From 07ce8caf7df21e067de87a048d3cf638426bfe33 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2020 17:42:41 +0800 Subject: [PATCH 15/65] Remove labeler workflows --- .github/workflows/labeler.yml | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 .github/workflows/labeler.yml diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml deleted file mode 100644 index 1490730b..00000000 --- a/.github/workflows/labeler.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: "Issue Labeler" -on: - issues: - types: [opened, edited, reopened] - pull_request: - types: [opened, edited, reopened, ready_for_review, synchronize] - -jobs: - triage: - runs-on: ubuntu-latest - name: Label issues and pull requests - steps: - - name: check out - uses: actions/checkout@v2 - - - name: labeler - uses: jinzhu/super-labeler-action@develop - with: - GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" From 81c68db87fe8c4dc18a86caf198466d6fe29b0d3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2020 17:56:16 +0800 Subject: [PATCH 16/65] Fix zero time failed on mysql 8 --- tests/scanner_valuer_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 2c2c1e18..63a7c63c 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -258,5 +258,5 @@ func (t *EmptyTime) Scan(v interface{}) error { } func (t EmptyTime) Value() (driver.Value, error) { - return t.Time, nil + return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil } From dc299b900f5916c101b36b23edc77801ca76d056 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jul 2020 14:47:26 +0800 Subject: [PATCH 17/65] Use specified table when preloading data with Join --- callbacks/query.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 66bbf805..be829fbc 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -124,13 +124,13 @@ func BuildQuerySQL(db *gorm.DB) { for idx, ref := range relation.References { if ref.OwnPrimaryKey { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.PrimaryKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, } } else { if ref.PrimaryValue == "" { exprs[idx] = clause.Eq{ - Column: clause.Column{Table: db.Statement.Schema.Table, Name: ref.ForeignKey.DBName}, + Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, } } else { From 2676fa4fb8e3c2b11c6bc72c1fb639c1586f6f3b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jul 2020 18:19:25 +0800 Subject: [PATCH 18/65] Remove autoincrement tag for join table, close #3217 --- schema/relationship.go | 4 ++-- schema/utils.go | 2 +- schema/utils_test.go | 1 + tests/postgres_test.go | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index e67092b4..b7ab4f66 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -220,7 +220,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(ownField.StructField.Tag, "column"), + Tag: removeSettingFromTag(removeSettingFromTag(ownField.StructField.Tag, "column"), "autoincrement"), }) } @@ -243,7 +243,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(relField.StructField.Tag, "column"), + Tag: removeSettingFromTag(removeSettingFromTag(relField.StructField.Tag, "column"), "autoincrement"), }) } diff --git a/schema/utils.go b/schema/utils.go index defa83af..1481d428 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -50,7 +50,7 @@ func toColumns(val string) (results []string) { } func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { - return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`:.*?)(;|("))`).ReplaceAllString(string(tag), "${1}${4}")) + return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) } // GetRelationsValues get relations's values from a reflect value diff --git a/schema/utils_test.go b/schema/utils_test.go index e70169bf..1b47ef25 100644 --- a/schema/utils_test.go +++ b/schema/utils_test.go @@ -13,6 +13,7 @@ func TestRemoveSettingFromTag(t *testing.T) { `gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`, `gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, `gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, + `gorm:"before:value;column; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, } for k, v := range tags { diff --git a/tests/postgres_test.go b/tests/postgres_test.go index a0b1fddb..85cd34d4 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -39,13 +39,13 @@ func TestPostgres(t *testing.T) { } type Post struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` Title string Categories []*Category `gorm:"Many2Many:post_categories"` } type Category struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4()"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` Title string Posts []*Post `gorm:"Many2Many:post_categories"` } From f83b00d20dd57bb0df964cacfefa8f7b259a09d3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 3 Aug 2020 10:30:25 +0800 Subject: [PATCH 19/65] Fix Count with Select when Model not specfied, close #3220 --- finisher_api.go | 11 +++++++++-- schema/schema.go | 4 ++++ tests/count_test.go | 12 ++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 77bea578..33a4f121 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -274,11 +274,18 @@ func (db *DB) Count(count *int64) (tx *DB) { expr := clause.Expr{SQL: "count(1)"} if len(tx.Statement.Selects) == 1 { + dbName := tx.Statement.Selects[0] if tx.Statement.Parse(tx.Statement.Model) == nil { - if f := tx.Statement.Schema.LookUpField(tx.Statement.Selects[0]); f != nil { - expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: f.DBName}}} + if f := tx.Statement.Schema.LookUpField(dbName); f != nil { + dbName = f.DBName } } + + if tx.Statement.Distinct { + expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} + } else { + expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} + } } tx.Statement.AddClause(clause.Select{Expression: expr}) diff --git a/schema/schema.go b/schema/schema.go index 1106f0c5..9206c24e 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -72,6 +72,10 @@ type Tabler interface { // get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { + if dest == nil { + return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) + } + modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() diff --git a/tests/count_test.go b/tests/count_test.go index 826d6a36..05661ae8 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -2,6 +2,7 @@ package tests_test import ( "fmt" + "regexp" "testing" "gorm.io/gorm" @@ -55,4 +56,15 @@ func TestCount(t *testing.T) { if count3 != 2 { t.Errorf("Should get correct count for count with group, but got %v", count3) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + result := dryDB.Table("users").Select("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(.name.\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Table("users").Distinct("name").Count(&count) + if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) + } } From c11c939b959c489c96bd6b5967b6a47c8b402ceb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 3 Aug 2020 21:48:36 +0800 Subject: [PATCH 20/65] callbacks support sort with wildcard --- callbacks.go | 16 ++++++++++++++-- gorm.go | 2 +- prepare_stmt.go | 34 +++++++++++++++++----------------- tests/callbacks_test.go | 8 ++++++++ 4 files changed, 40 insertions(+), 20 deletions(-) diff --git a/callbacks.go b/callbacks.go index c917a678..baeb6c09 100644 --- a/callbacks.go +++ b/callbacks.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "sort" "time" "gorm.io/gorm/logger" @@ -207,6 +208,9 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { names, sorted []string sortCallback func(*callback) error ) + sort.Slice(cs, func(i, j int) bool { + return cs[j].before == "*" || cs[j].after == "*" + }) for _, c := range cs { // show warning message the callback name already exists @@ -218,7 +222,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { sortCallback = func(c *callback) error { if c.before != "" { // if defined before callback - if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { + if c.before == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append([]string{c.name}, sorted...) + } + } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if before callback already sorted, append current callback just after it sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) @@ -232,7 +240,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { } if c.after != "" { // if defined after callback - if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { + if c.after == "*" && len(sorted) > 0 { + if curIdx := getRIndex(sorted, c.name); curIdx == -1 { + sorted = append(sorted, c.name) + } + } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if after callback sorted, append current callback to last sorted = append(sorted, c.name) diff --git a/gorm.go b/gorm.go index c786b5a5..1ace0099 100644 --- a/gorm.go +++ b/gorm.go @@ -165,7 +165,7 @@ func (db *DB) Session(config *Session) *DB { preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, - mux: preparedStmt.mux, + Mux: preparedStmt.Mux, Stmts: preparedStmt.Stmts, } } diff --git a/prepare_stmt.go b/prepare_stmt.go index 2f4e1d57..7e87558d 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -9,12 +9,12 @@ import ( type PreparedStmtDB struct { Stmts map[string]*sql.Stmt PreparedSQL []string - mux sync.RWMutex + Mux sync.RWMutex ConnPool } func (db *PreparedStmtDB) Close() { - db.mux.Lock() + db.Mux.Lock() for _, query := range db.PreparedSQL { if stmt, ok := db.Stmts[query]; ok { delete(db.Stmts, query) @@ -22,21 +22,21 @@ func (db *PreparedStmtDB) Close() { } } - db.mux.Unlock() + db.Mux.Unlock() } func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { - db.mux.RLock() + db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok { - db.mux.RUnlock() + db.Mux.RUnlock() return stmt, nil } - db.mux.RUnlock() + db.Mux.RUnlock() - db.mux.Lock() + db.Mux.Lock() // double check if stmt, ok := db.Stmts[query]; ok { - db.mux.Unlock() + db.Mux.Unlock() return stmt, nil } @@ -45,7 +45,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { db.Stmts[query] = stmt db.PreparedSQL = append(db.PreparedSQL, query) } - db.mux.Unlock() + db.Mux.Unlock() return stmt, err } @@ -63,10 +63,10 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { - db.mux.Lock() + db.Mux.Lock() stmt.Close() delete(db.Stmts, query) - db.mux.Unlock() + db.Mux.Unlock() } } return result, err @@ -77,10 +77,10 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { - db.mux.Lock() + db.Mux.Lock() stmt.Close() delete(db.Stmts, query) - db.mux.Unlock() + db.Mux.Unlock() } } return rows, err @@ -104,10 +104,10 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. if err == nil { result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) if err != nil { - tx.PreparedStmtDB.mux.Lock() + tx.PreparedStmtDB.Mux.Lock() stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + tx.PreparedStmtDB.Mux.Unlock() } } return result, err @@ -118,10 +118,10 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . if err == nil { rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) if err != nil { - tx.PreparedStmtDB.mux.Lock() + tx.PreparedStmtDB.Mux.Lock() stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) - tx.PreparedStmtDB.mux.Unlock() + tx.PreparedStmtDB.Mux.Unlock() } } return rows, err diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 1dbae441..84f56165 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -96,6 +96,14 @@ func TestCallbacks(t *testing.T) { callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, results: []string{"c1", "c4", "c3"}, }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c4", "c3"}, + }, } for idx, data := range datas { From ff985b90cc0f2f11be492300dd9f6914cba0cf22 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 4 Aug 2020 12:10:19 +0800 Subject: [PATCH 21/65] Fix failed to guess relations for embedded types, close #3224 --- migrator/migrator.go | 1 + schema/field.go | 2 + schema/relationship.go | 69 +++++++++++++++++++++++++++-------- tests/callbacks_test.go | 8 +++- tests/embedded_struct_test.go | 14 +++++++ 5 files changed, 76 insertions(+), 18 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 3e5d86d3..d50159dd 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } return nil }); err != nil { + fmt.Println(err) return err } } diff --git a/schema/field.go b/schema/field.go index 4eb95b98..1ca4cb6d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -62,6 +62,7 @@ type Field struct { TagSettings map[string]string Schema *Schema EmbeddedSchema *Schema + OwnerSchema *Schema ReflectValueOf func(reflect.Value) reflect.Value ValueOf func(reflect.Value) (value interface{}, zero bool) Set func(reflect.Value, interface{}) error @@ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema + ef.OwnerSchema = field.EmbeddedSchema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) // index is negative means is pointer if field.FieldType.Kind() == reflect.Struct { diff --git a/schema/relationship.go b/schema/relationship.go index b7ab4f66..93080105 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -5,6 +5,7 @@ import ( "reflect" "regexp" "strings" + "sync" "github.com/jinzhu/inflection" "gorm.io/gorm/clause" @@ -66,9 +67,16 @@ func (schema *Schema) parseRelation(field *Field) { } ) - if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { - schema.err = err - return + if field.OwnerSchema != nil { + if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil { + schema.err = err + return + } + } else { + if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { + schema.err = err + return + } } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { @@ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) { } else { switch field.IndirectFieldType.Kind() { case reflect.Struct, reflect.Slice: - schema.guessRelation(relation, field, true) + schema.guessRelation(relation, field, guessHas) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) } @@ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } -func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) { +type guessLevel int + +const ( + guessHas guessLevel = iota + guessEmbeddedHas + guessBelongs + guessEmbeddedBelongs +) + +func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { var ( primaryFields, foreignFields []*Field primarySchema, foreignSchema = schema, relation.FieldSchema ) - if !guessHas { - primarySchema, foreignSchema = relation.FieldSchema, schema + reguessOrErr := func(err string, args ...interface{}) { + switch gl { + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + case guessEmbeddedHas: + schema.guessRelation(relation, field, guessBelongs) + case guessBelongs: + schema.guessRelation(relation, field, guessEmbeddedBelongs) + default: + schema.err = fmt.Errorf(err, args...) + } } - reguessOrErr := func(err string, args ...interface{}) { - if guessHas { - schema.guessRelation(relation, field, false) + switch gl { + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } else { - schema.err = fmt.Errorf(err, args...) + reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + return + } + case guessBelongs: + primarySchema, foreignSchema = relation.FieldSchema, schema + case guessEmbeddedBelongs: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema + } else { + reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + return } } @@ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } } else { for _, primaryField := range primarySchema.PrimaryFields { - lookUpName := schema.Name + primaryField.Name - if !guessHas { + lookUpName := primarySchema.Name + primaryField.Name + if gl == guessBelongs { lookUpName = field.Name + primaryField.Name } @@ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH } if len(foreignFields) == 0 { - reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas) + reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) return } else if len(relation.primaryKeys) > 0 { for idx, primaryKey := range relation.primaryKeys { @@ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], ForeignKey: foreignField, - OwnPrimaryKey: schema == primarySchema && guessHas, + OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), }) } - if guessHas { + if gl == guessHas || gl == guessEmbeddedHas { relation.Type = "has" } else { relation.Type = BelongsTo diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 84f56165..02765b8c 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) { results: []string{"c5", "c1", "c2", "c3", "c4"}, }, { - callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}}, - results: []string{"c5", "c1", "c2", "c4", "c3"}, + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}}, + results: []string{"c3", "c5", "c1", "c2", "c4"}, + }, + { + callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}}, + results: []string{"c5", "c1", "c2", "c3", "c4"}, }, } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index 7f40a0a4..fb0d6f23 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -7,6 +7,7 @@ import ( "testing" "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) func TestEmbeddedStruct(t *testing.T) { @@ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) { t.Errorf("Failed to create got error %v", err) } } + +func TestEmbeddedRelations(t *testing.T) { + type AdvancedUser struct { + User `gorm:"embedded"` + Advanced bool + } + + DB.Debug().Migrator().DropTable(&AdvancedUser{}) + + if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil { + t.Errorf("Failed to auto migrate advanced user, got error %v", err) + } +} From f962872b48fae9095c9309d1c94215c4636befe8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 5 Aug 2020 14:22:35 +0800 Subject: [PATCH 22/65] Fix labeler --- .github/workflows/labeler.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/labeler.yml diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml new file mode 100644 index 00000000..bc1add53 --- /dev/null +++ b/.github/workflows/labeler.yml @@ -0,0 +1,19 @@ +name: "Issue Labeler" +on: + issues: + types: [opened, edited, reopened] + pull_request: + types: [opened, edited, reopened] + +jobs: + triage: + runs-on: ubuntu-latest + name: Label issues and pull requests + steps: + - name: check out + uses: actions/checkout@v2 + + - name: labeler + uses: jinzhu/super-labeler-action@develop + with: + GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" From da1e54d5abb4482ca2accabbad0a1e1d65a9fc8f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 6 Aug 2020 15:37:36 +0800 Subject: [PATCH 23/65] Add sql-cli --- tests/tests_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_test.go b/tests/tests_test.go index 5aedc061..192160a0 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -64,6 +64,8 @@ func OpenTestConnection() (db *gorm.DB, err error) { // USE gorm; // CREATE USER gorm FROM LOGIN gorm; // sp_changedbowner 'gorm'; + // npm install -g sql-cli + // mssql -u gorm -p LoremIpsum86 -d gorm -o 9930 log.Println("testing sqlserver...") if dbDSN == "" { dbDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm" From 3df249c127e637f8af6c99e5e4fed9c466803d79 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 6 Aug 2020 16:25:26 +0800 Subject: [PATCH 24/65] Use table expr when inserting table, close #3239 --- callbacks/create.go | 8 ++------ tests/go.mod | 4 ++-- tests/table_test.go | 11 +++++++++++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index b41a3ef2..3a414dd7 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -43,9 +43,7 @@ func Create(config *Config) func(db *gorm.DB) { if db.Statement.SQL.String() == "" { db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) + db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") @@ -105,9 +103,7 @@ func CreateWithReturning(db *gorm.DB) { } if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) + db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") diff --git a/tests/go.mod b/tests/go.mod index 6eb6eb07..82d4fdc8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,8 +8,8 @@ require ( github.com/lib/pq v1.6.0 gorm.io/driver/mysql v0.3.1 gorm.io/driver/postgres v0.2.6 - gorm.io/driver/sqlite v1.0.8 - gorm.io/driver/sqlserver v0.2.5 + gorm.io/driver/sqlite v1.0.9 + gorm.io/driver/sqlserver v0.2.6 gorm.io/gorm v0.2.19 ) diff --git a/tests/table_test.go b/tests/table_test.go index faee6499..647b5e19 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -40,6 +40,17 @@ func TestTable(t *testing.T) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) From 39c8d6220b75b5a28dfff6ae88da17485b35dc46 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 6 Aug 2020 17:48:46 +0800 Subject: [PATCH 25/65] Fix soft delete panic when using unaddressable value --- soft_delete.go | 2 +- tests/delete_test.go | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/soft_delete.go b/soft_delete.go index 6b88b1a5..180bf745 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -64,7 +64,7 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } - if stmt.Dest != stmt.Model && stmt.Model != nil { + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) diff --git a/tests/delete_test.go b/tests/delete_test.go index 3d461f65..f5b3e784 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -43,6 +43,14 @@ func TestDelete(t *testing.T) { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } } + + if err := DB.Delete(users[0]).Error; err != nil { + t.Errorf("errors happened when delete: %v", err) + } + + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", err) + } } func TestDeleteWithTable(t *testing.T) { From 15b96ed3f482a29201b2c6c15fa0d3936d4d9a17 Mon Sep 17 00:00:00 2001 From: Caelansar Date: Mon, 10 Aug 2020 15:34:20 +0800 Subject: [PATCH 26/65] add testcase --- tests/scanner_valuer_test.go | 69 +++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 63a7c63c..6b8f086e 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -35,7 +35,9 @@ func TestScannerValuer(t *testing.T) { {"name1", "value1"}, {"name2", "value2"}, }, - Role: Role{Name: "admin"}, + Role: Role{Name: "admin"}, + ExampleStruct: ExampleStruct1{"name", "value"}, + ExampleStructPtr: &ExampleStruct1{"name", "value"}, } if err := DB.Create(&data).Error; err != nil { @@ -49,6 +51,14 @@ func TestScannerValuer(t *testing.T) { } AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") + + if result.ExampleStructPtr.Val != "value" { + t.Errorf(`ExampleStructPtr.Val should equal to "value", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value" { + t.Errorf(`ExampleStruct.Val should equal to "value", but got %v`, result.ExampleStruct.Val) + } } func TestScannerValuerWithFirstOrCreate(t *testing.T) { @@ -124,21 +134,23 @@ func TestInvalidValuer(t *testing.T) { type ScannerValuerStruct struct { gorm.Model - Name sql.NullString - Gender *sql.NullString - Age sql.NullInt64 - Male sql.NullBool - Height sql.NullFloat64 - Birthday sql.NullTime - Password EncryptedData - Bytes []byte - Num Num - Strings StringsSlice - Structs StructsSlice - Role Role - UserID *sql.NullInt64 - User User - EmptyTime EmptyTime + Name sql.NullString + Gender *sql.NullString + Age sql.NullInt64 + Male sql.NullBool + Height sql.NullFloat64 + Birthday sql.NullTime + Password EncryptedData + Bytes []byte + Num Num + Strings StringsSlice + Structs StructsSlice + Role Role + UserID *sql.NullInt64 + User User + EmptyTime EmptyTime + ExampleStruct ExampleStruct1 + ExampleStructPtr *ExampleStruct1 } type EncryptedData []byte @@ -207,6 +219,31 @@ type ExampleStruct struct { Value string } +type ExampleStruct1 struct { + Name string `json:"name,omitempty"` + Val string `json:"val,omitempty"` +} + +func (s ExampleStruct1) Value() (driver.Value, error) { + if len(s.Name) == 0 { + return nil, nil + } + //for test, has no practical meaning + s.Name = "" + return json.Marshal(s) +} + +func (s *ExampleStruct1) Scan(src interface{}) error { + switch value := src.(type) { + case string: + return json.Unmarshal([]byte(value), s) + case []byte: + return json.Unmarshal(value, s) + default: + return errors.New("not supported") + } +} + type StructsSlice []ExampleStruct func (l StructsSlice) Value() (driver.Value, error) { From 4a9d3a688aa47a7db7611902f6467f0b311aee79 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Aug 2020 21:22:51 +0800 Subject: [PATCH 27/65] Don't parse ignored anonymous field --- schema/field.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema/field.go b/schema/field.go index 1ca4cb6d..ea6364a4 100644 --- a/schema/field.go +++ b/schema/field.go @@ -311,7 +311,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { if reflect.Indirect(fieldValue).Kind() == reflect.Struct { var err error field.Creatable = false From a3dda47afac01b7430efb200d27473e24fe2fca9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 11 Aug 2020 21:22:51 +0800 Subject: [PATCH 28/65] Don't parse ignored anonymous field --- schema/field.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/schema/field.go b/schema/field.go index 1ca4cb6d..ea6364a4 100644 --- a/schema/field.go +++ b/schema/field.go @@ -311,7 +311,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer) { + if _, ok := field.TagSettings["EMBEDDED"]; ok || (fieldStruct.Anonymous && !isValuer && (field.Creatable || field.Updatable || field.Readable)) { if reflect.Indirect(fieldValue).Kind() == reflect.Struct { var err error field.Creatable = false From 7d45833f3e309f9c15bb9ca301c1782b23cb9f0e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 12:05:55 +0800 Subject: [PATCH 29/65] Fix driver.Valuer interface returns nil, close #3248 --- schema/field.go | 60 +++++++++++++++++------------------- tests/scanner_valuer_test.go | 52 ++++++++++++++++--------------- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/schema/field.go b/schema/field.go index ea6364a4..84fdb695 100644 --- a/schema/field.go +++ b/schema/field.go @@ -731,40 +731,10 @@ func (field *Field) setupValuerAndSetter() { return nil } default: - if _, ok := fieldValue.Interface().(sql.Scanner); ok { - // struct scanner - field.Set = func(value reflect.Value, v interface{}) (err error) { - if valuer, ok := v.(driver.Valuer); ok { - v, _ = valuer.Value() - } - - reflectV := reflect.ValueOf(v) - if !reflectV.IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else if reflectV.Kind() == reflect.Ptr { - if reflectV.Elem().IsNil() || !reflectV.Elem().IsValid() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(value, reflectV.Elem().Interface()) - } - } else { - err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) - } - return - } - } else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { + if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) - - if valuer, ok := v.(driver.Valuer); ok { - if valuer == nil || reflectV.IsNil() { - field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) - } else { - v, _ = valuer.Value() - } - } - if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { @@ -778,10 +748,38 @@ func (field *Field) setupValuerAndSetter() { if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } + + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + err = fieldValue.Interface().(sql.Scanner).Scan(v) } return } + } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { + // struct scanner + field.Set = func(value reflect.Value, v interface{}) (err error) { + reflectV := reflect.ValueOf(v) + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { + field.ReflectValueOf(value).Set(reflectV) + } else if reflectV.Kind() == reflect.Ptr { + if reflectV.IsNil() || !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else { + return field.Set(value, reflectV.Elem().Interface()) + } + } else { + if valuer, ok := v.(driver.Valuer); ok { + v, _ = valuer.Value() + } + + err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) + } + return + } } else { field.Set = func(value reflect.Value, v interface{}) (err error) { return fallbackSetter(value, v, field.Set) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index 6b8f086e..b8306af7 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -36,8 +36,8 @@ func TestScannerValuer(t *testing.T) { {"name2", "value2"}, }, Role: Role{Name: "admin"}, - ExampleStruct: ExampleStruct1{"name", "value"}, - ExampleStructPtr: &ExampleStruct1{"name", "value"}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err != nil { @@ -46,19 +46,18 @@ func TestScannerValuer(t *testing.T) { var result ScannerValuerStruct - if err := DB.Find(&result).Error; err != nil { + if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil { t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) } + if result.ExampleStructPtr.Val != "value2" { + t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val) + } + + if result.ExampleStruct.Val != "value1" { + t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct) + } AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") - - if result.ExampleStructPtr.Val != "value" { - t.Errorf(`ExampleStructPtr.Val should equal to "value", but got %v`, result.ExampleStructPtr.Val) - } - - if result.ExampleStruct.Val != "value" { - t.Errorf(`ExampleStruct.Val should equal to "value", but got %v`, result.ExampleStruct.Val) - } } func TestScannerValuerWithFirstOrCreate(t *testing.T) { @@ -68,9 +67,11 @@ func TestScannerValuerWithFirstOrCreate(t *testing.T) { } data := ScannerValuerStruct{ - Name: sql.NullString{String: "name", Valid: true}, - Gender: &sql.NullString{String: "M", Valid: true}, - Age: sql.NullInt64{Int64: 18, Valid: true}, + Name: sql.NullString{String: "name", Valid: true}, + Gender: &sql.NullString{String: "M", Valid: true}, + Age: sql.NullInt64{Int64: 18, Valid: true}, + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } var result ScannerValuerStruct @@ -109,7 +110,9 @@ func TestInvalidValuer(t *testing.T) { } data := ScannerValuerStruct{ - Password: EncryptedData("xpass1"), + Password: EncryptedData("xpass1"), + ExampleStruct: ExampleStruct{"name", "value1"}, + ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err == nil { @@ -149,8 +152,8 @@ type ScannerValuerStruct struct { UserID *sql.NullInt64 User User EmptyTime EmptyTime - ExampleStruct ExampleStruct1 - ExampleStructPtr *ExampleStruct1 + ExampleStruct ExampleStruct + ExampleStructPtr *ExampleStruct } type EncryptedData []byte @@ -215,25 +218,24 @@ func (l *StringsSlice) Scan(input interface{}) error { } type ExampleStruct struct { - Name string - Value string + Name string + Val string } -type ExampleStruct1 struct { - Name string `json:"name,omitempty"` - Val string `json:"val,omitempty"` +func (ExampleStruct) GormDataType() string { + return "bytes" } -func (s ExampleStruct1) Value() (driver.Value, error) { +func (s ExampleStruct) Value() (driver.Value, error) { if len(s.Name) == 0 { return nil, nil } - //for test, has no practical meaning + // for test, has no practical meaning s.Name = "" return json.Marshal(s) } -func (s *ExampleStruct1) Scan(src interface{}) error { +func (s *ExampleStruct) Scan(src interface{}) error { switch value := src.(type) { case string: return json.Unmarshal([]byte(value), s) From 045d5f853838b9800acdb8ae204969ba3d93e00a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 12:18:36 +0800 Subject: [PATCH 30/65] Fix count with join and no model, close #3255 --- callbacks/query.go | 2 +- tests/count_test.go | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/callbacks/query.go b/callbacks/query.go index be829fbc..5ae1e904 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -96,7 +96,7 @@ func BuildQuerySQL(db *gorm.DB) { // inline joins if len(db.Statement.Joins) != 0 { - if len(db.Statement.Selects) == 0 { + if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} diff --git a/tests/count_test.go b/tests/count_test.go index 05661ae8..216fa3a1 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -67,4 +67,9 @@ func TestCount(t *testing.T) { if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) } + + var count4 int64 + if err := DB.Debug().Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count) + } } From ecc946be6e93a108bbdcc10cf2719d08baa50f3f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 16:05:06 +0800 Subject: [PATCH 31/65] Test update from sub query --- callbacks/update.go | 9 +++++++-- tests/update_test.go | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 12806af6..0ced3ffb 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -174,11 +174,16 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { sort.Strings(keys) for _, k := range keys { + kv := value[k] + if _, ok := kv.(*gorm.DB); ok { + kv = []interface{}{kv} + } + if stmt.Schema != nil { if field := stmt.Schema.LookUpField(k); field != nil { if field.DBName != "" { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) assignValue(field, value[k]) } } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { @@ -189,7 +194,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { - set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: value[k]}) + set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) } } diff --git a/tests/update_test.go b/tests/update_test.go index 2ff150dd..83a7b9a2 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -545,3 +545,21 @@ func TestUpdatesTableWithIgnoredValues(t *testing.T) { t.Errorf("element's ignored field should not be updated") } } + +func TestUpdateFromSubQuery(t *testing.T) { + user := *GetUser("update_from_sub_query", Config{Company: true}) + if err := DB.Create(&user).Error; err != nil { + t.Errorf("failed to create user, got error: %v", err) + } + + if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + var result User + DB.First(&result, user.ID) + + if result.Name != user.Company.Name { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } +} From dea93edb6acdccdb398a5f9d89412f9bd0be5b39 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 16:28:21 +0800 Subject: [PATCH 32/65] Copy TableExpr when clone statement --- statement.go | 1 + tests/update_test.go | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/statement.go b/statement.go index e9d826c4..b5b5db5a 100644 --- a/statement.go +++ b/statement.go @@ -392,6 +392,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) { func (stmt *Statement) clone() *Statement { newStmt := &Statement{ + TableExpr: stmt.TableExpr, Table: stmt.Table, Model: stmt.Model, Dest: stmt.Dest, diff --git a/tests/update_test.go b/tests/update_test.go index 83a7b9a2..a59a8856 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -562,4 +562,14 @@ func TestUpdateFromSubQuery(t *testing.T) { if result.Name != user.Company.Name { t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) } + + DB.Model(&user.Company).Update("Name", "new company name") + if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil { + t.Errorf("failed to update with sub query, got error %v", err) + } + + DB.First(&result, user.ID) + if result.Name != "new company name" { + t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) + } } From 2c4e8571259bf6193cf5d396594104fca7fa727d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 18:09:04 +0800 Subject: [PATCH 33/65] Should ignore association conditions when querying with struct --- statement.go | 12 ++++++------ tests/query_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/statement.go b/statement.go index b5b5db5a..6114f468 100644 --- a/statement.go +++ b/statement.go @@ -309,10 +309,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c for _, field := range s.Fields { if field.Readable { if v, isZero := field.ValueOf(reflectValue); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { + if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } } } @@ -322,10 +322,10 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c for _, field := range s.Fields { if field.Readable { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { - if field.DBName == "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) - } else { + if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + } else if field.DataType != "" { + conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) } } } diff --git a/tests/query_test.go b/tests/query_test.go index 59f1130b..4c2a2abd 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -103,6 +103,22 @@ func TestFind(t *testing.T) { }) } +func TestQueryWithAssociation(t *testing.T) { + user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3}) + + if err := DB.Create(&user).Error; err != nil { + t.Fatalf("errors happened when create user: %v", err) + } + + if err := DB.Where(&user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } + + if err := DB.Where(user).First(&User{}).Error; err != nil { + t.Errorf("search with struct with association should returns no error, but got %v", err) + } +} + func TestFindInBatches(t *testing.T) { var users = []User{ *GetUser("find_in_batches", Config{}), From 2faff25dfbcfff9e3fb37c8fcf1a20a468f887a8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 13 Aug 2020 18:38:39 +0800 Subject: [PATCH 34/65] Fix FirstOr(Init/Create) when assigning with association --- finisher_api.go | 67 +++++++++++++++++++++++++++++++-------------- tests/query_test.go | 2 ++ 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 33a4f121..8a3d4199 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -132,19 +133,47 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat return } -func (tx *DB) assignExprsToValue(exprs []clause.Expression) { - for _, expr := range exprs { - if eq, ok := expr.(clause.Eq); ok { - switch column := eq.Column.(type) { - case string: - if field := tx.Statement.Schema.LookUpField(column); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) +func (tx *DB) assignInterfacesToValue(values ...interface{}) { + for _, value := range values { + switch v := value.(type) { + case []clause.Expression: + for _, expr := range v { + if eq, ok := expr.(clause.Eq); ok { + switch column := eq.Column.(type) { + case string: + if field := tx.Statement.Schema.LookUpField(column); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Column: + if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + default: + } } - case clause.Column: - if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { - tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) + } + case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: + exprs := tx.Statement.BuildCondition(value) + tx.assignInterfacesToValue(exprs) + default: + if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + for _, f := range s.Fields { + if f.Readable { + if v, isZero := f.ValueOf(reflectValue); !isZero { + if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { + tx.AddError(field.Set(tx.Statement.ReflectValue, v)) + } + } + } + } } - default: + } else if len(values) > 0 { + exprs := tx.Statement.BuildCondition(values[0], values[1:]...) + tx.assignInterfacesToValue(exprs) + return } } } @@ -154,22 +183,20 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } tx.Error = nil } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return } @@ -180,20 +207,18 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { - tx.assignExprsToValue(where.Exprs) + tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.attrs[0], tx.Statement.attrs[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.attrs...) } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) - tx.assignExprsToValue(exprs) + tx.assignInterfacesToValue(tx.Statement.assigns...) } return tx.Create(dest) diff --git a/tests/query_test.go b/tests/query_test.go index 4c2a2abd..72dd89b9 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -110,6 +110,8 @@ func TestQueryWithAssociation(t *testing.T) { t.Fatalf("errors happened when create user: %v", err) } + user.CreatedAt = time.Time{} + user.UpdatedAt = time.Time{} if err := DB.Where(&user).First(&User{}).Error; err != nil { t.Errorf("search with struct with association should returns no error, but got %v", err) } From 6834c25cec6b037299970cc845de1a186e04ba1f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 12:02:41 +0800 Subject: [PATCH 35/65] Fix stack overflow for embedded self-referred associations, close #3269 --- schema/field.go | 8 +++++++- schema/model_test.go | 22 ++++++++++++++++++++++ schema/relationship.go | 17 +++++++---------- schema/schema_test.go | 6 ++++++ 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/schema/field.go b/schema/field.go index 84fdb695..78eeccdc 100644 --- a/schema/field.go +++ b/schema/field.go @@ -317,7 +317,13 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Creatable = false field.Updatable = false field.Readable = false - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil { + + cacheStore := schema.cacheStore + if _, embedded := schema.cacheStore.Load("embedded_cache_store"); !embedded { + cacheStore = &sync.Map{} + cacheStore.Store("embedded_cache_store", true) + } + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { diff --git a/schema/model_test.go b/schema/model_test.go index a13372b5..84c7b327 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -39,3 +39,25 @@ type AdvancedDataTypeUser struct { Active mybool Admin *mybool } + +type BaseModel struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time + CreatedBy *int + Created *VersionUser `gorm:"foreignKey:CreatedBy"` + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` +} + +type VersionModel struct { + BaseModel + Version int + CompanyID int +} + +type VersionUser struct { + VersionModel + Name string + Age uint + Birthday *time.Time +} diff --git a/schema/relationship.go b/schema/relationship.go index 93080105..537a3582 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -5,7 +5,6 @@ import ( "reflect" "regexp" "strings" - "sync" "github.com/jinzhu/inflection" "gorm.io/gorm/clause" @@ -67,16 +66,14 @@ func (schema *Schema) parseRelation(field *Field) { } ) + cacheStore := schema.cacheStore if field.OwnerSchema != nil { - if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil { - schema.err = err - return - } - } else { - if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil { - schema.err = err - return - } + cacheStore = field.OwnerSchema.cacheStore + } + + if relation.FieldSchema, err = Parse(fieldValue, cacheStore, schema.namer); err != nil { + schema.err = err + return } if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { diff --git a/schema/schema_test.go b/schema/schema_test.go index 99781e47..966f80e4 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -160,3 +160,9 @@ func TestCustomizeTableName(t *testing.T) { t.Errorf("Failed to customize table with TableName method") } } + +func TestNestedModel(t *testing.T) { + if _, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}); err != nil { + t.Fatalf("failed to parse nested user, got error %v", err) + } +} From 2a716e04e6528f1979dc0a7a2de509f0350e9e04 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 12:16:42 +0800 Subject: [PATCH 36/65] Avoid panic for invalid transaction, close #3271 --- finisher_api.go | 6 ++++-- tests/transaction_test.go | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 8a3d4199..19534460 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -445,7 +445,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { // Commit commit a transaction func (db *DB) Commit() *DB { - if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) } else { db.AddError(ErrInvalidTransaction) @@ -456,7 +456,9 @@ func (db *DB) Commit() *DB { // Rollback rollback a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { - db.AddError(committer.Rollback()) + if !reflect.ValueOf(committer).IsNil() { + db.AddError(committer.Rollback()) + } } else { db.AddError(ErrInvalidTransaction) } diff --git a/tests/transaction_test.go b/tests/transaction_test.go index c101388a..aea151d9 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "errors" "testing" @@ -57,6 +58,25 @@ func TestTransaction(t *testing.T) { } } +func TestCancelTransaction(t *testing.T) { + ctx := context.Background() + ctx, cancelFunc := context.WithCancel(ctx) + cancelFunc() + + user := *GetUser("cancel_transaction", Config{}) + DB.Create(&user) + + err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var result User + tx.First(&result, user.ID) + return nil + }) + + if err == nil { + t.Fatalf("Transaction should get error when using cancelled context") + } +} + func TestTransactionWithBlock(t *testing.T) { assertPanic := func(f func()) { defer func() { From 681268cc43a2aa665e5577680b88ac77b9e5b64c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 16:31:09 +0800 Subject: [PATCH 37/65] Refactor Create/Query/Update/DeleteClauses interface --- schema/field.go | 22 -------------------- schema/interfaces.go | 8 ++++---- schema/schema.go | 17 ++++++++++++++++ soft_delete.go | 48 +++++++++++++++++++++++++++++++++----------- 4 files changed, 57 insertions(+), 38 deletions(-) diff --git a/schema/field.go b/schema/field.go index 78eeccdc..bc47e543 100644 --- a/schema/field.go +++ b/schema/field.go @@ -88,23 +88,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } fieldValue := reflect.New(field.IndirectFieldType) - - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses()...) - } - - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses()...) - } - - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses()...) - } - - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses()...) - } - // if field is valuer, used its value or first fields as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { @@ -353,11 +336,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.TagSettings[k] = v } } - - field.Schema.CreateClauses = append(field.Schema.CreateClauses, field.EmbeddedSchema.CreateClauses...) - field.Schema.QueryClauses = append(field.Schema.QueryClauses, field.EmbeddedSchema.QueryClauses...) - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, field.EmbeddedSchema.UpdateClauses...) - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, field.EmbeddedSchema.DeleteClauses...) } else { schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } diff --git a/schema/interfaces.go b/schema/interfaces.go index f5d07843..e8e51e4c 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -7,17 +7,17 @@ type GormDataTypeInterface interface { } type CreateClausesInterface interface { - CreateClauses() []clause.Interface + CreateClauses(*Field) []clause.Interface } type QueryClausesInterface interface { - QueryClauses() []clause.Interface + QueryClauses(*Field) []clause.Interface } type UpdateClausesInterface interface { - UpdateClauses() []clause.Interface + UpdateClauses(*Field) []clause.Interface } type DeleteClausesInterface interface { - DeleteClauses() []clause.Interface + DeleteClauses(*Field) []clause.Interface } diff --git a/schema/schema.go b/schema/schema.go index 9206c24e..d81da4b8 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -219,6 +219,23 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) return schema, schema.err } } + + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } + + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } } } diff --git a/soft_delete.go b/soft_delete.go index 180bf745..875623bc 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -24,37 +24,61 @@ func (n DeletedAt) Value() (driver.Value, error) { return n.Time, nil } -func (DeletedAt) QueryClauses() []clause.Interface { +func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { return []clause.Interface{ clause.Where{Exprs: []clause.Expression{ clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: "deleted_at"}, + Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: nil, }, }}, } } -func (DeletedAt) DeleteClauses() []clause.Interface { - return []clause.Interface{SoftDeleteClause{}} +type SoftDeleteQueryClause struct { + Field *schema.Field } -type SoftDeleteClause struct { -} - -func (SoftDeleteClause) Name() string { +func (sd SoftDeleteQueryClause) Name() string { return "" } -func (SoftDeleteClause) Build(clause.Builder) { +func (sd SoftDeleteQueryClause) Build(clause.Builder) { } -func (SoftDeleteClause) MergeClause(*clause.Clause) { +func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { } -func (SoftDeleteClause) ModifyStatement(stmt *Statement) { +func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { + if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + }}) + stmt.Clauses["soft_delete_enabled"] = clause.Clause{} + } +} + +func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteDeleteClause{Field: f}} +} + +type SoftDeleteDeleteClause struct { + Field *schema.Field +} + +func (sd SoftDeleteDeleteClause) Name() string { + return "" +} + +func (sd SoftDeleteDeleteClause) Build(clause.Builder) { +} + +func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { - stmt.AddClause(clause.Set{{Column: clause.Column{Name: "deleted_at"}, Value: stmt.DB.NowFunc()}}) + stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: stmt.DB.NowFunc()}}) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) From 9fcc337bd1ccfccfddcdbd4a9b8b08ad08bf465c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 17 Aug 2020 17:41:36 +0800 Subject: [PATCH 38/65] Fix create from map --- callbacks/associations.go | 59 ++++++++++++++++++++++++--------------- callbacks/create.go | 22 ++++++++++++--- callbacks/helper.go | 10 ++++++- tests/create_test.go | 39 ++++++++++++++++++++++++++ tests/go.mod | 2 +- 5 files changed, 103 insertions(+), 29 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 3508335a..2710ffe9 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -48,14 +48,19 @@ func SaveBeforeAssociations(db *gorm.DB) { elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 0) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value - rv := rel.Field.ReflectValueOf(obj) // relation reflect value - objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value + rv := rel.Field.ReflectValueOf(obj) // relation reflect value + objs = append(objs, obj) + if isPtr { + elems = reflect.Append(elems, rv) + } else { + elems = reflect.Append(elems, rv.Addr()) + } } + } else { + break } } @@ -112,22 +117,24 @@ func SaveAfterAssociations(db *gorm.DB) { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) - if _, zero := rel.Field.ValueOf(obj); !zero { - rv := rel.Field.ReflectValueOf(obj) - if rv.Kind() != reflect.Ptr { - rv = rv.Addr() - } - - for _, ref := range rel.References { - if ref.OwnPrimaryKey { - fv, _ := ref.PrimaryKey.ValueOf(obj) - db.AddError(ref.ForeignKey.Set(rv, fv)) - } else if ref.PrimaryValue != "" { - db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + if reflect.Indirect(obj).Kind() == reflect.Struct { + if _, zero := rel.Field.ValueOf(obj); !zero { + rv := rel.Field.ReflectValueOf(obj) + if rv.Kind() != reflect.Ptr { + rv = rv.Addr() } - } - elems = reflect.Append(elems, rv) + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + fv, _ := ref.PrimaryKey.ValueOf(obj) + db.AddError(ref.ForeignKey.Set(rv, fv)) + } else if ref.PrimaryValue != "" { + db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) + } + } + + elems = reflect.Append(elems, rv) + } } } @@ -207,7 +214,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) @@ -277,7 +287,10 @@ func SaveAfterAssociations(db *gorm.DB) { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - appendToElems(db.Statement.ReflectValue.Index(i)) + obj := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(obj).Kind() == reflect.Struct { + appendToElems(obj) + } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) diff --git a/callbacks/create.go b/callbacks/create.go index 3a414dd7..4cc0f555 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -61,16 +61,26 @@ func Create(config *Config) func(db *gorm.DB) { case reflect.Slice, reflect.Array: if config.LastInsertIDReversed { for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID-- } } } else { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue.Index(i)); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) insertID++ } } @@ -140,6 +150,10 @@ func CreateWithReturning(db *gorm.DB) { for rows.Next() { BEGIN: reflectValue := db.Statement.ReflectValue.Index(int(db.RowsAffected)) + if reflect.Indirect(reflectValue).Kind() != reflect.Struct { + break + } + for idx, field := range fields { fieldValue := field.ReflectValueOf(reflectValue) diff --git a/callbacks/helper.go b/callbacks/helper.go index 7bd910f6..80fbc2a1 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -26,6 +26,10 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { values.Columns = append(values.Columns, clause.Column{Name: k}) + if len(values.Values) == 0 { + values.Values = [][]interface{}{{}} + } + values.Values[0] = append(values.Values[0], value) } } @@ -61,11 +65,15 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st sort.Strings(columns) values.Values = make([][]interface{}, len(mapValues)) + values.Columns = make([]clause.Column, len(columns)) for idx, column := range columns { + values.Columns[idx] = clause.Column{Name: column} + for i, v := range result[column] { - if i == 0 { + if len(values.Values[i]) == 0 { values.Values[i] = make([]interface{}, len(columns)) } + values.Values[i][idx] = v } } diff --git a/tests/create_test.go b/tests/create_test.go index ae6e1232..ab0a78d4 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -39,6 +39,45 @@ func TestCreate(t *testing.T) { } } +func TestCreateFromMap(t *testing.T) { + if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result User + if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil { + t.Fatalf("failed to create data from map, got error: %v", err) + } + + var result1 User + if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 { + t.Fatalf("failed to create from map, got error %v", err) + } + + datas := []map[string]interface{}{ + {"Name": "create_from_map_2", "Age": 19}, + {"name": "create_from_map_3", "Age": 20}, + } + + if err := DB.Model(&User{}).Create(datas).Error; err != nil { + t.Fatalf("failed to create data from slice of map, got error: %v", err) + } + + var result2 User + if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } + + var result3 User + if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 { + t.Fatalf("failed to query data after create from slice of map, got error %v", err) + } +} + func TestCreateWithAssociations(t *testing.T) { var user = *GetUser("create_with_associations", Config{ Account: true, diff --git a/tests/go.mod b/tests/go.mod index 82d4fdc8..54a808d0 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v0.3.1 gorm.io/driver/postgres v0.2.6 gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.6 + gorm.io/driver/sqlserver v0.2.7 gorm.io/gorm v0.2.19 ) From dc48e04896aa529bb4014390347e21e2c4c509b2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Aug 2020 11:21:40 +0800 Subject: [PATCH 39/65] Fix nested embedded struct, close #3278 --- schema/field.go | 8 +++----- schema/model_test.go | 5 ++--- schema/schema.go | 37 ++++++++++++++++++----------------- schema/schema_test.go | 18 ++++++++++++++++- schema/utils.go | 2 ++ tests/embedded_struct_test.go | 4 ++-- utils/tests/utils.go | 2 +- 7 files changed, 46 insertions(+), 30 deletions(-) diff --git a/schema/field.go b/schema/field.go index bc47e543..35c1e44d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -301,14 +301,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Updatable = false field.Readable = false - cacheStore := schema.cacheStore - if _, embedded := schema.cacheStore.Load("embedded_cache_store"); !embedded { - cacheStore = &sync.Map{} - cacheStore.Store("embedded_cache_store", true) - } + cacheStore := &sync.Map{} + cacheStore.Store(embeddedCacheKey, true) if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { schema.err = err } + for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema ef.OwnerSchema = field.EmbeddedSchema diff --git a/schema/model_test.go b/schema/model_test.go index 84c7b327..1f2b0948 100644 --- a/schema/model_test.go +++ b/schema/model_test.go @@ -41,7 +41,7 @@ type AdvancedDataTypeUser struct { } type BaseModel struct { - ID uint `gorm:"primarykey"` + ID uint CreatedAt time.Time CreatedBy *int Created *VersionUser `gorm:"foreignKey:CreatedBy"` @@ -51,8 +51,7 @@ type BaseModel struct { type VersionModel struct { BaseModel - Version int - CompanyID int + Version int } type VersionUser struct { diff --git a/schema/schema.go b/schema/schema.go index d81da4b8..458256d1 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -212,29 +212,30 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { - // parse relations for unidentified fields - for _, field := range schema.Fields { - if field.DataType == "" && field.Creatable { - if schema.parseRelation(field); schema.err != nil { - return schema, schema.err + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { + for _, field := range schema.Fields { + if field.DataType == "" && field.Creatable { + if schema.parseRelation(field); schema.err != nil { + return schema, schema.err + } } - } - fieldValue := reflect.New(field.IndirectFieldType) - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { - field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) - } + fieldValue := reflect.New(field.IndirectFieldType) + if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) + } - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { - field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) - } + if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) + } - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { - field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) - } + if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) + } - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { - field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) + } } } } diff --git a/schema/schema_test.go b/schema/schema_test.go index 966f80e4..c0ad3c25 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -162,7 +162,23 @@ func TestCustomizeTableName(t *testing.T) { } func TestNestedModel(t *testing.T) { - if _, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}); err != nil { + versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { t.Fatalf("failed to parse nested user, got error %v", err) } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Int, Size: 64}, + {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, + } + + for _, f := range fields { + checkSchemaField(t, versionUser, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } } diff --git a/schema/utils.go b/schema/utils.go index 1481d428..29f2fefb 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -9,6 +9,8 @@ import ( "gorm.io/gorm/utils" ) +var embeddedCacheKey = "embedded_cache_store" + func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index fb0d6f23..c29078bd 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -160,9 +160,9 @@ func TestEmbeddedRelations(t *testing.T) { Advanced bool } - DB.Debug().Migrator().DropTable(&AdvancedUser{}) + DB.Migrator().DropTable(&AdvancedUser{}) - if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil { + if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { t.Errorf("Failed to auto migrate advanced user, got error %v", err) } } diff --git a/utils/tests/utils.go b/utils/tests/utils.go index 0067d5c6..817e4b0b 100644 --- a/utils/tests/utils.go +++ b/utils/tests/utils.go @@ -76,7 +76,7 @@ func AssertEqual(t *testing.T, got, expect interface{}) { } } else { name := reflect.ValueOf(got).Type().Elem().Name() - t.Errorf("%v expects length: %v, got %v", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len()) + t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got) } return } From 50826742fd0bd26caf55a7a5a96b2c85b612f4ae Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Aug 2020 18:00:36 +0800 Subject: [PATCH 40/65] Add error gorm.ErrInvalidData --- callbacks/create.go | 2 ++ callbacks/update.go | 2 ++ errors.go | 2 ++ tests/update_test.go | 9 +++++++++ 4 files changed, 15 insertions(+) diff --git a/callbacks/create.go b/callbacks/create.go index 4cc0f555..7a32ed5c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -309,6 +309,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } } + default: + stmt.AddError(gorm.ErrInvalidData) } } diff --git a/callbacks/update.go b/callbacks/update.go index 0ced3ffb..5656d166 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -252,6 +252,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } } + default: + stmt.AddError(gorm.ErrInvalidData) } } diff --git a/errors.go b/errors.go index 115b8e25..32ff8ec1 100644 --- a/errors.go +++ b/errors.go @@ -19,6 +19,8 @@ var ( ErrPrimaryKeyRequired = errors.New("primary key required") // ErrModelValueRequired model value required ErrModelValueRequired = errors.New("model value required") + // ErrInvalidData unsupported data + ErrInvalidData = errors.New("unsupported data") // ErrUnsupportedDriver unsupported driver ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered diff --git a/tests/update_test.go b/tests/update_test.go index a59a8856..49a13be9 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -334,6 +334,15 @@ func TestSelectWithUpdateWithMap(t *testing.T) { AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") } +func TestWithUpdateWithInvalidMap(t *testing.T) { + user := *GetUser("update_with_invalid_map", Config{}) + DB.Create(&user) + + if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error for unsupported updating data") + } +} + func TestOmitWithUpdate(t *testing.T) { user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) From b5de8aeb425cc9eccf92b8c3252fc0a7201ed52e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 18 Aug 2020 18:58:53 +0800 Subject: [PATCH 41/65] Fix overrite SELECT clause --- chainable_api.go | 3 +++ finisher_api.go | 2 +- tests/query_test.go | 5 +++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/chainable_api.go b/chainable_api.go index 4df8780e..78724cc8 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -91,6 +91,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } + delete(tx.Statement.Clauses, "SELECT") case string: fields := strings.FieldsFunc(v, utils.IsChar) @@ -112,6 +113,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } } + + delete(tx.Statement.Clauses, "SELECT") } else { tx.Statement.AddClause(clause.Select{ Distinct: db.Statement.Distinct, diff --git a/finisher_api.go b/finisher_api.go index 19534460..88873948 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -294,7 +294,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) - defer tx.Statement.AddClause(clause.Select{}) + defer delete(tx.Statement.Clauses, "SELECT") } else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") { expr := clause.Expr{SQL: "count(1)"} diff --git a/tests/query_test.go b/tests/query_test.go index 72dd89b9..d71c813a 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -346,6 +346,11 @@ func TestSelect(t *testing.T) { if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) } + + r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID) + if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { + t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) + } } func TestOmit(t *testing.T) { From 3411425d651e540cf19f9845d83cc507d929f2e6 Mon Sep 17 00:00:00 2001 From: deepoli <67894732+deepoil@users.noreply.github.com> Date: Tue, 18 Aug 2020 20:03:09 +0900 Subject: [PATCH 42/65] fix return value and delete unused default (#3280) --- chainable_api.go | 2 +- finisher_api.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/chainable_api.go b/chainable_api.go index 78724cc8..9b46a95b 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -69,7 +69,7 @@ func (db *DB) Distinct(args ...interface{}) (tx *DB) { if len(args) > 0 { tx = tx.Select(args[0], args[1:]...) } - return tx + return } // Select specify fields that you want when querying, creating, updating diff --git a/finisher_api.go b/finisher_api.go index 88873948..db069c5c 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -148,7 +148,6 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) } - default: } } } From c1782d60c149483111b021e29c412d9139bd46ba Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 15:47:08 +0800 Subject: [PATCH 43/65] Fix embedded scanner/valuer, close #3283 --- schema/field.go | 34 +++++++++++++++++++++------------- tests/scanner_valuer_test.go | 6 ++++++ 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/schema/field.go b/schema/field.go index 35c1e44d..59367399 100644 --- a/schema/field.go +++ b/schema/field.go @@ -92,32 +92,40 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { - var overrideFieldValue bool - if v, err := valuer.Value(); v != nil && err == nil { - overrideFieldValue = true + if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { fieldValue = reflect.ValueOf(v) } - if field.IndirectFieldType.Kind() == reflect.Struct { - for i := 0; i < field.IndirectFieldType.NumField(); i++ { - if !overrideFieldValue { - newFieldType := field.IndirectFieldType.Field(i).Type + var getRealFieldValue func(reflect.Value) + getRealFieldValue = func(v reflect.Value) { + rv := reflect.Indirect(v) + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + for i := 0; i < rv.Type().NumField(); i++ { + newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) - overrideFieldValue = true - } - // copy tag settings from valuer - for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { - if _, ok := field.TagSettings[key]; !ok { - field.TagSettings[key] = value + if rv.Type() != reflect.Indirect(fieldValue).Type() { + getRealFieldValue(fieldValue) + } + + if fieldValue.IsValid() { + return + } + + for key, value := range ParseTagSetting(field.IndirectFieldType.Field(i).Tag.Get("gorm"), ";") { + if _, ok := field.TagSettings[key]; !ok { + field.TagSettings[key] = value + } } } } } + + getRealFieldValue(fieldValue) } } diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index b8306af7..ce8a2b50 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -27,6 +27,7 @@ func TestScannerValuer(t *testing.T) { Male: sql.NullBool{Bool: true, Valid: true}, Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, + Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}}, Password: EncryptedData("pass1"), Bytes: []byte("byte"), Num: 18, @@ -143,6 +144,7 @@ type ScannerValuerStruct struct { Male sql.NullBool Height sql.NullFloat64 Birthday sql.NullTime + Allergen NullString Password EncryptedData Bytes []byte Num Num @@ -299,3 +301,7 @@ func (t *EmptyTime) Scan(v interface{}) error { func (t EmptyTime) Value() (driver.Value, error) { return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil } + +type NullString struct { + sql.NullString +} From 3313c11888538af30abed9b168550b426a4af082 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 19:02:32 +0800 Subject: [PATCH 44/65] Fix embedded struct containing field named ID, close #3286 --- schema/field.go | 8 ++++++++ schema/schema_helper_test.go | 9 +++++++-- schema/schema_test.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/schema/field.go b/schema/field.go index 59367399..de937132 100644 --- a/schema/field.go +++ b/schema/field.go @@ -336,6 +336,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.PrimaryKey = true } else { ef.PrimaryKey = false + + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } + + if ef.DefaultValue == "" { + ef.HasDefaultValue = false + } } for k, v := range field.TagSettings { diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index f202b487..4e916f84 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -49,7 +49,12 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* } } - if parsedField, ok := s.FieldsByName[f.Name]; !ok { + parsedField, ok := s.FieldsByDBName[f.DBName] + if !ok { + parsedField, ok = s.FieldsByName[f.Name] + } + + if !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") @@ -62,7 +67,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* for _, name := range []string{f.DBName, f.Name} { if name != "" { - if field := s.LookUpField(name); field == nil || parsedField != field { + if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) } } diff --git a/schema/schema_test.go b/schema/schema_test.go index c0ad3c25..c28812af 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -182,3 +182,35 @@ func TestNestedModel(t *testing.T) { }) } } + +func TestEmbeddedStruct(t *testing.T) { + type Company struct { + ID int + Name string + } + + type Corp struct { + ID uint + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + f.Creatable = true + f.Updatable = true + f.Readable = true + }) + } +} From 528e5ba5c41b647367d48e527b9fe9ad7dfcdd72 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 19 Aug 2020 20:30:39 +0800 Subject: [PATCH 45/65] Cleanup Model after Count --- finisher_api.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/finisher_api.go b/finisher_api.go index db069c5c..cf46f78a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -289,6 +289,9 @@ func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest + defer func() { + tx.Statement.Model = nil + }() } if len(tx.Statement.Selects) == 0 { From 0c9870d1ae52a466837daf7f8386e3f2c0c1505c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Aug 2020 10:39:01 +0800 Subject: [PATCH 46/65] Test Association Mode with conditions --- tests/associations_has_many_test.go | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/associations_has_many_test.go b/tests/associations_has_many_test.go index d8befd8a..173e9231 100644 --- a/tests/associations_has_many_test.go +++ b/tests/associations_has_many_test.go @@ -21,6 +21,23 @@ func TestHasManyAssociation(t *testing.T) { DB.Model(&user2).Association("Pets").Find(&user2.Pets) CheckUser(t, user2, user) + var pets []Pet + DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets) + + if len(pets) != 1 { + t.Fatalf("should only find one pets, but got %v", len(pets)) + } + + CheckPet(t, pets[0], *user.Pets[0]) + + if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 { + t.Fatalf("should only find one pets, but got %v", count) + } + + if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 { + t.Fatalf("should only find no pet with invalid conditions, but got %v", count) + } + // Count AssertAssociationCount(t, user, "Pets", 2, "") @@ -40,13 +57,13 @@ func TestHasManyAssociation(t *testing.T) { AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") - var pets = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} + var pets2 = []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} - if err := DB.Model(&user2).Association("Pets").Append(&pets); err != nil { + if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) } - for _, pet := range pets { + for _, pet := range pets2 { var pet = pet if pet.ID == 0 { t.Fatalf("Pet's ID should be created") From 06de6e8834baf8ed56230727cdf715809e2c7f27 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Aug 2020 10:58:35 +0800 Subject: [PATCH 47/65] Test same field name from embedded field, close #3291 --- schema/schema_helper_test.go | 2 +- schema/schema_test.go | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/schema/schema_helper_test.go b/schema/schema_helper_test.go index 4e916f84..cc0306e0 100644 --- a/schema/schema_helper_test.go +++ b/schema/schema_helper_test.go @@ -57,7 +57,7 @@ func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(* if !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { - tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "Tag", "TagSettings") + tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings") if f.DBName != "" { if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { diff --git a/schema/schema_test.go b/schema/schema_test.go index c28812af..8bd1e5ca 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -4,6 +4,7 @@ import ( "sync" "testing" + "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) @@ -184,13 +185,19 @@ func TestNestedModel(t *testing.T) { } func TestEmbeddedStruct(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + type Company struct { - ID int - Name string + ID int + OwnerID int + Name string } type Corp struct { - ID uint + CorpBase Base Company `gorm:"embedded;embeddedPrefix:company_"` } @@ -201,9 +208,11 @@ func TestEmbeddedStruct(t *testing.T) { } fields := []schema.Field{ - {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, + {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, } for _, f := range fields { From f88e8b072c6e9dc5ecb0530823ee957f9cff5f6f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 20 Aug 2020 18:13:29 +0800 Subject: [PATCH 48/65] Check valid pointer before use it as Valuer --- schema/field.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/schema/field.go b/schema/field.go index de937132..497aa02d 100644 --- a/schema/field.go +++ b/schema/field.go @@ -473,16 +473,16 @@ func (field *Field) setupValuerAndSetter() { } } - if valuer, ok := v.(driver.Valuer); ok { - if v, err = valuer.Value(); err == nil { - err = setter(value, v) - } - } else if reflectV.Kind() == reflect.Ptr { + if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { err = setter(value, reflectV.Elem().Interface()) } + } else if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = setter(value, v) + } } else { return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) } From 2b510d6423f6299d53eee6a69252a6acc4c431c7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 21 Aug 2020 15:40:50 +0800 Subject: [PATCH 49/65] Don't create index for join table, close #3294 --- schema/relationship.go | 4 ++-- schema/utils.go | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 537a3582..c8d129f2 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -225,7 +225,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, - Tag: removeSettingFromTag(removeSettingFromTag(ownField.StructField.Tag, "column"), "autoincrement"), + Tag: removeSettingFromTag(ownField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), }) } @@ -248,7 +248,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, - Tag: removeSettingFromTag(removeSettingFromTag(relField.StructField.Tag, "column"), "autoincrement"), + Tag: removeSettingFromTag(relField.StructField.Tag, "column", "autoincrement", "index", "unique", "uniqueindex"), }) } diff --git a/schema/utils.go b/schema/utils.go index 29f2fefb..41bd9d60 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -51,8 +51,11 @@ func toColumns(val string) (results []string) { return } -func removeSettingFromTag(tag reflect.StructTag, name string) reflect.StructTag { - return reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) +func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag { + for _, name := range names { + tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) + } + return tag } // GetRelationsValues get relations's values from a reflect value From 3a97639880a6a965c5e8209e2ff5557008e8b191 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 10:40:37 +0800 Subject: [PATCH 50/65] Fix unordered joins, close #3267 --- callbacks/query.go | 8 ++++---- chainable_api.go | 5 +---- statement.go | 13 +++++++++---- tests/joins_test.go | 8 ++++++++ 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/callbacks/query.go b/callbacks/query.go index 5ae1e904..f6cb32d5 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -104,12 +104,12 @@ func BuildQuerySQL(db *gorm.DB) { } joins := []clause.Join{} - for name, conds := range db.Statement.Joins { + for _, join := range db.Statement.Joins { if db.Statement.Schema == nil { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[name]; ok { + } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { tableAliasName := relation.Name for _, s := range relation.FieldSchema.DBNames { @@ -149,7 +149,7 @@ func BuildQuerySQL(db *gorm.DB) { }) } else { joins = append(joins, clause.Join{ - Expression: clause.Expr{SQL: name, Vars: conds}, + Expression: clause.Expr{SQL: join.Name, Vars: join.Conds}, }) } } diff --git a/chainable_api.go b/chainable_api.go index 9b46a95b..e1b73457 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -172,10 +172,7 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() - if tx.Statement.Joins == nil { - tx.Statement.Joins = map[string][]interface{}{} - } - tx.Statement.Joins[query] = args + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) return } diff --git a/statement.go b/statement.go index 6114f468..214a15bb 100644 --- a/statement.go +++ b/statement.go @@ -29,7 +29,7 @@ type Statement struct { Distinct bool Selects []string // selected columns Omits []string // omit columns - Joins map[string][]interface{} + Joins []join Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool @@ -44,6 +44,11 @@ type Statement struct { assigns []interface{} } +type join struct { + Name string + Conds []interface{} +} + // StatementModifier statement modifier interface type StatementModifier interface { ModifyStatement(*Statement) @@ -401,7 +406,6 @@ func (stmt *Statement) clone() *Statement { Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, - Joins: map[string][]interface{}{}, Preloads: map[string][]interface{}{}, ConnPool: stmt.ConnPool, Schema: stmt.Schema, @@ -417,8 +421,9 @@ func (stmt *Statement) clone() *Statement { newStmt.Preloads[k] = p } - for k, j := range stmt.Joins { - newStmt.Joins[k] = j + if len(stmt.Joins) > 0 { + newStmt.Joins = make([]join, len(stmt.Joins)) + copy(newStmt.Joins, stmt.Joins) } stmt.Settings.Range(func(k, v interface{}) bool { diff --git a/tests/joins_test.go b/tests/joins_test.go index e54d3784..f78ddf67 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "regexp" "sort" "testing" @@ -88,6 +89,13 @@ func TestJoinConds(t *testing.T) { if db5.Error != nil { t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) } + + dryDB := DB.Session(&gorm.Session{DryRun: true}) + stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement + + if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { + t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) + } } func TestJoinsWithSelect(t *testing.T) { From cc6a64adfb0ed47d5f8ccf8de13eaf8145656973 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 15:40:19 +0800 Subject: [PATCH 51/65] Support smart migrate, close #3078 --- migrator.go | 1 + migrator/migrator.go | 63 ++++++++++++++++++++++++++++++++-- schema/field.go | 5 +++ statement.go | 1 - tests/go.mod | 6 ++-- tests/migrate_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 149 insertions(+), 7 deletions(-) diff --git a/migrator.go b/migrator.go index 37051f81..ed8a8e26 100644 --- a/migrator.go +++ b/migrator.go @@ -42,6 +42,7 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error + MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) diff --git a/migrator/migrator.go b/migrator/migrator.go index d50159dd..d93b8a6d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "reflect" + "regexp" "strings" "gorm.io/gorm" @@ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { - // TODO smart migrate data type for _, value := range m.ReorderModels(values, true) { tx := m.DB.Session(&gorm.Session{}) if !tx.Migrator().HasTable(value) { @@ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, field := range stmt.Schema.FieldsByDBName { - if !tx.Migrator().HasColumn(value, field.DBName) { + var foundColumn *sql.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { return err } + } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + // found, smart migrate + return err } } @@ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } return nil }); err != nil { - fmt.Println(err) return err } } @@ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error { + // found, smart migrate + fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) + realDataType := strings.ToLower(columnType.DatabaseTypeName()) + + alterColumn := false + + // check size + if length, _ := columnType.Length(); length != int64(field.Size) { + if length > 0 && field.Size > 0 { + alterColumn = true + } else { + // has size in data type and not equal + matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) + matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) + if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { + alterColumn = true + } + } + } + + // check precision + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { + if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) { + alterColumn = true + } + } + + // check nullable + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + // not primary key & database is nullable + if !field.PrimaryKey && nullable { + alterColumn = true + } + } + + if alterColumn { + return m.DB.Migrator().AlterColumn(value, field.Name) + } + + return nil +} + func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() diff --git a/schema/field.go b/schema/field.go index 497aa02d..524d19fb 100644 --- a/schema/field.go +++ b/schema/field.go @@ -55,6 +55,7 @@ type Field struct { Comment string Size int Precision int + Scale int FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField @@ -160,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Precision, _ = strconv.Atoi(p) } + if s, ok := field.TagSettings["SCALE"]; ok { + field.Scale, _ = strconv.Atoi(s) + } + if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { field.NotNull = true } diff --git a/statement.go b/statement.go index 214a15bb..95d23fa5 100644 --- a/statement.go +++ b/statement.go @@ -379,7 +379,6 @@ func (stmt *Statement) Build(clauses ...string) { } } } - // TODO handle named vars } func (stmt *Statement) Parse(value interface{}) (err error) { diff --git a/tests/go.mod b/tests/go.mod index 54a808d0..9d4e892d 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.3.1 - gorm.io/driver/postgres v0.2.6 + gorm.io/driver/mysql v0.3.2 + gorm.io/driver/postgres v0.2.9 gorm.io/driver/sqlite v1.0.9 gorm.io/driver/sqlserver v0.2.7 - gorm.io/gorm v0.2.19 + gorm.io/gorm v0.2.36 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 1b002049..4cc8a7c3 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -47,6 +47,86 @@ func TestMigrate(t *testing.T) { } } +func TestSmartMigrateColumn(t *testing.T) { + type UserMigrateColumn struct { + ID uint + Name string + Salary float64 + Birthday time.Time + } + + DB.Migrator().DropTable(&UserMigrateColumn{}) + + DB.AutoMigrate(&UserMigrateColumn{}) + + type UserMigrateColumn2 struct { + ID uint + Name string `gorm:"size:128"` + Salary float64 `gorm:"precision:2"` + Birthday time.Time `gorm:"precision:2"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 128 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + + type UserMigrateColumn3 struct { + ID uint + Name string `gorm:"size:256"` + Salary float64 `gorm:"precision:3"` + Birthday time.Time `gorm:"precision:3"` + } + + if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { + t.Fatalf("failed to auto migrate, got error: %v", err) + } + + columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) + if err != nil { + t.Fatalf("failed to get column types, got error: %v", err) + } + + for _, columnType := range columnTypes { + switch columnType.Name() { + case "name": + if length, _ := columnType.Length(); length != 0 && length != 256 { + t.Fatalf("name's length should be 128, but got %v", length) + } + case "salary": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("salary's precision should be 2, but got %v", precision) + } + case "birthday": + if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + t.Fatalf("birthday's precision should be 2, but got %v", precision) + } + } + } + +} + func TestMigrateWithComment(t *testing.T) { type UserWithComment struct { gorm.Model From ebdb4edda8363fdd79c87ab323ca19b2be7a8872 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Aug 2020 20:08:23 +0800 Subject: [PATCH 52/65] Add AllowGlobalUpdate mode --- callbacks/delete.go | 2 +- callbacks/update.go | 2 +- gorm.go | 7 +++++++ soft_delete.go | 2 +- tests/delete_test.go | 4 ++++ tests/update_test.go | 4 ++++ 6 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 288f2d69..f444f020 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -51,7 +51,7 @@ func Delete(db *gorm.DB) { } } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { db.AddError(gorm.ErrMissingWhereClause) return } diff --git a/callbacks/update.go b/callbacks/update.go index 5656d166..bd8a4150 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -69,7 +69,7 @@ func Update(db *gorm.DB) { db.Statement.Build("UPDATE", "SET", "WHERE") } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { db.AddError(gorm.ErrMissingWhereClause) return } diff --git a/gorm.go b/gorm.go index 1ace0099..3c187f42 100644 --- a/gorm.go +++ b/gorm.go @@ -32,6 +32,8 @@ type Config struct { DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool + // AllowGlobalUpdate allow global update + AllowGlobalUpdate bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -61,6 +63,7 @@ type Session struct { PrepareStmt bool WithConditions bool SkipDefaultTransaction bool + AllowGlobalUpdate bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -154,6 +157,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.SkipDefaultTransaction = true } + if config.AllowGlobalUpdate { + txConfig.AllowGlobalUpdate = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx diff --git a/soft_delete.go b/soft_delete.go index 875623bc..d33bf866 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -98,7 +98,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !ok { + if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { stmt.DB.AddError(ErrMissingWhereClause) return } diff --git a/tests/delete_test.go b/tests/delete_test.go index f5b3e784..09c1a075 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -118,4 +118,8 @@ func TestBlockGlobalDelete(t *testing.T) { if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while deleting error") } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } } diff --git a/tests/update_test.go b/tests/update_test.go index 49a13be9..e52dc652 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -222,6 +222,10 @@ func TestBlockGlobalUpdate(t *testing.T) { if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) } + + if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil { + t.Errorf("should returns no error while enable global update, but got err %v", err) + } } func TestSelectWithUpdate(t *testing.T) { From 84dbb36d3bd91a5e7b3c1ee5a617ea923a4098d2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 24 Aug 2020 20:24:25 +0800 Subject: [PATCH 53/65] Add Golang v1.15 --- .github/workflows/tests.yml | 10 +++++----- tests/default_value_test.go | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b626ce94..4388c31d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest, macos-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -38,7 +38,7 @@ jobs: sqlite_windows: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [windows-latest] runs-on: ${{ matrix.platform }} @@ -64,7 +64,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -108,7 +108,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] # can not run in macOS and widnowsOS runs-on: ${{ matrix.platform }} @@ -150,7 +150,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.14', '1.13'] + go: ['1.15', '1.14', '1.13'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} diff --git a/tests/default_value_test.go b/tests/default_value_test.go index ea496d60..aa4a511a 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -10,7 +10,7 @@ func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model Email string `gorm:"not null;index:,unique"` - Name string `gorm:"not null;default:'foo'"` + Name string `gorm:"not null;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;not null;default:''"` Age int `gorm:"default:18"` From 3dfa8a66f1bef0a7469c34968cb298c208e59fb9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 17:27:28 +0800 Subject: [PATCH 54/65] Fix panic when delet without pointer, close #3308 --- callbacks/delete.go | 12 ++++++------ soft_delete.go | 5 ----- tests/delete_test.go | 4 ++++ 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index f444f020..76b78fb4 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -41,7 +41,7 @@ func Delete(db *gorm.DB) { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } - if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) @@ -51,15 +51,15 @@ func Delete(db *gorm.DB) { } } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } - db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.Build("DELETE", "FROM", "WHERE") } + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + if !db.DryRun && db.Error == nil { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/soft_delete.go b/soft_delete.go index d33bf866..484f565c 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -98,11 +98,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - return - } - stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build("UPDATE", "SET", "WHERE") } diff --git a/tests/delete_test.go b/tests/delete_test.go index 09c1a075..17299677 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -48,6 +48,10 @@ func TestDelete(t *testing.T) { t.Errorf("errors happened when delete: %v", err) } + if err := DB.Delete(User{}).Error; err != gorm.ErrMissingWhereClause { + t.Errorf("errors happened when delete: %v", err) + } + if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("should returns record not found error, but got %v", err) } From 0f3201e73b97c358d2b7d98d24185fab91e5dd73 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 18:18:16 +0800 Subject: [PATCH 55/65] friendly invalid field error message --- schema/relationship.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index c8d129f2..dad2e629 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -336,7 +336,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue primarySchema, foreignSchema = schema, relation.FieldSchema ) - reguessOrErr := func(err string, args ...interface{}) { + reguessOrErr := func() { switch gl { case guessHas: schema.guessRelation(relation, field, guessEmbeddedHas) @@ -345,7 +345,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) default: - schema.err = fmt.Errorf(err, args...) + schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) } } @@ -354,7 +354,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if field.OwnerSchema != nil { primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } else { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } case guessBelongs: @@ -363,7 +363,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if field.OwnerSchema != nil { primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema } else { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } } @@ -373,7 +373,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if f := foreignSchema.LookUpField(foreignKey); f != nil { foreignFields = append(foreignFields, f) } else { - reguessOrErr("unsupported relations %v for %v on field %v with foreign keys %v", relation.FieldSchema, schema, field.Name, relation.foreignKeys) + reguessOrErr() return } } @@ -392,7 +392,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } if len(foreignFields) == 0 { - reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl) + reguessOrErr() return } else if len(relation.primaryKeys) > 0 { for idx, primaryKey := range relation.primaryKeys { @@ -400,11 +400,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue if len(primaryFields) < idx+1 { primaryFields = append(primaryFields, f) } else if f != primaryFields[idx] { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) + reguessOrErr() return } } else { - reguessOrErr("unsupported relations %v for %v on field %v with primary keys %v", relation.FieldSchema, schema, field.Name, relation.primaryKeys) + reguessOrErr() return } } @@ -414,7 +414,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue } else if len(primarySchema.PrimaryFields) == len(foreignFields) { primaryFields = append(primaryFields, primarySchema.PrimaryFields...) } else { - reguessOrErr("unsupported relations %v for %v on field %v", relation.FieldSchema, schema, field.Name) + reguessOrErr() return } } From 3195ae12072f51d15064a3428f4e906c6873c4e2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 25 Aug 2020 18:59:19 +0800 Subject: [PATCH 56/65] Allow override alias table in preload conditions --- callbacks/preload.go | 6 +++--- tests/preload_test.go | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index cd09a6d6..25b8cb2b 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -50,7 +50,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues) - tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error) // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) @@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } reflectResults := rel.FieldSchema.MakeSlice().Elem() - column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues) + column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { @@ -103,7 +103,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { } } - tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) fieldValues := make([]interface{}, len(relForeignFields)) diff --git a/tests/preload_test.go b/tests/preload_test.go index 3caa17b4..7e5d2622 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -5,6 +5,7 @@ import ( "strconv" "testing" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -108,6 +109,20 @@ func TestPreloadWithConds(t *testing.T) { } CheckUser(t, users2[0], users[0]) + + var users3 []User + if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB { + return tx.Table("accounts AS a").Select("a.*") + }).Find(&users3, "id IN ?", userIDs).Error; err != nil { + t.Errorf("failed to query, got error %v", err) + } + sort.Slice(users3, func(i, j int) bool { + return users2[i].ID < users2[j].ID + }) + + for i, u := range users3 { + CheckUser(t, u, users[i]) + } } func TestNestedPreloadWithConds(t *testing.T) { From 0d96f99499f2501a0d3a5e0d93ef157cc287e44f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 26 Aug 2020 12:22:11 +0800 Subject: [PATCH 57/65] Update README --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index b51297c4..c727e2cf 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Composite Primary Key * Auto Migrations * Logger -* Extendable, write Plugins based on GORM callbacks +* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… * Every feature comes with tests * Developer Friendly @@ -40,4 +40,3 @@ The fantastic ORM library for Golang, aims to be developer friendly. © Jinzhu, 2013~time.Now Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) - From ce8853e7a6142420a786be1b0f0c5ffeb8778778 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 15:03:57 +0800 Subject: [PATCH 58/65] Add GormValuer interface support --- README.md | 2 +- callbacks/create.go | 8 +++--- callbacks/delete.go | 4 +-- callbacks/interfaces.go | 39 ++++++++++++++++++++++++++++ callbacks/query.go | 2 +- callbacks/update.go | 8 +++--- interfaces.go | 37 +++------------------------ schema/interfaces.go | 4 ++- statement.go | 2 ++ tests/scanner_valuer_test.go | 49 ++++++++++++++++++++++++++++++++++++ 10 files changed, 108 insertions(+), 47 deletions(-) create mode 100644 callbacks/interfaces.go diff --git a/README.md b/README.md index c727e2cf..9c0aded0 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Context, Prepared Statment Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map -* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg +* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key * Auto Migrations * Logger diff --git a/callbacks/create.go b/callbacks/create.go index 7a32ed5c..cc7e2671 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -12,14 +12,14 @@ func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { - if i, ok := value.(gorm.BeforeCreateInterface); ok { + if i, ok := value.(BeforeCreateInterface); ok { called = true db.AddError(i.BeforeCreate(tx)) } @@ -203,14 +203,14 @@ func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterCreate { - if i, ok := value.(gorm.AfterCreateInterface); ok { + if i, ok := value.(AfterCreateInterface); ok { called = true db.AddError(i.AfterCreate(tx)) } diff --git a/callbacks/delete.go b/callbacks/delete.go index 76b78fb4..e95117a1 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -11,7 +11,7 @@ import ( func BeforeDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.BeforeDeleteInterface); ok { + if i, ok := value.(BeforeDeleteInterface); ok { db.AddError(i.BeforeDelete(tx)) return true } @@ -75,7 +75,7 @@ func Delete(db *gorm.DB) { func AfterDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterDeleteInterface); ok { + if i, ok := value.(AfterDeleteInterface); ok { db.AddError(i.AfterDelete(tx)) return true } diff --git a/callbacks/interfaces.go b/callbacks/interfaces.go new file mode 100644 index 00000000..2302470f --- /dev/null +++ b/callbacks/interfaces.go @@ -0,0 +1,39 @@ +package callbacks + +import "gorm.io/gorm" + +type BeforeCreateInterface interface { + BeforeCreate(*gorm.DB) error +} + +type AfterCreateInterface interface { + AfterCreate(*gorm.DB) error +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*gorm.DB) error +} + +type AfterUpdateInterface interface { + AfterUpdate(*gorm.DB) error +} + +type BeforeSaveInterface interface { + BeforeSave(*gorm.DB) error +} + +type AfterSaveInterface interface { + AfterSave(*gorm.DB) error +} + +type BeforeDeleteInterface interface { + BeforeDelete(*gorm.DB) error +} + +type AfterDeleteInterface interface { + AfterDelete(*gorm.DB) error +} + +type AfterFindInterface interface { + AfterFind(*gorm.DB) error +} diff --git a/callbacks/query.go b/callbacks/query.go index f6cb32d5..0703b92e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -214,7 +214,7 @@ func Preload(db *gorm.DB) { func AfterQuery(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { callMethod(db, func(value interface{}, tx *gorm.DB) bool { - if i, ok := value.(gorm.AfterFindInterface); ok { + if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) return true } diff --git a/callbacks/update.go b/callbacks/update.go index bd8a4150..73c062b4 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -32,14 +32,14 @@ func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { - if i, ok := value.(gorm.BeforeSaveInterface); ok { + if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { - if i, ok := value.(gorm.BeforeUpdateInterface); ok { + if i, ok := value.(BeforeUpdateInterface); ok { called = true db.AddError(i.BeforeUpdate(tx)) } @@ -90,14 +90,14 @@ func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { - if i, ok := value.(gorm.AfterSaveInterface); ok { + if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterUpdate { - if i, ok := value.(gorm.AfterUpdateInterface); ok { + if i, ok := value.(AfterUpdateInterface); ok { called = true db.AddError(i.AfterUpdate(tx)) } diff --git a/interfaces.go b/interfaces.go index b2ce59b3..e933952b 100644 --- a/interfaces.go +++ b/interfaces.go @@ -53,38 +53,7 @@ type TxCommitter interface { Rollback() error } -type BeforeCreateInterface interface { - BeforeCreate(*DB) error -} - -type AfterCreateInterface interface { - AfterCreate(*DB) error -} - -type BeforeUpdateInterface interface { - BeforeUpdate(*DB) error -} - -type AfterUpdateInterface interface { - AfterUpdate(*DB) error -} - -type BeforeSaveInterface interface { - BeforeSave(*DB) error -} - -type AfterSaveInterface interface { - AfterSave(*DB) error -} - -type BeforeDeleteInterface interface { - BeforeDelete(*DB) error -} - -type AfterDeleteInterface interface { - AfterDelete(*DB) error -} - -type AfterFindInterface interface { - AfterFind(*DB) error +// Valuer gorm valuer interface +type Valuer interface { + GormValue(context.Context, *DB) clause.Expr } diff --git a/schema/interfaces.go b/schema/interfaces.go index e8e51e4c..98abffbd 100644 --- a/schema/interfaces.go +++ b/schema/interfaces.go @@ -1,6 +1,8 @@ package schema -import "gorm.io/gorm/clause" +import ( + "gorm.io/gorm/clause" +) type GormDataTypeInterface interface { GormDataType() string diff --git a/statement.go b/statement.go index 95d23fa5..fba1991d 100644 --- a/statement.go +++ b/statement.go @@ -161,6 +161,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { stmt.Vars = append(stmt.Vars, v.Value) case clause.Column, clause.Table: stmt.QuoteTo(writer, v) + case Valuer: + stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) case clause.Expr: var varStr strings.Builder var sql = v.SQL diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index ce8a2b50..ec16ccf6 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -1,16 +1,20 @@ package tests_test import ( + "context" "database/sql" "database/sql/driver" "encoding/json" "errors" + "fmt" "reflect" + "regexp" "strconv" "testing" "time" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -305,3 +309,48 @@ func (t EmptyTime) Value() (driver.Value, error) { type NullString struct { sql.NullString } + +type Point struct { + X, Y int +} + +func (point *Point) Scan(v interface{}) error { + return nil +} + +func (point Point) GormDataType() string { + return "geo" +} + +func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { + return clause.Expr{ + SQL: "ST_PointFromText(?)", + Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)}, + } +} + +func TestGORMValuer(t *testing.T) { + type UserWithPoint struct { + Name string + Point Point + } + + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + + stmt := dryRunDB.Create(&UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if stmt.SQL.String() == "" || len(stmt.Vars) != 2 { + t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) + } + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } +} From 7a90496701f7b81e06daaa134a8f8853c1f935d5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 16:27:59 +0800 Subject: [PATCH 59/65] Test create from sql expr with map --- callbacks/create.go | 4 ++++ callbacks/helper.go | 12 ++++++++---- tests/scanner_valuer_test.go | 26 ++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/callbacks/create.go b/callbacks/create.go index cc7e2671..c59b14b5 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -225,8 +225,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) + case *map[string]interface{}: + values = ConvertMapToValuesForCreate(stmt, *value) case []map[string]interface{}: values = ConvertSliceOfMapToValuesForCreate(stmt, value) + case *[]map[string]interface{}: + values = ConvertSliceOfMapToValuesForCreate(stmt, *value) default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) diff --git a/callbacks/helper.go b/callbacks/helper.go index 80fbc2a1..e0a66dd2 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -20,8 +20,10 @@ func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]inter for _, k := range keys { value := mapValue[k] - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { @@ -46,8 +48,10 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st for idx, mapValue := range mapValues { for k, v := range mapValue { - if field := stmt.Schema.LookUpField(k); field != nil { - k = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(k); field != nil { + k = field.DBName + } } if _, ok := result[k]; !ok { diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index ec16ccf6..dbf5adac 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -353,4 +353,30 @@ func TestGORMValuer(t *testing.T) { if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } + + stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } + + stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{ + "Name": "jinzhu", + "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, + }).Statement + + if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } } From cd54dddd94a992edd446611aeccc939a64ad2658 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 18:42:40 +0800 Subject: [PATCH 60/65] Test update with GormValuer --- tests/go.mod | 2 +- tests/scanner_valuer_test.go | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 9d4e892d..b0ed4497 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v0.3.2 gorm.io/driver/postgres v0.2.9 gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.7 + gorm.io/driver/sqlserver v0.2.8 gorm.io/gorm v0.2.36 ) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index dbf5adac..f42daae7 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -314,10 +314,6 @@ type Point struct { X, Y int } -func (point *Point) Scan(v interface{}) error { - return nil -} - func (point Point) GormDataType() string { return "geo" } @@ -379,4 +375,19 @@ func TestGORMValuer(t *testing.T) { if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } + + stmt = dryRunDB.Session(&gorm.Session{ + AllowGlobalUpdate: true, + }).Model(&UserWithPoint{}).Updates(UserWithPoint{ + Name: "jinzhu", + Point: Point{X: 100, Y: 100}, + }).Statement + + if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { + t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + } + + if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { + t.Errorf("generated vars is not equal, got %v", stmt.Vars) + } } From d50dbb0896100640d61a8b4017aa46946f3bc6c5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 19:15:40 +0800 Subject: [PATCH 61/65] Fix check valid db name, close #3315 --- chainable_api.go | 6 +++--- finisher_api.go | 2 +- utils/utils.go | 4 ++-- utils/utils_test.go | 14 ++++++++++++++ 4 files changed, 20 insertions(+), 6 deletions(-) create mode 100644 utils/utils_test.go diff --git a/chainable_api.go b/chainable_api.go index e1b73457..c8417a6d 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -93,7 +93,7 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { } delete(tx.Statement.Clauses, "SELECT") case string: - fields := strings.FieldsFunc(v, utils.IsChar) + fields := strings.FieldsFunc(v, utils.IsValidDBNameChar) // normal field names if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { @@ -133,7 +133,7 @@ func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) } else { tx.Statement.Omits = columns } @@ -180,7 +180,7 @@ func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() - fields := strings.FieldsFunc(name, utils.IsChar) + fields := strings.FieldsFunc(name, utils.IsValidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) diff --git a/finisher_api.go b/finisher_api.go index cf46f78a..2cde3c31 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -362,7 +362,7 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx.AddError(ErrModelValueRequired) } - fields := strings.FieldsFunc(column, utils.IsChar) + fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, diff --git a/utils/utils.go b/utils/utils.go index e93f3055..71336f4b 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -29,8 +29,8 @@ func FileWithLineNum() string { return "" } -func IsChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' +func IsValidDBNameChar(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' } func CheckTruth(val interface{}) bool { diff --git a/utils/utils_test.go b/utils/utils_test.go new file mode 100644 index 00000000..5737c511 --- /dev/null +++ b/utils/utils_test.go @@ -0,0 +1,14 @@ +package utils + +import ( + "strings" + "testing" +) + +func TestIsValidDBNameChar(t *testing.T) { + for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} { + if fields := strings.FieldsFunc(db, IsValidDBNameChar); len(fields) != 1 { + t.Fatalf("failed to parse db name %v", db) + } + } +} From dacbaa5f02bf40efa5d8841047c047f7a5340d9f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 27 Aug 2020 19:52:01 +0800 Subject: [PATCH 62/65] Fix update attrs order --- callbacks/update.go | 6 ++++-- tests/scanner_valuer_test.go | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/callbacks/update.go b/callbacks/update.go index 73c062b4..46f59157 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -199,7 +199,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } if !stmt.UpdatingColumn && stmt.Schema != nil { - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { now := stmt.DB.NowFunc() @@ -222,7 +223,8 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) - for _, field := range stmt.Schema.FieldsByDBName { + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.LookUpField(dbName) if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(updatingValue) diff --git a/tests/scanner_valuer_test.go b/tests/scanner_valuer_test.go index f42daae7..fb1f5791 100644 --- a/tests/scanner_valuer_test.go +++ b/tests/scanner_valuer_test.go @@ -384,7 +384,7 @@ func TestGORMValuer(t *testing.T) { }).Statement if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { - t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) + t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String()) } if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { From c19a3abefb2aef853e4541ae1af7fa93f2dc0848 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Aug 2020 11:31:13 +0800 Subject: [PATCH 63/65] Fix self-referential belongs to, close #3319 --- association.go | 4 ++-- schema/relationship.go | 34 +++++++++++++++++++--------------- schema/relationship_test.go | 14 ++++++++++++++ schema/schema_test.go | 2 +- 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/association.go b/association.go index e59b8938..25e1fe8d 100644 --- a/association.go +++ b/association.go @@ -54,7 +54,7 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro for _, queryClause := range association.Relationship.JoinTable.QueryClauses { joinStmt.AddClause(queryClause) } - joinStmt.Build("WHERE", "LIMIT") + joinStmt.Build("WHERE") tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } @@ -112,7 +112,7 @@ func (association *Association) Replace(values ...interface{}) error { updateMap[ref.ForeignKey.DBName] = nil } - association.DB.UpdateColumns(updateMap) + association.Error = association.DB.UpdateColumns(updateMap).Error } case schema.HasOne, schema.HasMany: var ( diff --git a/schema/relationship.go b/schema/relationship.go index dad2e629..5132ff74 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -82,7 +82,9 @@ func (schema *Schema) parseRelation(field *Field) { schema.buildMany2ManyRelation(relation, field, many2many) } else { switch field.IndirectFieldType.Kind() { - case reflect.Struct, reflect.Slice: + case reflect.Struct: + schema.guessRelation(relation, field, guessBelongs) + case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) @@ -324,10 +326,10 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel type guessLevel int const ( - guessHas guessLevel = iota - guessEmbeddedHas - guessBelongs + guessBelongs guessLevel = iota guessEmbeddedBelongs + guessHas + guessEmbeddedHas ) func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) { @@ -338,25 +340,19 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue reguessOrErr := func() { switch gl { - case guessHas: - schema.guessRelation(relation, field, guessEmbeddedHas) - case guessEmbeddedHas: - schema.guessRelation(relation, field, guessBelongs) case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) + case guessEmbeddedBelongs: + schema.guessRelation(relation, field, guessHas) + case guessHas: + schema.guessRelation(relation, field, guessEmbeddedHas) + // case guessEmbeddedHas: default: schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) } } switch gl { - case guessEmbeddedHas: - if field.OwnerSchema != nil { - primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema - } else { - reguessOrErr() - return - } case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: @@ -366,6 +362,14 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue reguessOrErr() return } + case guessHas: + case guessEmbeddedHas: + if field.OwnerSchema != nil { + primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema + } else { + reguessOrErr() + return + } } if len(relation.foreignKeys) > 0 { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 2c09f528..2e85c538 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -55,6 +55,20 @@ func TestBelongsToOverrideReferences(t *testing.T) { }) } +func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { + type User struct { + ID int32 `gorm:"primaryKey"` + Name string + CreatedBy *int32 + Creator *User `gorm:"foreignKey:CreatedBy;references:ID"` + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", + References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}}, + }) +} + func TestHasOneOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/schema/schema_test.go b/schema/schema_test.go index 8bd1e5ca..4d13ebd2 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -171,7 +171,7 @@ func TestNestedModel(t *testing.T) { fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, - {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Int, Size: 64}, + {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64}, {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, } From 94c6bb980b8c3775d98121d5d42109cefe596c5c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Aug 2020 12:25:25 +0800 Subject: [PATCH 64/65] Refactor association --- association.go | 92 ++++++++++++++++++++------------------------------ 1 file changed, 37 insertions(+), 55 deletions(-) diff --git a/association.go b/association.go index 25e1fe8d..db77cc4e 100644 --- a/association.go +++ b/association.go @@ -43,32 +43,8 @@ func (db *DB) Association(column string) *Association { func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { - var ( - queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - tx = association.DB.Model(out) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: queryConds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: queryConds}) - } - - association.Error = tx.Find(out, conds...).Error + association.Error = association.buildCondition().Find(out, conds...).Error } - return association.Error } @@ -80,7 +56,7 @@ func (association *Association) Append(values ...interface{}) error { association.Error = association.Replace(values...) } default: - association.saveAssociation(false, values...) + association.saveAssociation( /*clear*/ false, values...) } } @@ -90,7 +66,7 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { // save associations - association.saveAssociation(true, values...) + association.saveAssociation( /*clear*/ true, values...) // set old associations's foreign key to null reflectValue := association.DB.Statement.ReflectValue @@ -234,7 +210,7 @@ func (association *Association) Delete(values ...interface{}) error { var ( primaryFields, relPrimaryFields []*schema.Field joinPrimaryKeys, joinRelPrimaryKeys []string - modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + joinValue = reflect.New(rel.JoinTable.ModelType).Interface() ) for _, ref := range rel.References { @@ -259,10 +235,11 @@ func (association *Association) Delete(values ...interface{}) error { relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error + association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error } if association.Error == nil { + // clean up deleted values's foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { @@ -328,33 +305,8 @@ func (association *Association) Clear() error { func (association *Association) Count() (count int64) { if association.Error == nil { - var ( - conds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) - modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() - tx = association.DB.Model(modelValue) - ) - - if association.Relationship.JoinTable != nil { - if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { - joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} - for _, queryClause := range association.Relationship.JoinTable.QueryClauses { - joinStmt.AddClause(queryClause) - } - joinStmt.Build("WHERE", "LIMIT") - tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) - } - - tx.Clauses(clause.From{Joins: []clause.Join{{ - Table: clause.Table{Name: association.Relationship.JoinTable.Table}, - ON: clause.Where{Exprs: conds}, - }}}) - } else { - tx.Clauses(clause.Where{Exprs: conds}) - } - - association.Error = tx.Count(&count).Error + association.Error = association.buildCondition().Count(&count).Error } - return } @@ -435,6 +387,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if len(values) != reflectValue.Len() { + // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { @@ -467,6 +420,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Save(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: + // clear old data if clear && len(values) == 0 { association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) @@ -498,3 +452,31 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } } } + +func (association *Association) buildCondition() *DB { + var ( + queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) + modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() + tx = association.DB.Model(modelValue) + ) + + if association.Relationship.JoinTable != nil { + if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { + joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} + for _, queryClause := range association.Relationship.JoinTable.QueryClauses { + joinStmt.AddClause(queryClause) + } + joinStmt.Build("WHERE") + tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) + } + + tx.Clauses(clause.From{Joins: []clause.Join{{ + Table: clause.Table{Name: association.Relationship.JoinTable.Table}, + ON: clause.Where{Exprs: queryConds}, + }}}) + } else { + tx.Clauses(clause.Where{Exprs: queryConds}) + } + + return tx +} From 06461b32549fb13090b92713703228da2e8290aa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 28 Aug 2020 21:16:47 +0800 Subject: [PATCH 65/65] GORM V2.0.0 --- tests/go.mod | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index b0ed4497..1a6fe7a8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.3.2 - gorm.io/driver/postgres v0.2.9 - gorm.io/driver/sqlite v1.0.9 - gorm.io/driver/sqlserver v0.2.8 - gorm.io/gorm v0.2.36 + gorm.io/driver/mysql v1.0.0 + gorm.io/driver/postgres v1.0.0 + gorm.io/driver/sqlite v1.1.0 + gorm.io/driver/sqlserver v1.0.0 + gorm.io/gorm v1.9.19 ) replace gorm.io/gorm => ../