From 9b0ad4730f16d6ac7cf18d1aa42d74714959745b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 31 Aug 2020 12:08:33 +0800 Subject: [PATCH 01/67] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 759038a126122d5b3323979fdd7d867a4ab85585 Author: Jinzhu Date: Mon Aug 31 12:06:31 2020 +0800 Add PreparedStmt tests commit 066d54db1fc93ea58c190195104a2d7086623f69 Author: 王岚 Date: Fri Aug 28 18:40:59 2020 +0800 prepare_stmt add ctx --- gorm.go | 1 + prepare_stmt.go | 22 ++++++++--------- tests/prepared_stmt_test.go | 48 +++++++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 11 deletions(-) create mode 100644 tests/prepared_stmt_test.go diff --git a/gorm.go b/gorm.go index 3c187f42..fec4310b 100644 --- a/gorm.go +++ b/gorm.go @@ -169,6 +169,7 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { if v, ok := db.cacheStore.Load("preparedStmt"); ok { + tx.Statement = tx.Statement.clone() preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, diff --git a/prepare_stmt.go b/prepare_stmt.go index 7e87558d..7c80bafe 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -25,7 +25,7 @@ func (db *PreparedStmtDB) Close() { db.Mux.Unlock() } -func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok { db.Mux.RUnlock() @@ -40,7 +40,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { return stmt, nil } - stmt, err := db.ConnPool.PrepareContext(context.Background(), query) + stmt, err := db.ConnPool.PrepareContext(ctx, query) if err == nil { db.Stmts[query] = stmt db.PreparedSQL = append(db.PreparedSQL, query) @@ -59,7 +59,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn } func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := db.prepare(query) + stmt, err := db.prepare(ctx, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { @@ -73,7 +73,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. } func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := db.prepare(query) + stmt, err := db.prepare(ctx, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { @@ -87,7 +87,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := db.prepare(query) + stmt, err := db.prepare(ctx, query) if err == nil { return stmt.QueryRowContext(ctx, args...) } @@ -100,9 +100,9 @@ type PreparedStmtTX struct { } func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := tx.PreparedStmtDB.prepare(query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) + result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -114,9 +114,9 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. } func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := tx.PreparedStmtDB.prepare(query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.StmtContext(ctx, stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -128,9 +128,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := tx.PreparedStmtDB.prepare(query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) + return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) } return &sql.Row{} } diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go new file mode 100644 index 00000000..b81318d3 --- /dev/null +++ b/tests/prepared_stmt_test.go @@ -0,0 +1,48 @@ +package tests_test + +import ( + "context" + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestPreparedStmt(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + txCtx := tx.WithContext(ctx) + + user := *GetUser("prepared_stmt", Config{}) + + txCtx.Create(&user) + + var result1 User + if err := txCtx.Find(&result1, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + time.Sleep(time.Second) + + var result2 User + if err := tx.Find(&result2, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + user2 := *GetUser("prepared_stmt2", Config{}) + if err := txCtx.Create(&user2).Error; err == nil { + t.Fatalf("should failed to create with timeout context") + } + + if err := tx.Create(&user2).Error; err != nil { + t.Fatalf("failed to create, got error %v", err) + } + + var result3 User + if err := tx.Find(&result3, user2.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } +} From 496db1f13e51ef20db2a68f6591047df6b20e292 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 31 Aug 2020 15:45:56 +0800 Subject: [PATCH 02/67] Fix named argument with multiple line SQL, fix #3336 --- clause/expression.go | 2 +- prepare_stmt.go | 2 +- tests/go.mod | 2 ++ tests/named_argument_test.go | 14 +++++++++++++- 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 4d5e328b..3b914e68 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -94,7 +94,7 @@ func (expr NamedExpr) Build(builder Builder) { if v == '@' && !inName { inName = true name = []byte{} - } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' { + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' { if inName { if nv, ok := namedMap[string(name)]; ok { builder.AddVar(builder, nv) diff --git a/prepare_stmt.go b/prepare_stmt.go index 7c80bafe..de7e2a26 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -116,7 +116,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - rows, err = tx.Tx.StmtContext(ctx, stmt).QueryContext(ctx, args...) + rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() diff --git a/tests/go.mod b/tests/go.mod index c09747ab..f3dd6dbc 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -14,3 +14,5 @@ require ( ) replace gorm.io/gorm => ../ + +replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2 diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go index 56fad5f4..d0a6f915 100644 --- a/tests/named_argument_test.go +++ b/tests/named_argument_test.go @@ -48,10 +48,22 @@ func TestNamedArg(t *testing.T) { t.Errorf("failed to update with named arg") } + namedUser.Name1 = "jinzhu-new" + namedUser.Name2 = "jinzhu-new2" + namedUser.Name3 = "jinzhu-new" + var result5 NamedUser if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil { t.Errorf("failed to update with named arg") } - AssertEqual(t, result4, namedUser) + AssertEqual(t, result5, namedUser) + + var result6 NamedUser + if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name + AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result6, namedUser) } From 0273856e4d9744c98aa42b98d485d726099e9020 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 31 Aug 2020 16:27:22 +0800 Subject: [PATCH 03/67] Don't alter column with full column data type, close #3339 --- migrator/migrator.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index d93b8a6d..c736a3e0 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -297,10 +297,12 @@ func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { + fileType := clause.Expr{SQL: m.DataTypeOf(field)} return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType, ).Error + } return fmt.Errorf("failed to look up field with name: %s", field) }) From 162367be7d1d10aa59dc08bb507c356b4495c95e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 11:30:16 +0800 Subject: [PATCH 04/67] Fix multiple M2M relations on one table, close #3347 --- schema/relationship.go | 62 +++++++++++++++++++++---------------- schema/relationship_test.go | 31 +++++++++++++++++++ 2 files changed, 66 insertions(+), 27 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 5132ff74..aa992b84 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -254,12 +254,18 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel }) } + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: schema.Name + field.Name, + Type: schema.ModelType, + Tag: `gorm:"-"`, + }) + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many relation.JoinTable.Table = schema.namer.JoinTableName(many2many) - relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields)) + relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields)) relName := relation.Schema.Name relRefName := relation.FieldSchema.Name @@ -290,36 +296,38 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } // build references - for idx, f := range relation.JoinTable.Fields { - // use same data type for foreign keys - f.DataType = fieldsMap[f.Name].DataType - f.GORMDataType = fieldsMap[f.Name].GORMDataType - relation.JoinTable.PrimaryFields[idx] = f - ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] + for _, f := range relation.JoinTable.Fields { + if f.Creatable { + // use same data type for foreign keys + f.DataType = fieldsMap[f.Name].DataType + f.GORMDataType = fieldsMap[f.Name].GORMDataType + relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) + ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPriamryField { - joinRel := relation.JoinTable.Relationships.Relations[relName] - joinRel.Field = relation.Field - joinRel.References = append(joinRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - }) - } else { - joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] - if joinRefRel.Field == nil { - joinRefRel.Field = relation.Field + if ownPriamryField { + joinRel := relation.JoinTable.Relationships.Relations[relName] + joinRel.Field = relation.Field + joinRel.References = append(joinRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } else { + joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] + if joinRefRel.Field == nil { + joinRefRel.Field = relation.Field + } + joinRefRel.References = append(joinRefRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) } - joinRefRel.References = append(joinRefRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, + + relation.References = append(relation.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + OwnPrimaryKey: ownPriamryField, }) } - - relation.References = append(relation.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - OwnPrimaryKey: ownPriamryField, - }) } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 2e85c538..f2d63323 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -267,3 +267,34 @@ func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { }, ) } + +func TestMultipleMany2Many(t *testing.T) { + type Thing struct { + ID int + } + + type Person struct { + ID int + Likes []Thing `gorm:"many2many:likes"` + Dislikes []Thing `gorm:"many2many:dislikes"` + } + + checkStructRelation(t, &Person{}, + Relation{ + Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "likes", Table: "likes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "likes", "", true}, + {"ID", "Thing", "ThingID", "likes", "", false}, + }, + }, + Relation{ + Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "dislikes", "", true}, + {"ID", "Thing", "ThingID", "dislikes", "", false}, + }, + }, + ) +} From 308d22b166eb3b71d2a3374bfc565be29ed88eda Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 13:48:37 +0800 Subject: [PATCH 05/67] Clean up associations before Preload, close #3345 --- callbacks/preload.go | 10 ++++++++++ tests/helper_test.go | 10 +++++----- tests/preload_test.go | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 25b8cb2b..9b8f762a 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -107,6 +107,16 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { fieldValues := make([]interface{}, len(relForeignFields)) + // clean up old values before preloading + switch reflectValue.Kind() { + case reflect.Struct: + rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + } + } + for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { diff --git a/tests/helper_test.go b/tests/helper_test.go index cc0d808c..eee34e99 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -115,7 +115,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Pets", func(t *testing.T) { if len(user.Pets) != len(expect.Pets) { - t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) + t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) } sort.Slice(user.Pets, func(i, j int) bool { @@ -137,7 +137,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Toys", func(t *testing.T) { if len(user.Toys) != len(expect.Toys) { - t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) + t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) } sort.Slice(user.Toys, func(i, j int) bool { @@ -177,7 +177,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Team", func(t *testing.T) { if len(user.Team) != len(expect.Team) { - t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) + t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) } sort.Slice(user.Team, func(i, j int) bool { @@ -195,7 +195,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Languages", func(t *testing.T) { if len(user.Languages) != len(expect.Languages) { - t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) + t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) } sort.Slice(user.Languages, func(i, j int) bool { @@ -212,7 +212,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Friends", func(t *testing.T) { if len(user.Friends) != len(expect.Friends) { - t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) + t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) } sort.Slice(user.Friends, func(i, j int) bool { diff --git a/tests/preload_test.go b/tests/preload_test.go index 7e5d2622..76b72f14 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -31,6 +31,20 @@ func TestPreloadWithAssociations(t *testing.T) { var user2 User DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + var user3 = *GetUser("preload_with_associations_new", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) } func TestNestedPreload(t *testing.T) { From e98a4a3a4ef602a20803c1fc4deb3f8bdbf84fec Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 14:01:59 +0800 Subject: [PATCH 06/67] Change default timeout interval to avoid test fail on CI --- tests/prepared_stmt_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index b81318d3..af610165 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -12,7 +12,7 @@ import ( func TestPreparedStmt(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() txCtx := tx.WithContext(ctx) From e6f4b711a7e1f885a2200b22e40786cf0dacddcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=8B=E5=B0=8F=E5=8C=97?= Date: Tue, 1 Sep 2020 15:50:53 +0800 Subject: [PATCH 07/67] fix order case (#3350) --- chainable_api.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chainable_api.go b/chainable_api.go index c8417a6d..ae2ac4f1 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -198,7 +198,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { // Order specify order when retrieve records from database // db.Order("name DESC") -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() From e73147fa8e25bea98257444ae1d65e19a1af089d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 16:55:30 +0800 Subject: [PATCH 08/67] Better support for scan into map, fix unfriendly data type for interface, close #3351 --- scan.go | 72 +++++++++++++++++++----------- tests/query_test.go | 104 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 149 insertions(+), 27 deletions(-) diff --git a/scan.go b/scan.go index 0b199029..89d9a07a 100644 --- a/scan.go +++ b/scan.go @@ -2,12 +2,52 @@ package gorm import ( "database/sql" + "database/sql/driver" "reflect" "strings" "gorm.io/gorm/schema" ) +func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { + if db.Statement.Schema != nil { + for idx, name := range columns { + if field := db.Statement.Schema.LookUpField(name); field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + values[idx] = new(interface{}) + } + } else if len(columnTypes) > 0 { + for idx, columnType := range columnTypes { + if columnType.ScanType() != nil { + values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() + } else { + values[idx] = new(interface{}) + } + } + } else { + for idx := range columns { + values[idx] = new(interface{}) + } + } +} + +func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { + for idx, column := range columns { + if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { + mapValue[column] = reflectValue.Interface() + if valuer, ok := mapValue[column].(driver.Valuer); ok { + mapValue[column], _ = valuer.Value() + } else if b, ok := mapValue[column].(sql.RawBytes); ok { + mapValue[column] = string(b) + } + } else { + mapValue[column] = nil + } + } +} + func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) @@ -15,9 +55,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: if initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } + columnTypes, _ := rows.ColumnTypes() + prepareValues(values, db, columnTypes, columns) db.RowsAffected++ db.AddError(rows.Scan(values...)) @@ -28,38 +67,19 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { mapValue = *v } } - - for idx, column := range columns { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - mapValue[column] = nil - } else { - mapValue[column] = *v - } - } - } + scanIntoMap(mapValue, values, columns) } case *[]map[string]interface{}: + columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } + prepareValues(values, db, columnTypes, columns) initialized = false db.RowsAffected++ db.AddError(rows.Scan(values...)) mapValue := map[string]interface{}{} - for idx, column := range columns { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - mapValue[column] = nil - } else { - mapValue[column] = *v - } - } - } - + scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } case *int, *int64, *uint, *uint64, *float32, *float64: diff --git a/tests/query_test.go b/tests/query_test.go index d71c813a..6bb68cd3 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -6,6 +6,7 @@ import ( "regexp" "sort" "strconv" + "strings" "testing" "time" @@ -61,6 +62,54 @@ func TestFind(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := first[dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Age": + if _, ok := first[dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Birthday": + if _, ok := first[dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstMapWithTable", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(first[dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + } + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) }) @@ -86,13 +135,29 @@ func TestFind(t *testing.T) { t.Run("FirstSliceOfMap", func(t *testing.T) { var allMap = []map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) + t.Errorf("errors happened when query find: %v", err) } else { for idx, user := range users { t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := allMap[idx][dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Age": + if _, ok := allMap[idx][dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Birthday": + if _, ok := allMap[idx][dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + } + reflectValue := reflect.Indirect(reflect.ValueOf(user)) AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) }) @@ -101,6 +166,43 @@ func TestFind(t *testing.T) { } } }) + + t.Run("FindSliceOfMapWithTable", func(t *testing.T) { + var allMap = []map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query find: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) + } func TestQueryWithAssociation(t *testing.T) { From bf6123b01e265ecfe709738b290c3ea3f6ad9bdc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 18:05:26 +0800 Subject: [PATCH 09/67] Fix duplicated soft delete clause --- soft_delete.go | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/soft_delete.go b/soft_delete.go index 484f565c..b13fc63f 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -25,14 +25,7 @@ func (n DeletedAt) Value() (driver.Value, error) { } func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{ - clause.Where{Exprs: []clause.Expression{ - clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, - Value: nil, - }, - }}, - } + return []clause.Interface{SoftDeleteQueryClause{Field: f}} } type SoftDeleteQueryClause struct { From 22317b43c007f1a4aa21d6bf6c3e5088ce0ca507 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 18:58:16 +0800 Subject: [PATCH 10/67] Fix migrate field, failed to migrate when field size changed --- migrator/migrator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c736a3e0..1aebc50d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -356,9 +356,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn = true } else { // has size in data type and not equal - matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) + matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1) matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) - if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { + if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { alterColumn = true } } From d1e17d549fc3fb9a66e150d425e090dca838ab07 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 1 Sep 2020 20:52:06 +0800 Subject: [PATCH 11/67] request ColumnTypes after new session method --- migrator/migrator.go | 2 +- tests/go.mod | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 1aebc50d..29d26c9e 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -388,7 +388,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() + rows, err := m.DB.Session(&gorm.Session{}).Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() if err == nil { columnTypes, err = rows.ColumnTypes() } diff --git a/tests/go.mod b/tests/go.mod index f3dd6dbc..30a7dda7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v1.0.0 + gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.0 - gorm.io/driver/sqlserver v1.0.1 + gorm.io/driver/sqlite v1.1.1 + gorm.io/driver/sqlserver v1.0.2 gorm.io/gorm v1.9.19 ) From 9a101c8a089b724fc19af525fcdca58bff0b7997 Mon Sep 17 00:00:00 2001 From: aimuz Date: Tue, 1 Sep 2020 21:03:37 +0800 Subject: [PATCH 12/67] fmt.Sprint() to strconv.Format (#3354) --- logger/sql.go | 14 +++++++------- schema/field.go | 2 +- utils/utils.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 02d559c5..0efc0971 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -3,6 +3,7 @@ package logger import ( "database/sql/driver" "fmt" + "gorm.io/gorm/utils" "reflect" "regexp" "strconv" @@ -24,13 +25,12 @@ var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var convertParams func(interface{}, int) - var vars = make([]interface{}, len(avars)) - copy(vars, avars) + var vars = make([]string, len(avars)) convertParams = func(v interface{}, idx int) { switch v := v.(type) { case bool: - vars[idx] = fmt.Sprint(v) + vars[idx] = strconv.FormatBool(v) case time.Time: if v.IsZero() { vars[idx] = escaper + "0000-00-00 00:00:00" + escaper @@ -44,7 +44,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a vars[idx] = escaper + "" + escaper } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - vars[idx] = fmt.Sprintf("%d", v) + vars[idx] = utils.ToString(v) case float64, float32: vars[idx] = fmt.Sprintf("%.6f", v) case string: @@ -70,18 +70,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } } - for idx, v := range vars { + for idx, v := range avars { convertParams(v, idx) } if numericPlaceholder == nil { for _, v := range vars { - sql = strings.Replace(sql, "?", v.(string), 1) + sql = strings.Replace(sql, "?", v, 1) } } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { - sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1) + sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1) } } diff --git a/schema/field.go b/schema/field.go index 524d19fb..2e649d81 100644 --- a/schema/field.go +++ b/schema/field.go @@ -671,7 +671,7 @@ func (field *Field) setupValuerAndSetter() { case []byte: field.ReflectValueOf(value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValueOf(value).SetString(fmt.Sprint(data)) + field.ReflectValueOf(value).SetString(utils.ToString(data)) case float64, float32: field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: diff --git a/utils/utils.go b/utils/utils.go index 71336f4b..905001a5 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -83,3 +83,31 @@ func AssertEqual(src, dst interface{}) bool { } return true } + +func ToString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case int: + return strconv.FormatInt(int64(v), 10) + case int8: + return strconv.FormatInt(int64(v), 10) + case int16: + return strconv.FormatInt(int64(v), 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.FormatUint(uint64(v), 10) + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint16: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + } + return "" +} From dbaa6b0ec3f451903c2983fd091c52e5efc60669 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Sep 2020 16:14:26 +0800 Subject: [PATCH 13/67] Fix Scan struct with primary key, close #3357 --- callbacks.go | 2 ++ callbacks/row.go | 2 +- finisher_api.go | 19 ++++++++++++++----- logger/sql.go | 3 ++- migrator.go | 2 +- tests/scan_test.go | 18 +++++++++++++++--- 6 files changed, 35 insertions(+), 11 deletions(-) diff --git a/callbacks.go b/callbacks.go index baeb6c09..eace06ca 100644 --- a/callbacks.go +++ b/callbacks.go @@ -79,6 +79,8 @@ func (p *processor) Execute(db *DB) { if stmt.Model == nil { stmt.Model = stmt.Dest + } else if stmt.Dest == nil { + stmt.Dest = stmt.Model } if stmt.Model != nil { diff --git a/callbacks/row.go b/callbacks/row.go index 7e70382e..a36c0116 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,7 +11,7 @@ func RowQuery(db *gorm.DB) { } if !db.DryRun { - if _, ok := db.Get("rows"); ok { + if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/finisher_api.go b/finisher_api.go index a205b859..1d5ef5fc 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -331,13 +331,13 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance() + tx := db.getInstance().InstanceSet("rows", false) tx.callbacks.Row().Execute(tx) return tx.Statement.Dest.(*sql.Row) } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.Set("rows", true) + tx := db.getInstance().InstanceSet("rows", true) tx.callbacks.Row().Execute(tx) return tx.Statement.Dest.(*sql.Rows), tx.Error } @@ -345,8 +345,14 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) + if rows, err := tx.Rows(); err != nil { + tx.AddError(err) + } else { + defer rows.Close() + if rows.Next() { + tx.ScanRows(rows, dest) + } + } return } @@ -379,7 +385,10 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() tx.Error = tx.Statement.Parse(dest) tx.Statement.Dest = dest - tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) + tx.Statement.ReflectValue = reflect.ValueOf(dest) + for tx.Statement.ReflectValue.Kind() == reflect.Ptr { + tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() + } Scan(rows, tx, true) return tx.Error } diff --git a/logger/sql.go b/logger/sql.go index 0efc0971..80645b0c 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -3,13 +3,14 @@ package logger import ( "database/sql/driver" "fmt" - "gorm.io/gorm/utils" "reflect" "regexp" "strconv" "strings" "time" "unicode" + + "gorm.io/gorm/utils" ) func isPrintable(s []byte) bool { diff --git a/migrator.go b/migrator.go index ed8a8e26..162fe680 100644 --- a/migrator.go +++ b/migrator.go @@ -9,7 +9,7 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator(db) + return db.Dialector.Migrator(db.Session(&Session{WithConditions: true})) } // AutoMigrate run auto migration for given models diff --git a/tests/scan_test.go b/tests/scan_test.go index d6a372bb..3e66a25a 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -16,14 +17,25 @@ func TestScan(t *testing.T) { DB.Save(&user1).Save(&user2).Save(&user3) type result struct { + ID uint Name string Age int } var res result - DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res) - if res.Name != user3.Name || res.Age != int(user3.Age) { - t.Errorf("Scan into struct should work") + DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) + if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) + } + + DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } var doubleAgeRes = &result{} From 680dda2c159d21c0b8f677b25519ec7fec29cd4d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Sep 2020 20:09:51 +0800 Subject: [PATCH 14/67] Fix combine conditions when using string conditions, close #3358 --- clause/where.go | 52 ++++++++++++++++++++++++++++++++++++- tests/sql_builder_test.go | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/clause/where.go b/clause/where.go index 9af9701c..a3774e1c 100644 --- a/clause/where.go +++ b/clause/where.go @@ -1,5 +1,9 @@ package clause +import ( + "strings" +) + // Where where clause type Where struct { Exprs []Expression @@ -22,6 +26,7 @@ func (where Where) Build(builder Builder) { } } + wrapInParentheses := false for idx, expr := range where.Exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { @@ -31,7 +36,36 @@ func (where Where) Build(builder Builder) { } } - expr.Build(builder) + if len(where.Exprs) > 1 { + switch v := expr.(type) { + case OrConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case AndConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case Expr: + sql := strings.ToLower(v.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + + if wrapInParentheses { + builder.WriteString(`(`) + expr.Build(builder) + builder.WriteString(`)`) + wrapInParentheses = false + } else { + expr.Build(builder) + } } } @@ -50,6 +84,8 @@ func (where Where) MergeClause(clause *Clause) { func And(exprs ...Expression) Expression { if len(exprs) == 0 { return nil + } else if len(exprs) == 1 { + return exprs[0] } return AndConditions{Exprs: exprs} } @@ -118,6 +154,7 @@ func (not NotConditions) Build(builder Builder) { if len(not.Exprs) > 1 { builder.WriteByte('(') } + for idx, c := range not.Exprs { if idx > 0 { builder.WriteString(" AND ") @@ -127,9 +164,22 @@ func (not NotConditions) Build(builder Builder) { negationBuilder.NegationBuild(builder) } else { builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToLower(e.SQL) + if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + builder.WriteByte('(') + } + } + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } } } + if len(not.Exprs) > 1 { builder.WriteByte(')') } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index e6038947..c0176fc3 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "regexp" "strings" "testing" @@ -188,3 +189,56 @@ func TestGroupConditions(t *testing.T) { t.Errorf("expects: %v, got %v", expects, result) } } + +func TestCombineStringConditions(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } +} From dbe0f4d8d7dad471d7e3931ecb7e24610adb76f5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Sep 2020 20:15:12 +0800 Subject: [PATCH 15/67] Allow use NULL as default value for string, close #3363 --- schema/field.go | 2 +- tests/default_value_test.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/schema/field.go b/schema/field.go index 2e649d81..b49b7de6 100644 --- a/schema/field.go +++ b/schema/field.go @@ -201,7 +201,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.String: field.DataType = String isFunc := strings.Contains(field.DefaultValue, "(") && - strings.Contains(field.DefaultValue, ")") + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" if field.HasDefaultValue && !isFunc { field.DefaultValue = strings.Trim(field.DefaultValue, "'") diff --git a/tests/default_value_test.go b/tests/default_value_test.go index aa4a511a..44309eab 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -13,6 +13,7 @@ func TestDefaultValue(t *testing.T) { Name string `gorm:"not null;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;not null;default:''"` + Name4 string `gorm:"size:233;default:null"` Age int `gorm:"default:18"` Enabled bool `gorm:"default:true"` } From 130f24090db2b9862282281f9dd288c2a214a263 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 2 Sep 2020 21:03:47 +0800 Subject: [PATCH 16/67] update default_value_test --- tests/default_value_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/default_value_test.go b/tests/default_value_test.go index 44309eab..aa4a511a 100644 --- a/tests/default_value_test.go +++ b/tests/default_value_test.go @@ -13,7 +13,6 @@ func TestDefaultValue(t *testing.T) { Name string `gorm:"not null;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;not null;default:''"` - Name4 string `gorm:"size:233;default:null"` Age int `gorm:"default:18"` Enabled bool `gorm:"default:true"` } From fcb666cfa31ecf0de77fcd23e60a67c6819ad7fa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 10:58:48 +0800 Subject: [PATCH 17/67] Fix associations using composite primary keys without ID field, close #3365 --- callbacks/associations.go | 18 +++++++++++++--- tests/multi_primary_keys_test.go | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 2710ffe9..0c677f47 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -5,6 +5,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func SaveBeforeAssociations(db *gorm.DB) { @@ -145,7 +146,7 @@ func SaveAfterAssociations(db *gorm.DB) { } db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + Columns: onConflictColumns(rel.FieldSchema), DoUpdates: clause.AssignmentColumns(assignmentColumns), }).Create(elems.Interface()).Error) } @@ -168,7 +169,7 @@ func SaveAfterAssociations(db *gorm.DB) { } db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + Columns: onConflictColumns(rel.FieldSchema), DoUpdates: clause.AssignmentColumns(assignmentColumns), }).Create(f.Interface()).Error) } @@ -230,7 +231,7 @@ func SaveAfterAssociations(db *gorm.DB) { } db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, + Columns: onConflictColumns(rel.FieldSchema), DoUpdates: clause.AssignmentColumns(assignmentColumns), }).Create(elems.Interface()).Error) } @@ -310,3 +311,14 @@ func SaveAfterAssociations(db *gorm.DB) { } } } + +func onConflictColumns(s *schema.Schema) (columns []clause.Column) { + if s.PrioritizedPrimaryField != nil { + return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} + } + + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + return +} diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 051e3ee2..68da8a88 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) type Blog struct { @@ -410,3 +411,38 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { t.Fatalf("EN Blog's tags should be cleared") } } + +func TestCompositePrimaryKeysAssociations(t *testing.T) { + type Label struct { + BookID *uint `gorm:"primarykey"` + Name string `gorm:"primarykey"` + Value string + } + + type Book struct { + ID int + Name string + Labels []Label + } + + DB.Migrator().DropTable(&Label{}, &Book{}) + if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { + t.Fatalf("failed to migrate") + } + + book := Book{ + Name: "my book", + Labels: []Label{ + {Name: "region", Value: "emea"}, + }, + } + + DB.Create(&book) + + var result Book + if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil { + t.Fatalf("failed to preload, got error %v", err) + } + + AssertEqual(t, book, result) +} From 48b395b760d86fddad7480972791444494a8ae68 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 11:32:30 +0800 Subject: [PATCH 18/67] returns ErrEmptySlice when creating with zero length slice --- callbacks/create.go | 5 +++++ callbacks/helper.go | 5 +++++ errors.go | 2 ++ tests/create_test.go | 12 ++++++++++++ tests/go.mod | 2 ++ 5 files changed, 26 insertions(+) diff --git a/callbacks/create.go b/callbacks/create.go index 5de19d35..e37c2c60 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -252,6 +252,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { stmt.SQL.Grow(stmt.ReflectValue.Len() * 15) values.Values = make([][]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + if stmt.ReflectValue.Len() == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + for i := 0; i < stmt.ReflectValue.Len(); i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) values.Values[i] = make([]interface{}, len(values.Columns)) diff --git a/callbacks/helper.go b/callbacks/helper.go index e0a66dd2..09ec4582 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -46,6 +46,11 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) ) + if len(mapValues) == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + for idx, mapValue := range mapValues { for k, v := range mapValue { if stmt.Schema != nil { diff --git a/errors.go b/errors.go index 32ff8ec1..508f6957 100644 --- a/errors.go +++ b/errors.go @@ -27,4 +27,6 @@ var ( ErrRegistered = errors.New("registered") // ErrInvalidField invalid field ErrInvalidField = errors.New("invalid field") + // ErrEmptySlice empty slice found + ErrEmptySlice = errors.New("empty slice found") ) diff --git a/tests/create_test.go b/tests/create_test.go index ab0a78d4..59fdd8f1 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -287,6 +287,18 @@ func TestCreateEmptyStruct(t *testing.T) { } } +func TestCreateEmptySlice(t *testing.T) { + var data = []User{} + if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } + + var sliceMap = []map[string]interface{}{} + if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } +} + func TestCreateWithExistingTimestamp(t *testing.T) { user := User{Name: "CreateUserExistingTimestamp"} curTime := now.MustParse("2016-01-01") diff --git a/tests/go.mod b/tests/go.mod index 30a7dda7..2b336850 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -16,3 +16,5 @@ require ( replace gorm.io/gorm => ../ replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2 + +replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/jinzhu/sqlserver From ff3880292dc89da8061269e74cfdeb75e20aee6f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 11:48:44 +0800 Subject: [PATCH 19/67] Update missing playground template --- .github/workflows/missing_playground.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 6fb714ca..422cb9f5 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -13,7 +13,7 @@ jobs: uses: actions/stale@v3.0.7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs." + stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 days-before-close: 2 From 98e15e0b95b39f9caefbb8b14a1e479a237e52fe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 12:54:26 +0800 Subject: [PATCH 20/67] Setup DB's ConnPool in PrepareStmt mode, fix #3362 --- gorm.go | 2 ++ tests/prepared_stmt_test.go | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/gorm.go b/gorm.go index fec4310b..ed01ccfe 100644 --- a/gorm.go +++ b/gorm.go @@ -176,6 +176,8 @@ func (db *DB) Session(config *Session) *DB { Mux: preparedStmt.Mux, Stmts: preparedStmt.Stmts, } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } } diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index af610165..6b10b6dc 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -12,6 +12,10 @@ import ( func TestPreparedStmt(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) + if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() txCtx := tx.WithContext(ctx) From 3cc7a307122e1ca2d0fbb298c264c51fce1bdd62 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 13:28:37 +0800 Subject: [PATCH 21/67] Fix tests/go.mod --- tests/go.mod | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 2b336850..30a7dda7 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -16,5 +16,3 @@ require ( replace gorm.io/gorm => ../ replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2 - -replace gorm.io/driver/sqlserver => /Users/jinzhu/Projects/jinzhu/sqlserver From cf31508095ecae9a50ecfde1cf7c534d01fbe745 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 15:02:04 +0800 Subject: [PATCH 22/67] Fix tests_all.sh --- tests/go.mod | 2 +- tests/tests_all.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/go.mod b/tests/go.mod index 30a7dda7..4ddb0b69 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,7 +9,7 @@ require ( gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 gorm.io/driver/sqlite v1.1.1 - gorm.io/driver/sqlserver v1.0.2 + gorm.io/driver/sqlserver v1.0.3 gorm.io/gorm v1.9.19 ) diff --git a/tests/tests_all.sh b/tests/tests_all.sh index e87ff045..744a40e9 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -10,7 +10,7 @@ if [ -d tests ] then cd tests cp go.mod go.mod.bak - sed '/$[[:space:]]*gorm.io\/driver/d' go.mod.bak > go.mod + sed '/^[[:blank:]]*gorm.io\/driver/d' go.mod.bak > go.mod cd .. fi From f2adb088c598400086b6e67506ffee38780e9c3f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 16:11:15 +0800 Subject: [PATCH 23/67] Set field size from primary fields to foreign fields --- gorm.go | 3 +++ schema/relationship.go | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/gorm.go b/gorm.go index ed01ccfe..8efd8a73 100644 --- a/gorm.go +++ b/gorm.go @@ -319,6 +319,9 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { f.DataType = ref.ForeignKey.DataType f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size + } ref.ForeignKey = f } else { return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) diff --git a/schema/relationship.go b/schema/relationship.go index aa992b84..47b948dc 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -165,6 +165,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi // use same data type for foreign keys relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType + if relation.Polymorphic.PolymorphicID.Size == 0 { + relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, @@ -301,6 +304,9 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType f.GORMDataType = fieldsMap[f.Name].GORMDataType + if f.Size == 0 { + f.Size = fieldsMap[f.Name].Size + } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] @@ -436,6 +442,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue // use same data type for foreign keys foreignField.DataType = primaryFields[idx].DataType foreignField.GORMDataType = primaryFields[idx].GORMDataType + if foreignField.Size == 0 { + foreignField.Size = primaryFields[idx].Size + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], From 78e9c9b7488fbc71bf2ab853db4490d241cb0ada Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 18:20:57 +0800 Subject: [PATCH 24/67] raise error when failed to parse default value, close #3378 --- schema/field.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/schema/field.go b/schema/field.go index b49b7de6..0cb210f8 100644 --- a/schema/field.go +++ b/schema/field.go @@ -70,6 +70,8 @@ type Field struct { } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { + var err error + field := &Field{ Name: fieldStruct.Name, BindNames: []string{fieldStruct.Name}, @@ -151,7 +153,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if num, ok := field.TagSettings["SIZE"]; ok { - var err error if field.Size, err = strconv.Atoi(num); err != nil { field.Size = -1 } @@ -181,22 +182,30 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Bool: field.DataType = Bool if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue) + if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) + } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64) + if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) + } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64) + if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) + } } case reflect.Float32, reflect.Float64: field.DataType = Float if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64) + if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) + } } case reflect.String: field.DataType = String From 3cd81ff646090931556cf5590c41ac5d5746357c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 18:42:32 +0800 Subject: [PATCH 25/67] Fix query with specified table and conditions, close #3382 --- statement.go | 8 ++++---- tests/query_test.go | 9 ++++++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index d72a086f..e16cf0ff 100644 --- a/statement.go +++ b/statement.go @@ -317,9 +317,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c if field.Readable { if v, isZero := field.ValueOf(reflectValue); !isZero { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) } } } @@ -330,9 +330,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c if field.Readable { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) } } } diff --git a/tests/query_test.go b/tests/query_test.go index 6bb68cd3..795186da 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -202,7 +202,6 @@ func TestFind(t *testing.T) { } } }) - } func TestQueryWithAssociation(t *testing.T) { @@ -800,3 +799,11 @@ func TestScanNullValue(t *testing.T) { t.Fatalf("failed to query slice data with null age, got error %v", err) } } + +func TestQueryWithTableAndConditions(t *testing.T) { + result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) + + if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} From dd0d74fad06342a792a1cdc20101a57ee019f447 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 19:16:55 +0800 Subject: [PATCH 26/67] Fix transaction on closed conn when using prepared statement, close #3380 --- prepare_stmt.go | 14 ++++++++++++++ tests/tests_test.go | 4 ++-- tests/transaction_test.go | 21 +++++++++++++++++++++ 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/prepare_stmt.go b/prepare_stmt.go index de7e2a26..14a6aaec 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -99,6 +99,20 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (tx *PreparedStmtTX) Commit() error { + if tx.Tx != nil { + return tx.Tx.Commit() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) Rollback() error { + if tx.Tx != nil { + return tx.Tx.Rollback() + } + return ErrInvalidTransaction +} + func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { diff --git a/tests/tests_test.go b/tests/tests_test.go index 192160a0..cb73d267 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -21,7 +21,7 @@ var DB *gorm.DB func init() { var err error if DB, err = OpenTestConnection(); err != nil { - log.Printf("failed to connect database, got error %v\n", err) + log.Printf("failed to connect database, got error %v", err) os.Exit(1) } else { sqlDB, err := DB.DB() @@ -30,7 +30,7 @@ func init() { } if err != nil { - log.Printf("failed to connect database, got error %v\n", err) + log.Printf("failed to connect database, got error %v", err) } RunMigrations() diff --git a/tests/transaction_test.go b/tests/transaction_test.go index aea151d9..334600b8 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -282,3 +282,24 @@ func TestNestedTransactionWithBlock(t *testing.T) { t.Fatalf("Should find saved record") } } + +func TestTransactionOnClosedConn(t *testing.T) { + DB, err := OpenTestConnection() + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + rawDB, _ := DB.DB() + rawDB.Close() + + if err := DB.Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } + + if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } +} From 6a866464695e8b0291236f9038a032f68fb0b37d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 20:41:00 +0800 Subject: [PATCH 27/67] Fix use db function as integer's default value, close #3384 --- schema/field.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/schema/field.go b/schema/field.go index 0cb210f8..f8a73c60 100644 --- a/schema/field.go +++ b/schema/field.go @@ -178,41 +178,41 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Comment = val } + defaultValueFunc := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && !defaultValueFunc { if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && !defaultValueFunc { if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && !defaultValueFunc { if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) } } case reflect.Float32, reflect.Float64: field.DataType = Float - if field.HasDefaultValue && field.DefaultValue != "" { + if field.HasDefaultValue && !defaultValueFunc { if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) } } case reflect.String: field.DataType = String - isFunc := strings.Contains(field.DefaultValue, "(") && - strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" - if field.HasDefaultValue && !isFunc { + if field.HasDefaultValue && !defaultValueFunc { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "\"") field.DefaultValueInterface = field.DefaultValue From 28121d44554b1f5db07658e7cc8343ace65d940d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 3 Sep 2020 20:59:41 +0800 Subject: [PATCH 28/67] Fix panic when batch creating from slice contains invalid data, close #3385 --- callbacks/create.go | 6 ++++++ tests/create_test.go | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/callbacks/create.go b/callbacks/create.go index e37c2c60..c00a0a73 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "gorm.io/gorm" @@ -259,6 +260,11 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for i := 0; i < stmt.ReflectValue.Len(); i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) + if !rv.IsValid() { + stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) + return + } + values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] diff --git a/tests/create_test.go b/tests/create_test.go index 59fdd8f1..00674eec 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "errors" "testing" "time" @@ -299,6 +300,18 @@ func TestCreateEmptySlice(t *testing.T) { } } +func TestCreateInvalidSlice(t *testing.T) { + users := []*User{ + GetUser("invalid_slice_1", Config{}), + GetUser("invalid_slice_2", Config{}), + nil, + } + + if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error invalid data when creating from slice that contains invalid data") + } +} + func TestCreateWithExistingTimestamp(t *testing.T) { user := User{Name: "CreateUserExistingTimestamp"} curTime := now.MustParse("2016-01-01") From f1216222284fc2f91bee7018c5c54a3662b9a2b3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 4 Sep 2020 14:30:53 +0800 Subject: [PATCH 29/67] Don't add prefix for invalid embedded fields --- schema/field.go | 2 +- schema/schema_test.go | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/schema/field.go b/schema/field.go index f8a73c60..db044c23 100644 --- a/schema/field.go +++ b/schema/field.go @@ -340,7 +340,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) } - if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" { ef.DBName = prefix + ef.DBName } diff --git a/schema/schema_test.go b/schema/schema_test.go index 4d13ebd2..6ca5b269 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -194,6 +194,7 @@ func TestEmbeddedStruct(t *testing.T) { ID int OwnerID int Name string + Ignored string `gorm:"-"` } type Corp struct { @@ -211,15 +212,18 @@ func TestEmbeddedStruct(t *testing.T) { {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, } for _, f := range fields { checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { - f.Creatable = true - f.Updatable = true - f.Readable = true + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } }) } } From d8ddccf1478bf1aaf3726f2301c08fe6a9ca4183 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 4 Sep 2020 19:02:37 +0800 Subject: [PATCH 30/67] Don't marshal to null for associations after preloading, close #3395 --- callbacks/preload.go | 14 ++++++++++++-- tests/preload_test.go | 24 ++++++++++++++++++++++++ tests/scan_test.go | 8 ++++++-- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/callbacks/preload.go b/callbacks/preload.go index 9b8f762a..aec10ec5 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -110,10 +110,20 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { // clean up old values before preloading switch reflectValue.Kind() { case reflect.Struct: - rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + default: + rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + default: + rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + } } } diff --git a/tests/preload_test.go b/tests/preload_test.go index 76b72f14..d9035661 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -1,6 +1,8 @@ package tests_test import ( + "encoding/json" + "regexp" "sort" "strconv" "testing" @@ -188,3 +190,25 @@ func TestNestedPreloadWithConds(t *testing.T) { CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) } } + +func TestPreloadEmptyData(t *testing.T) { + var user = *GetUser("user_without_associations", Config{}) + DB.Create(&user) + + DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name) + + if r, err := json.Marshal(&user); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } + + var results []User + DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name) + + if r, err := json.Marshal(&results); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } +} diff --git a/tests/scan_test.go b/tests/scan_test.go index 3e66a25a..92e89521 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -51,11 +51,11 @@ func TestScan(t *testing.T) { DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) sort.Slice(results, func(i, j int) bool { - return strings.Compare(results[i].Name, results[j].Name) < -1 + return strings.Compare(results[i].Name, results[j].Name) <= -1 }) if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { - t.Errorf("Scan into struct map") + t.Errorf("Scan into struct map, got %#v", results) } } @@ -84,6 +84,10 @@ func TestScanRows(t *testing.T) { results = append(results, result) } + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) <= -1 + }) + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { t.Errorf("Should find expected results") } From 6e38a2c2d510a6823ad7b73c7e9321c8f7ceaff8 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 6 Sep 2020 10:51:21 +0800 Subject: [PATCH 31/67] Fix many2many join table name rule --- schema/naming.go | 4 ++++ schema/relationship_test.go | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index 9b7c9471..ecdab791 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -41,6 +41,10 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { + if strings.ToLower(str) == str { + return str + } + if ns.SingularTable { return ns.TablePrefix + toDBName(str) } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index f2d63323..b9279b9f 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -206,16 +206,16 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type User struct { gorm.Model - Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` Refer uint } checkStructRelation(t, &User{}, Relation{ Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", - JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, References: []Reference{ - {"ID", "User", "UserReferID", "user_profiles", "", true}, - {"ID", "Profile", "ProfileRefer", "user_profiles", "", false}, + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, }, }) } From 05794298bd3d87dc8e98de8cde451b19093e2a4a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 6 Sep 2020 12:22:05 +0800 Subject: [PATCH 32/67] Fix Save with specified table, close #3396 --- finisher_api.go | 3 ++- tests/update_test.go | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/finisher_api.go b/finisher_api.go index 1d5ef5fc..6ece0f79 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -51,7 +51,8 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.callbacks.Update().Execute(tx) if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { - if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) { + result := reflect.New(tx.Statement.Schema.ModelType).Interface() + if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } diff --git a/tests/update_test.go b/tests/update_test.go index 1944ed3f..a660647c 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -629,4 +629,26 @@ func TestSaveWithPrimaryValue(t *testing.T) { var result2 Language DB.First(&result2, "code = ?", "save") AssertEqual(t, result2, lang) + + DB.Table("langs").Migrator().DropTable(&Language{}) + DB.Table("langs").AutoMigrate(&Language{}) + + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result3 Language + if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3) + } + + lang.Name += "name2" + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result4 Language + if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) + } } From 6de0356a57f74da299e7cb2b8ccd44e86fe59675 Mon Sep 17 00:00:00 2001 From: egenchen Date: Tue, 8 Sep 2020 16:59:47 +0800 Subject: [PATCH 33/67] Fix monocolor log output inconsist with colorful log (#3425) --- logger/logger.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 49ae988c..0b0a7377 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -65,9 +65,9 @@ func New(writer Writer, config Config) Interface { infoStr = "%s\n[info] " warnStr = "%s\n[warn] " errStr = "%s\n[error] " - traceStr = "%s\n[%v] [rows:%d] %s" - traceWarnStr = "%s\n[%v] [rows:%d] %s" - traceErrStr = "%s %s\n[%v] [rows:%d] %s" + traceStr = "%s\n[%.3fms] [rows:%d] %s" + traceWarnStr = "%s\n[%.3fms] [rows:%d] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%d] %s" ) if config.Colorful { From c9d5c0b07aa7be8ed4bebeb376ccf158542730ea Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 8 Sep 2020 18:24:35 +0800 Subject: [PATCH 34/67] Fix create database foreign keys for same type having has many/one & many2many relationships, close #3424 --- migrator/migrator.go | 23 ++++++++++++++++++----- tests/embedded_struct_test.go | 4 +++- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 29d26c9e..98e92c96 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -586,6 +586,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i var ( modelNames, orderedModelNames []string orderedModelNamesMap = map[string]bool{} + parsedSchemas = map[*schema.Schema]bool{} valuesMap = map[string]Dependency{} insertIntoOrderedList func(name string) parseDependence func(value interface{}, addToList bool) @@ -595,23 +596,35 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i dep := Dependency{ Statement: &gorm.Statement{DB: m.DB, Dest: value}, } + beDependedOn := map[*schema.Schema]bool{} if err := dep.Parse(value); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } + if _, ok := parsedSchemas[dep.Statement.Schema]; ok { + return + } + parsedSchemas[dep.Statement.Schema] = true for _, rel := range dep.Schema.Relationships.Relations { if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { dep.Depends = append(dep.Depends, c.ReferenceSchema) } + if rel.Type == schema.HasOne || rel.Type == schema.HasMany { + beDependedOn[rel.FieldSchema] = true + } + if rel.JoinTable != nil { - if rel.Schema != rel.FieldSchema { - dep.Depends = append(dep.Depends, rel.FieldSchema) - } // append join value - defer func(joinValue interface{}) { + defer func(rel *schema.Relationship, joinValue interface{}) { + if !beDependedOn[rel.FieldSchema] { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } else { + fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() + parseDependence(fieldValue, autoAdd) + } parseDependence(joinValue, autoAdd) - }(reflect.New(rel.JoinTable.ModelType).Interface()) + }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) } } diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index c29078bd..312a5c37 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -163,6 +163,8 @@ func TestEmbeddedRelations(t *testing.T) { DB.Migrator().DropTable(&AdvancedUser{}) if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { - t.Errorf("Failed to auto migrate advanced user, got error %v", err) + if DB.Dialector.Name() != "sqlite" { + t.Errorf("Failed to auto migrate advanced user, got error %v", err) + } } } From c70c097e88bd5372783da6af55c4742fa4fe83ab Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 8 Sep 2020 19:11:20 +0800 Subject: [PATCH 35/67] Refactor format SQL for driver.Valuer --- logger/sql.go | 20 ++++++++++++++++++++ tests/go.mod | 4 ---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 80645b0c..096b9407 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -38,6 +38,26 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } else { vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper } + case *time.Time: + if v != nil { + if v.IsZero() { + vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + } else { + vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper + } + } else { + vars[idx] = "NULL" + } + case fmt.Stringer: + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + case driver.Valuer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && (reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) { + r, _ := v.Value() + vars[idx] = fmt.Sprintf("%v", r) + } else { + vars[idx] = "NULL" + } case []byte: if isPrintable(v) { vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper diff --git a/tests/go.mod b/tests/go.mod index 4ddb0b69..76db6764 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,10 +6,6 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v1.0.1 - gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.1 - gorm.io/driver/sqlserver v1.0.3 gorm.io/gorm v1.9.19 ) From aceb3dad3bbd43e79d0146992701f4f25f3eabb0 Mon Sep 17 00:00:00 2001 From: caelansar <819711623@qq.com> Date: Tue, 8 Sep 2020 21:28:04 +0800 Subject: [PATCH 36/67] correct generated sql --- clause/expression.go | 3 +++ tests/query_test.go | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/clause/expression.go b/clause/expression.go index 3b914e68..55599571 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -37,6 +37,9 @@ func (expr Expr) Build(builder Builder) { } else { switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } for i := 0; i < rv.Len(); i++ { if i > 0 { builder.WriteByte(',') diff --git a/tests/query_test.go b/tests/query_test.go index 795186da..e695e825 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -202,6 +202,22 @@ func TestFind(t *testing.T) { } } }) + + var models []User + if err := DB.Where("name in ?", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models[idx], user) + }) + } + } + + var none []User + if err := DB.Where("name in ?", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) + } } func TestQueryWithAssociation(t *testing.T) { From 222427c474a3146bf79cb782fe50fae7d80aae69 Mon Sep 17 00:00:00 2001 From: "Jonathan A. Sternberg" Date: Tue, 8 Sep 2020 18:12:14 -0500 Subject: [PATCH 37/67] Release the connection when discovering the column types in the migrator When the migrator is used to discover the column types, such as when used with `AutoMigrate()`, it does not close the query result. This changes the migrator to close the query result and it also changes the query to use `LIMIT 1` to prevent additional work against the database when only discovering the schema. Fixes #3432. --- migrator/migrator.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 98e92c96..c0e22ae0 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -388,9 +388,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Session(&gorm.Session{}).Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() + rows, err := m.DB.Session(&gorm.Session{}).Raw("select * from ? limit 1", clause.Table{Name: stmt.Table}).Rows() if err == nil { columnTypes, err = rows.ColumnTypes() + _ = rows.Close() } return err }) From 2242ac6c0ea490f7fa7c60c61126be0fdee0d72f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 10:31:48 +0800 Subject: [PATCH 38/67] Fix tests & refactor for PR #3429 --- clause/expression.go | 11 ++++++----- tests/go.mod | 4 ++++ tests/query_test.go | 4 ++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 55599571..dde236d3 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -39,12 +39,13 @@ func (expr Expr) Build(builder Builder) { case reflect.Slice, reflect.Array: if rv.Len() == 0 { builder.AddVar(builder, nil) - } - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) } default: builder.AddVar(builder, expr.Vars[idx]) diff --git a/tests/go.mod b/tests/go.mod index 76db6764..4ddb0b69 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,6 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 + gorm.io/driver/mysql v1.0.1 + gorm.io/driver/postgres v1.0.0 + gorm.io/driver/sqlite v1.1.1 + gorm.io/driver/sqlserver v1.0.3 gorm.io/gorm v1.9.19 ) diff --git a/tests/query_test.go b/tests/query_test.go index e695e825..14150038 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -204,7 +204,7 @@ func TestFind(t *testing.T) { }) var models []User - if err := DB.Where("name in ?", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) } else { for idx, user := range users { @@ -215,7 +215,7 @@ func TestFind(t *testing.T) { } var none []User - if err := DB.Where("name in ?", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) } } From 839e09e98558d946b4bf316bcd142edcf727ac37 Mon Sep 17 00:00:00 2001 From: caelansar <819711623@qq.com> Date: Tue, 8 Sep 2020 21:28:04 +0800 Subject: [PATCH 39/67] correct generated sql --- clause/expression.go | 3 +++ tests/query_test.go | 16 ++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/clause/expression.go b/clause/expression.go index 3b914e68..55599571 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -37,6 +37,9 @@ func (expr Expr) Build(builder Builder) { } else { switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { case reflect.Slice, reflect.Array: + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } for i := 0; i < rv.Len(); i++ { if i > 0 { builder.WriteByte(',') diff --git a/tests/query_test.go b/tests/query_test.go index 795186da..e695e825 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -202,6 +202,22 @@ func TestFind(t *testing.T) { } } }) + + var models []User + if err := DB.Where("name in ?", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models[idx], user) + }) + } + } + + var none []User + if err := DB.Where("name in ?", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) + } } func TestQueryWithAssociation(t *testing.T) { From e7188c04ca9d81767ff090bc584177f4b6fb9fcc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 10:31:48 +0800 Subject: [PATCH 40/67] Fix tests & refactor for PR #3429 --- clause/expression.go | 11 ++++++----- tests/go.mod | 4 ++++ tests/query_test.go | 4 ++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 55599571..dde236d3 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -39,12 +39,13 @@ func (expr Expr) Build(builder Builder) { case reflect.Slice, reflect.Array: if rv.Len() == 0 { builder.AddVar(builder, nil) - } - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) } default: builder.AddVar(builder, expr.Vars[idx]) diff --git a/tests/go.mod b/tests/go.mod index 76db6764..4ddb0b69 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,6 +6,10 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 + gorm.io/driver/mysql v1.0.1 + gorm.io/driver/postgres v1.0.0 + gorm.io/driver/sqlite v1.1.1 + gorm.io/driver/sqlserver v1.0.3 gorm.io/gorm v1.9.19 ) diff --git a/tests/query_test.go b/tests/query_test.go index e695e825..14150038 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -204,7 +204,7 @@ func TestFind(t *testing.T) { }) var models []User - if err := DB.Where("name in ?", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) } else { for idx, user := range users { @@ -215,7 +215,7 @@ func TestFind(t *testing.T) { } var none []User - if err := DB.Where("name in ?", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) } } From 567597f000606b2266ff4b43950f5a801c2f2f63 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 10:53:13 +0800 Subject: [PATCH 41/67] Fix fail on sqlserver, #3433 --- migrator/migrator.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index c0e22ae0..53fd5ac0 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -388,10 +388,10 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Session(&gorm.Session{}).Raw("select * from ? limit 1", clause.Table{Name: stmt.Table}).Rows() + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err == nil { + defer rows.Close() columnTypes, err = rows.ColumnTypes() - _ = rows.Close() } return err }) From f6117b7f3dd21629b8196c376b0284d71672d1c7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 16:26:11 +0800 Subject: [PATCH 42/67] Should not diplay SubQuery SQL log, close #3437 --- logger/logger.go | 14 +++++++++----- statement.go | 3 ++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/logger/logger.go b/logger/logger.go index 0b0a7377..831192fc 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "context" + "io/ioutil" "log" "os" "time" @@ -54,11 +55,14 @@ type Interface interface { Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) } -var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ - SlowThreshold: 100 * time.Millisecond, - LogLevel: Warn, - Colorful: true, -}) +var ( + Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ + SlowThreshold: 100 * time.Millisecond, + LogLevel: Warn, + Colorful: true, + }) +) func New(writer Writer, config Config) Interface { var ( diff --git a/statement.go b/statement.go index e16cf0ff..ee80f8cd 100644 --- a/statement.go +++ b/statement.go @@ -12,6 +12,7 @@ import ( "sync" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -189,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } case *DB: - subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance() + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance() subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) subdb.callbacks.Query().Execute(subdb) writer.WriteString(subdb.Statement.SQL.String()) From f6ed895caffcde0b37d181201a5cadd442b8879e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 9 Sep 2020 16:32:29 +0800 Subject: [PATCH 43/67] Build relationships if fields are not ignored, fix #3181 --- schema/relationship.go | 2 +- schema/relationship_test.go | 23 +++++++++++++++++++++++ schema/schema.go | 4 ++-- 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/schema/relationship.go b/schema/relationship.go index 47b948dc..35af111f 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -300,7 +300,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel // build references for _, f := range relation.JoinTable.Fields { - if f.Creatable { + if f.Creatable || f.Readable || f.Updatable { // use same data type for foreign keys f.DataType = fieldsMap[f.Name].DataType f.GORMDataType = fieldsMap[f.Name].GORMDataType diff --git a/schema/relationship_test.go b/schema/relationship_test.go index b9279b9f..7d7fd9c9 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -220,6 +220,29 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { }) } +func TestBuildReadonlyMany2ManyRelation(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, + References: []Reference{ + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, + }, + }) +} + func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { type Tag struct { ID uint `gorm:"primary_key"` diff --git a/schema/schema.go b/schema/schema.go index ea81d683..c3d3f6e0 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -133,7 +133,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission - if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { + if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { if _, ok := schema.FieldsByDBName[field.DBName]; !ok { schema.DBNames = append(schema.DBNames, field.DBName) } @@ -219,7 +219,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { - if field.DataType == "" && field.Creatable { + if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { if schema.parseRelation(field); schema.err != nil { return schema, schema.err } From 619d306cef27adf4681bd04edfc0a620217471b2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Sep 2020 10:55:02 +0800 Subject: [PATCH 44/67] ignore (-) when creating default values, #3434 --- migrator/migrator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 53fd5ac0..4b069c8a 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -71,7 +71,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) - } else { + } else if field.DefaultValue != "(-)" { expr.SQL += " DEFAULT " + field.DefaultValue } } From 231effe119fd25f368fa6ff5b5724e519bf59cd9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Sep 2020 11:59:18 +0800 Subject: [PATCH 45/67] Fix parse blank default value, close #3442 --- schema/field.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/schema/field.go b/schema/field.go index db044c23..e52a8aef 100644 --- a/schema/field.go +++ b/schema/field.go @@ -178,33 +178,34 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Comment = val } - defaultValueFunc := strings.Contains(field.DefaultValue, "(") && - strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" + // default value is function or null or blank (primary keys) + skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) } } case reflect.Float32, reflect.Float64: field.DataType = Float - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) } @@ -212,7 +213,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.String: field.DataType = String - if field.HasDefaultValue && !defaultValueFunc { + if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "\"") field.DefaultValueInterface = field.DefaultValue From 53caa85cf48f2ff4eee47fb55a07a3f3f16388fe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Sep 2020 19:20:47 +0800 Subject: [PATCH 46/67] Use db's Logger for callbacks logs, close #3448, #3447 --- callbacks.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/callbacks.go b/callbacks.go index eace06ca..83d103df 100644 --- a/callbacks.go +++ b/callbacks.go @@ -8,7 +8,6 @@ import ( "sort" "time" - "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -156,7 +155,7 @@ func (p *processor) compile() (err error) { p.callbacks = callbacks if p.fns, err = sortCallbacks(p.callbacks); err != nil { - logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err) + p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) } return } @@ -179,7 +178,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { } func (c *callback) Remove(name string) error { - logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) @@ -187,7 +186,7 @@ func (c *callback) Remove(name string) error { } func (c *callback) Replace(name string, fn func(*DB)) error { - logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true @@ -217,7 +216,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { - logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } From 70a7bd52ca2bbf64443b7227524e4600997ea1b9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 10 Sep 2020 21:46:18 +0800 Subject: [PATCH 47/67] Support delete associations with Select when deleting --- callbacks/callbacks.go | 1 + callbacks/delete.go | 53 ++++++++++++++++++++++++++++++++++++++ tests/delete_test.go | 54 +++++++++++++++++++++++++++++++++++++++ tests/joins_table_test.go | 18 +++++++++++++ utils/utils.go | 2 +- 5 files changed, 127 insertions(+), 1 deletion(-) diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 0a12468c..dda4b046 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback := db.Callback().Delete() deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callbacks/delete.go b/callbacks/delete.go index e95117a1..510dfae4 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -21,6 +21,59 @@ func BeforeDelete(db *gorm.DB) { } } +func DeleteBeforeAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + + if restricted { + for column, v := range selectColumns { + if v { + if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{}).Model(modelValue) + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + case schema.Many2Many: + var ( + queryConds []clause.Expression + foreignFields []*schema.Field + relForeignKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + } + } + } + } + } +} + func Delete(db *gorm.DB) { if db.Error == nil { if db.Statement.Schema != nil && !db.Statement.Unscoped { diff --git a/tests/delete_test.go b/tests/delete_test.go index 17299677..4945e837 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -5,6 +5,7 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) { t.Errorf("should returns no error while enable global update, but got err %v", err) } } + +func TestDeleteWithAssociations(t *testing.T) { + user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations).Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + +func TestDeleteSliceWithAssociations(t *testing.T) { + users := []User{ + *GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), + *GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}), + *GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}), + *GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}), + } + + if err := DB.Create(users).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go index b8c1be77..084c2f2c 100644 --- a/tests/joins_table_test.go +++ b/tests/joins_table_test.go @@ -5,12 +5,14 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type Person struct { ID int Name string Addresses []Address `gorm:"many2many:person_addresses;"` + DeletedAt gorm.DeletedAt } type Address struct { @@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) { if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { t.Fatalf("address should be deleted when clear with unscoped") } + + address2_1 := Address{Name: "address 2-1"} + address2_2 := Address{Name: "address 2-2"} + person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}} + DB.Create(&person2) + if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil { + t.Fatalf("failed to delete person, got error: %v", err) + } + + if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 { + t.Errorf("person's addresses expects 2, got %v", count) + } + + if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 { + t.Errorf("person's addresses expects 2, got %v", count) + } } diff --git a/utils/utils.go b/utils/utils.go index 905001a5..ecba7fb9 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -30,7 +30,7 @@ func FileWithLineNum() string { } func IsValidDBNameChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } func CheckTruth(val interface{}) bool { From b8a74a80d732963df95580eae3316db140a882a4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 10:49:31 +0800 Subject: [PATCH 48/67] Fix embedded struct with default value, close #3451 --- schema/field.go | 24 +++++++++++++----------- tests/go.mod | 4 ++-- tests/query_test.go | 1 + 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/schema/field.go b/schema/field.go index e52a8aef..60dc8095 100644 --- a/schema/field.go +++ b/schema/field.go @@ -345,19 +345,21 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.DBName = prefix + ef.DBName } - if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else { - ef.PrimaryKey = false + if ef.PrimaryKey { + if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else { + ef.PrimaryKey = false - if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { - ef.AutoIncrement = false - } + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } - if ef.DefaultValue == "" { - ef.HasDefaultValue = false + if ef.DefaultValue == "" { + ef.HasDefaultValue = false + } } } diff --git a/tests/go.mod b/tests/go.mod index 4ddb0b69..f62365f8 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,8 +9,8 @@ require ( gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 gorm.io/driver/sqlite v1.1.1 - gorm.io/driver/sqlserver v1.0.3 - gorm.io/gorm v1.9.19 + gorm.io/driver/sqlserver v1.0.4 + gorm.io/gorm v1.20.0 ) replace gorm.io/gorm => ../ diff --git a/tests/query_test.go b/tests/query_test.go index 14150038..36229e2c 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -648,6 +648,7 @@ func TestOffset(t *testing.T) { if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") } + DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { From e583dfa196400896932c073d05383fcf6cedeb4f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 11:44:58 +0800 Subject: [PATCH 49/67] Allow negative number for limit --- clause/limit.go | 4 +--- tests/go.mod | 2 +- tests/query_test.go | 3 ++- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/clause/limit.go b/clause/limit.go index 1946820d..2082f4d9 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -33,10 +33,8 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if limit.Limit == 0 && v.Limit > 0 { + if limit.Limit == 0 && v.Limit != 0 { limit.Limit = v.Limit - } else if limit.Limit < 0 { - limit.Limit = 0 } if limit.Offset == 0 && v.Offset > 0 { diff --git a/tests/go.mod b/tests/go.mod index f62365f8..17a3b156 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,7 +8,7 @@ require ( github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.1 + gorm.io/driver/sqlite v1.1.2 gorm.io/driver/sqlserver v1.0.4 gorm.io/gorm v1.20.0 ) diff --git a/tests/query_test.go b/tests/query_test.go index 36229e2c..d3bcbdbe 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -625,6 +625,7 @@ func TestLimit(t *testing.T) { {Name: "LimitUser3", Age: 20}, {Name: "LimitUser4", Age: 10}, {Name: "LimitUser5", Age: 20}, + {Name: "LimitUser6", Age: 20}, } DB.Create(&users) @@ -633,7 +634,7 @@ func TestLimit(t *testing.T) { DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") + t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3)) } } From 02fb382ec0b67a320fc26cdd460a70468d037779 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 15:01:02 +0800 Subject: [PATCH 50/67] Support scan into int, string data types --- finisher_api.go | 4 +++- scan.go | 2 +- tests/scan_test.go | 10 ++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/finisher_api.go b/finisher_api.go index 6ece0f79..f426839a 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -384,7 +384,9 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() - tx.Error = tx.Statement.Parse(dest) + if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { + tx.AddError(err) + } tx.Statement.Dest = dest tx.Statement.ReflectValue = reflect.ValueOf(dest) for tx.Statement.ReflectValue.Kind() == reflect.Ptr { diff --git a/scan.go b/scan.go index 89d9a07a..be8782ed 100644 --- a/scan.go +++ b/scan.go @@ -82,7 +82,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64: + case *int, *int64, *uint, *uint64, *float32, *float64, *string: for initialized || rows.Next() { initialized = false db.RowsAffected++ diff --git a/tests/scan_test.go b/tests/scan_test.go index 92e89521..785bb97e 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -91,4 +91,14 @@ func TestScanRows(t *testing.T) { if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { t.Errorf("Should find expected results") } + + var ages int + if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages) + } + + var name string + if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + } } From ed1b134e1c6d8d791fc87a7286e9c534fa2840f2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 11 Sep 2020 17:33:31 +0800 Subject: [PATCH 51/67] Fix use uint to for autoCreateTime, autoUpdateTime --- schema/field.go | 8 ++++++++ tests/customize_field_test.go | 22 +++++++++++----------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/schema/field.go b/schema/field.go index 60dc8095..4b8a5a2a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -624,6 +624,14 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).SetUint(uint64(data)) case []byte: return field.Set(value, string(data)) + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + } else { + field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { field.ReflectValueOf(value).SetUint(i) diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go index bf3c78fa..7802eb11 100644 --- a/tests/customize_field_test.go +++ b/tests/customize_field_test.go @@ -69,12 +69,12 @@ func TestCustomizeField(t *testing.T) { FieldAllowSave3 string `gorm:"->:false;<-:create"` FieldReadonly string `gorm:"->"` FieldIgnore string `gorm:"-"` - AutoUnixCreateTime int64 `gorm:"autocreatetime"` - AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"` + AutoUnixCreateTime int32 `gorm:"autocreatetime"` + AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"` AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` - AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` - AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"` - AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` + AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"` + AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"` + AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"` } DB.Migrator().DropTable(&CustomizeFieldStruct{}) @@ -116,15 +116,15 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid result: %#v", result) } - if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 { + if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 { t.Fatalf("invalid create/update unix time: %#v", result) } - if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 { + if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 { t.Fatalf("invalid create/update unix milli time: %#v", result) } - if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 { + if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 { t.Fatalf("invalid create/update unix nano time: %#v", result) } @@ -178,15 +178,15 @@ func TestCustomizeField(t *testing.T) { var createWithDefaultTimeResult CustomizeFieldStruct DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) - if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) } - if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) } - if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) } } From 0ec10d4907762e94ac942903670184a93e7ed456 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 14 Sep 2020 12:37:16 +0800 Subject: [PATCH 52/67] Fix format SQL log, close #3465 --- logger/sql.go | 16 ++++++++++++++-- logger/sql_test.go | 6 ++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 096b9407..69a6b10e 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -96,9 +96,21 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } if numericPlaceholder == nil { - for _, v := range vars { - sql = strings.Replace(sql, "?", v, 1) + var idx int + var newSQL strings.Builder + + for _, v := range []byte(sql) { + if v == '?' { + if len(vars) > idx { + newSQL.WriteString(vars[idx]) + idx++ + continue + } + } + newSQL.WriteByte(v) } + + sql = newSQL.String() } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { diff --git a/logger/sql_test.go b/logger/sql_test.go index 180570b8..b78f761c 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -29,6 +29,12 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), From 1d5f910b6e1a377f7f7defadb606a3e9c7a09c01 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 14 Sep 2020 15:29:47 +0800 Subject: [PATCH 53/67] Update workflows template --- .github/labels.json | 5 +++++ .github/workflows/invalid_question.yml | 22 ++++++++++++++++++++++ .github/workflows/missing_playground.yml | 2 +- 3 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/invalid_question.yml diff --git a/.github/labels.json b/.github/labels.json index 8b1ce849..6b9c2034 100644 --- a/.github/labels.json +++ b/.github/labels.json @@ -10,6 +10,11 @@ "colour": "#EDEDED", "description": "general questions" }, + "invalid_question": { + "name": "type:invalid question", + "colour": "#CF2E1F", + "description": "invalid question (not related to GORM or described in document or not enough information provided)" + }, "with_playground": { "name": "type:with reproduction steps", "colour": "#00ff00", diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml new file mode 100644 index 00000000..5b0bd981 --- /dev/null +++ b/.github/workflows/invalid_question.yml @@ -0,0 +1,22 @@ +name: "Close invalid questions issues" +on: + schedule: + - cron: "*/10 * * * *" + +jobs: + stale: + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v3.0.7 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 2 + remove-stale-when-updated: true + only-labels: "type:invalid question" + diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 422cb9f5..ea3207d6 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -13,7 +13,7 @@ jobs: uses: actions/stale@v3.0.7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 days-before-close: 2 From 06d534d6eaa7f8534e51742b9930818511aaf28c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Sep 2020 12:41:45 +0800 Subject: [PATCH 54/67] Cascade delete associations, close #3473 --- callbacks/delete.go | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 510dfae4..549a94e7 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -34,8 +34,23 @@ func DeleteBeforeAssociations(db *gorm.DB) { queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx := db.Session(&gorm.Session{}).Model(modelValue) - if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { - return + withoutConditions := false + + if len(db.Statement.Selects) > 0 { + tx = tx.Select(db.Statement.Selects) + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions { + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } } case schema.Many2Many: var ( From a932175ccf98130aaa3028b75daf047a32b6dca0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 15 Sep 2020 14:28:26 +0800 Subject: [PATCH 55/67] Refactor cascade delete associations --- callbacks/delete.go | 14 +++++++++++++- tests/delete_test.go | 2 +- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/callbacks/delete.go b/callbacks/delete.go index 549a94e7..85f11f4b 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -2,6 +2,7 @@ package callbacks import ( "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -37,7 +38,18 @@ func DeleteBeforeAssociations(db *gorm.DB) { withoutConditions := false if len(db.Statement.Selects) > 0 { - tx = tx.Select(db.Statement.Selects) + var selects []string + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if strings.HasPrefix(s, column+".") { + selects = append(selects, strings.TrimPrefix(s, column+".")) + } + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } } for _, cond := range queryConds { diff --git a/tests/delete_test.go b/tests/delete_test.go index 4945e837..ecd5ec39 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -136,7 +136,7 @@ func TestDeleteWithAssociations(t *testing.T) { t.Fatalf("failed to create user, got error %v", err) } - if err := DB.Select(clause.Associations).Delete(&user).Error; err != nil { + if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { t.Fatalf("failed to delete user, got error %v", err) } From d002c70cf6ac6f35e4a2840606e65d84d33c5391 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 17 Sep 2020 21:52:41 +0800 Subject: [PATCH 56/67] Support named argument for struct --- clause/expression.go | 12 ++++++++++++ clause/expression_test.go | 10 ++++++++++ tests/go.mod | 4 ++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index dde236d3..49924ef7 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -3,6 +3,7 @@ package clause import ( "database/sql" "database/sql/driver" + "go/ast" "reflect" ) @@ -89,6 +90,17 @@ func (expr NamedExpr) Build(builder Builder) { for k, v := range value { namedMap[k] = v } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + } + } + } } } diff --git a/clause/expression_test.go b/clause/expression_test.go index 17af737d..53d79c8f 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -37,6 +37,11 @@ func TestExpr(t *testing.T) { } func TestNamedExpr(t *testing.T) { + type NamedArgument struct { + Name1 string + Name2 string + } + results := []struct { SQL string Result string @@ -66,6 +71,11 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist", + Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, }} for idx, result := range results { diff --git a/tests/go.mod b/tests/go.mod index 17a3b156..0db87934 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -8,9 +8,9 @@ require ( github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.2 + gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 - gorm.io/gorm v1.20.0 + gorm.io/gorm v1.20.1 ) replace gorm.io/gorm => ../ From 072f1de83a842a991ea76cecfd14a7e93d5e67c1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Sep 2020 21:34:44 +0800 Subject: [PATCH 57/67] Add DryRunModeUnsupported Error for Row/Rows --- errors.go | 2 ++ finisher_api.go | 12 ++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/errors.go b/errors.go index 508f6957..08755083 100644 --- a/errors.go +++ b/errors.go @@ -29,4 +29,6 @@ var ( ErrInvalidField = errors.New("invalid field") // ErrEmptySlice empty slice found ErrEmptySlice = errors.New("empty slice found") + // ErrDryRunModeUnsupported dry run mode unsupported + ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") ) diff --git a/finisher_api.go b/finisher_api.go index f426839a..2c56d763 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -334,13 +334,21 @@ func (db *DB) Count(count *int64) (tx *DB) { func (db *DB) Row() *sql.Row { tx := db.getInstance().InstanceSet("rows", false) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Row) + row, ok := tx.Statement.Dest.(*sql.Row) + if !ok && tx.DryRun { + db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) + } + return row } func (db *DB) Rows() (*sql.Rows, error) { tx := db.getInstance().InstanceSet("rows", true) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Rows), tx.Error + rows, ok := tx.Statement.Dest.(*sql.Rows) + if !ok && tx.DryRun && tx.Error == nil { + tx.Error = ErrDryRunModeUnsupported + } + return rows, tx.Error } // Scan scan value to a struct From c9165fe3cafc9a66e2513caae381e6864fa0a15b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Sep 2020 21:42:27 +0800 Subject: [PATCH 58/67] Don't panic when using unmatched vars in query, close #3488 --- clause/expression.go | 4 ++-- clause/expression_test.go | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/clause/expression.go b/clause/expression.go index 49924ef7..6a0dde8d 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -31,7 +31,7 @@ func (expr Expr) Build(builder Builder) { ) for _, v := range []byte(expr.SQL) { - if v == '?' { + if v == '?' && len(expr.Vars) > idx { if afterParenthesis { if _, ok := expr.Vars[idx].(driver.Valuer); ok { builder.AddVar(builder, expr.Vars[idx]) @@ -122,7 +122,7 @@ func (expr NamedExpr) Build(builder Builder) { } builder.WriteByte(v) - } else if v == '?' { + } else if v == '?' && len(expr.Vars) > idx { builder.AddVar(builder, expr.Vars[idx]) idx++ } else if inName { diff --git a/clause/expression_test.go b/clause/expression_test.go index 53d79c8f..19e30e6c 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -76,6 +76,10 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{}, + Result: "create table ? (? ?, ? ?)", }} for idx, result := range results { From 089939c767f89087366799e47ab24d5b7b36c5e4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 18 Sep 2020 21:50:11 +0800 Subject: [PATCH 59/67] AutoMigrate should auto create indexes, close #3486 --- migrator/migrator.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/migrator/migrator.go b/migrator/migrator.go index 4b069c8a..f390ff9f 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -133,6 +133,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } } + + for _, idx := range stmt.Schema.ParseIndexes() { + if !tx.Migrator().HasIndex(value, idx.Name) { + if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + return nil }); err != nil { return err From 68920449f92f24c8b17d90986eb155c251ed8fc7 Mon Sep 17 00:00:00 2001 From: caelansar <31852257+caelansar@users.noreply.github.com> Date: Sat, 19 Sep 2020 13:48:34 +0800 Subject: [PATCH 60/67] Fix format sql log (#3492) --- logger/sql.go | 4 ++-- logger/sql_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/logger/sql.go b/logger/sql.go index 69a6b10e..138a35ec 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -52,9 +52,9 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper case driver.Valuer: reflectValue := reflect.ValueOf(v) - if v != nil && reflectValue.IsValid() && (reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) { + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { r, _ := v.Value() - vars[idx] = fmt.Sprintf("%v", r) + convertParams(r, idx) } else { vars[idx] = "NULL" } diff --git a/logger/sql_test.go b/logger/sql_test.go index b78f761c..71aa841a 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -1,13 +1,39 @@ package logger_test import ( + "database/sql/driver" + "encoding/json" + "fmt" "regexp" + "strings" "testing" "github.com/jinzhu/now" "gorm.io/gorm/logger" ) +type JSON json.RawMessage + +func (j JSON) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + return json.RawMessage(j).MarshalJSON() +} + +type ExampleStruct struct { + Name string + Val string +} + +func (s ExampleStruct) Value() (driver.Value, error) { + return json.Marshal(s) +} + +func format(v []byte, escaper string) string { + return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper +} + func TestExplainSQL(t *testing.T) { type role string type password []byte @@ -15,6 +41,10 @@ func TestExplainSQL(t *testing.T) { tt = now.MustParse("2020-02-23 11:10:10") myrole = role("admin") pwd = password([]byte("pass")) + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} ) results := []struct { @@ -53,6 +83,18 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, } for idx, r := range results { From 1a526e6802a9692a1340277551a9117644af21f0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 11:32:38 +0800 Subject: [PATCH 61/67] Fix NamingStrategy with embedded struct, close #3513 --- schema/field.go | 2 +- schema/naming.go | 2 +- schema/naming_test.go | 26 ++++++++++++++++ schema/schema.go | 3 ++ schema/schema_test.go | 70 +++++++++++++++++++++++++++++++++++++++++++ schema/utils.go | 5 ++++ tests/go.mod | 2 +- 7 files changed, 107 insertions(+), 3 deletions(-) diff --git a/schema/field.go b/schema/field.go index 4b8a5a2a..ce2808a8 100644 --- a/schema/field.go +++ b/schema/field.go @@ -326,7 +326,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { schema.err = err } diff --git a/schema/naming.go b/schema/naming.go index ecdab791..af753ce5 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -14,7 +14,7 @@ import ( type Namer interface { TableName(table string) string ColumnName(table, column string) string - JoinTableName(table string) string + JoinTableName(joinTable string) string RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string diff --git a/schema/naming_test.go b/schema/naming_test.go index 96b83ced..a4600ceb 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -1,6 +1,7 @@ package schema import ( + "strings" "testing" ) @@ -32,3 +33,28 @@ func TestToDBName(t *testing.T) { } } } + +type NewNamingStrategy struct { + NamingStrategy +} + +func (ns NewNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} diff --git a/schema/schema.go b/schema/schema.go index c3d3f6e0..cffc19a7 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -97,6 +97,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } schema := &Schema{ Name: modelType.Name(), diff --git a/schema/schema_test.go b/schema/schema_test.go index 6ca5b269..a426cd90 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "strings" "sync" "testing" @@ -227,3 +228,72 @@ func TestEmbeddedStruct(t *testing.T) { }) } } + +type CustomizedNamingStrategy struct { + schema.NamingStrategy +} + +func (ns CustomizedNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} + +func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + Ignored string `gorm:"-"` + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } + }) + } +} diff --git a/schema/utils.go b/schema/utils.go index 41bd9d60..55cbdeb4 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -190,3 +190,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa return columns, queryValues } } + +type embeddedNamer struct { + Table string + Namer +} diff --git a/tests/go.mod b/tests/go.mod index 0db87934..c92fa0cf 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,7 +7,7 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 - gorm.io/driver/postgres v1.0.0 + gorm.io/driver/postgres v1.0.1 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 gorm.io/gorm v1.20.1 From 52287359153b5788d95960c963f74bebcdea88c7 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 15:00:13 +0800 Subject: [PATCH 62/67] Don't build IN condition if value implemented Valuer interface, #3517 --- statement.go | 16 +++++++++++----- tests/query_test.go | 5 +++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/statement.go b/statement.go index ee80f8cd..38d35926 100644 --- a/statement.go +++ b/statement.go @@ -299,12 +299,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { - values[i] = reflectValue.Index(i).Interface() - } + if _, ok := v[key].(driver.Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else if _, ok := v[key].(Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else { + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } - conds = append(conds, clause.IN{Column: key, Values: values}) + conds = append(conds, clause.IN{Column: key, Values: values}) + } default: conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } diff --git a/tests/query_test.go b/tests/query_test.go index d3bcbdbe..9c9ad9f2 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -345,6 +345,11 @@ func TestNot(t *testing.T) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) From c0de3c505176b0fea74c2e09fb9cae7c595b7020 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 19:28:52 +0800 Subject: [PATCH 63/67] Support FullSaveAssociations Mode, close #3487, #3506 --- callbacks/associations.go | 61 +++++++++++++++++++-------------- callbacks/create.go | 5 ++- gorm.go | 7 ++++ logger/logger.go | 7 ++-- tests/update_belongs_to_test.go | 19 ++++++++++ tests/update_has_many_test.go | 41 ++++++++++++++++++++++ tests/update_has_one_test.go | 35 +++++++++++++++++++ tests/update_many2many_test.go | 25 ++++++++++++++ 8 files changed, 171 insertions(+), 29 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 0c677f47..64d79f24 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -66,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(elems.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -81,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(rv.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -145,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -168,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(f.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(f.Interface()).Error) } } } @@ -230,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: onConflictColumns(rel.FieldSchema), - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } } @@ -298,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -312,13 +305,31 @@ func SaveAfterAssociations(db *gorm.DB) { } } -func onConflictColumns(s *schema.Schema) (columns []clause.Column) { - if s.PrioritizedPrimaryField != nil { - return []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict { + if stmt.DB.FullSaveAssociations { + defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) + for _, dbName := range s.DBNames { + if !s.LookUpField(dbName).PrimaryKey { + defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) + } + } } - for _, dbName := range s.PrimaryFieldDBNames { - columns = append(columns, clause.Column{Name: dbName}) + if len(defaultUpdatingColumns) > 0 { + var columns []clause.Column + if s.PrioritizedPrimaryField != nil { + columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} + } else { + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + } + + return clause.OnConflict{ + Columns: columns, + DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), + } } - return + + return clause.OnConflict{DoNothing: true} } diff --git a/callbacks/create.go b/callbacks/create.go index c00a0a73..8e2454e8 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -88,7 +88,10 @@ func Create(config *Config) func(db *gorm.DB) { } case reflect.Struct: if insertID > 0 { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } } } else { diff --git a/gorm.go b/gorm.go index 8efd8a73..e5c4a8a4 100644 --- a/gorm.go +++ b/gorm.go @@ -20,6 +20,8 @@ type Config struct { SkipDefaultTransaction bool // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer + // FullSaveAssociations full save associations + FullSaveAssociations bool // Logger Logger logger.Interface // NowFunc the function to be used when creating a new timestamp @@ -64,6 +66,7 @@ type Session struct { WithConditions bool SkipDefaultTransaction bool AllowGlobalUpdate bool + FullSaveAssociations bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB { txConfig.AllowGlobalUpdate = true } + if config.FullSaveAssociations { + txConfig.FullSaveAssociations = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx diff --git a/logger/logger.go b/logger/logger.go index 831192fc..e568fb24 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -20,6 +20,7 @@ const ( Magenta = "\033[35m" Cyan = "\033[36m" White = "\033[37m" + BlueBold = "\033[34;1m" MagentaBold = "\033[35;1m" RedBold = "\033[31;1m" YellowBold = "\033[33;1m" @@ -76,11 +77,11 @@ func New(writer Writer, config Config) Interface { if config.Colorful { infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset - warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset - traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset - traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" } return &logger{ diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 47076e69..736dfc5b 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) { var user2 User DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + user.Company.Name += "new" + user.Manager.Name += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) } diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index 01ea2e3a..9066cbac 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) { DB.Preload("Pets").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + for _, pet := range user.Pets { + pet.Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Pets").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Pets").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var user = *GetUser("update-has-many", Config{}) @@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) { var user2 User DB.Preload("Toys").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Toys { + user.Toys[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Toys").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Toys").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) }) } diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 7b29f424..54568546 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) { DB.Preload("Account").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + user.Account.Number += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Account").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Account").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var pet = Pet{Name: "create"} @@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) { var pet2 Pet DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) CheckPet(t, pet2, pet) + + pet.Toy.Name += "new" + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet3 Pet + DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID) + CheckPet(t, pet2, pet3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet4 Pet + DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) + CheckPet(t, pet4, pet) }) } diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index a46deeb0..d94ef4ab 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) { var user2 User DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Friends { + user.Friends[idx].Name += "new" + } + + for idx := range user.Languages { + user.Languages[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) } From ba253982bf558543187f3eb88295b88610cdc83b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 24 Sep 2020 20:08:24 +0800 Subject: [PATCH 64/67] Fix Pluck with Time and Scanner --- scan.go | 13 +++++++++++-- schema/field.go | 6 ++++-- tests/query_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/scan.go b/scan.go index be8782ed..d7cddbe6 100644 --- a/scan.go +++ b/scan.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "reflect" "strings" + "time" "gorm.io/gorm/schema" ) @@ -82,7 +83,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64, *string: + case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time: for initialized || rows.Next() { initialized = false db.RowsAffected++ @@ -134,7 +135,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } // pluck values into slice of data - isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct + isPluck := false + if len(fields) == 1 { + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok { + isPluck = true + } else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) { + isPluck = true + } + } + for initialized || rows.Next() { initialized = false db.RowsAffected++ diff --git a/schema/field.go b/schema/field.go index ce2808a8..db516c33 100644 --- a/schema/field.go +++ b/schema/field.go @@ -18,6 +18,8 @@ type DataType string type TimeType int64 +var TimeReflectType = reflect.TypeOf(time.Time{}) + const ( UnixSecond TimeType = 1 UnixMillisecond TimeType = 2 @@ -102,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { for i := 0; i < rv.Type().NumField(); i++ { newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { @@ -221,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { field.DataType = Time diff --git a/tests/query_test.go b/tests/query_test.go index 9c9ad9f2..431ccce2 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "database/sql" "fmt" "reflect" "regexp" @@ -431,6 +432,33 @@ func TestPluck(t *testing.T) { t.Errorf("Unexpected result on pluck id, got %+v", ids) } } + + var times []time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range times { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var ptrtimes []*time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range ptrtimes { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var nulltimes []sql.NullTime + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range nulltimes { + AssertEqual(t, tv.Time, users[idx].CreatedAt) + } } func TestSelect(t *testing.T) { From 9eec6ae06638665661f9872e783a42613527e146 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 27 Sep 2020 12:25:38 +0800 Subject: [PATCH 65/67] Fix affected rows for Scan, change affected rows count for row/rows to '-', close #3532 --- callbacks.go | 1 - callbacks/row.go | 2 ++ finisher_api.go | 8 ++++++++ logger/logger.go | 49 +++++++++++++++++++++++++++++++++++++++--------- scan.go | 1 + 5 files changed, 51 insertions(+), 10 deletions(-) diff --git a/callbacks.go b/callbacks.go index 83d103df..fdde21e9 100644 --- a/callbacks.go +++ b/callbacks.go @@ -74,7 +74,6 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) { curTime := time.Now() stmt := db.Statement - db.RowsAffected = 0 if stmt.Model == nil { stmt.Model = stmt.Dest diff --git a/callbacks/row.go b/callbacks/row.go index a36c0116..4f985d7b 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -16,6 +16,8 @@ func RowQuery(db *gorm.DB) { } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } + + db.RowsAffected = -1 } } } diff --git a/finisher_api.go b/finisher_api.go index 2c56d763..63061553 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -353,7 +354,9 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { + currentLogger, newLogger := db.Logger, logger.Recorder.New() tx = db.getInstance() + tx.Logger = newLogger if rows, err := tx.Rows(); err != nil { tx.AddError(err) } else { @@ -362,6 +365,11 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.ScanRows(rows, dest) } } + + currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { + return newLogger.SQL, tx.RowsAffected + }, tx.Error) + tx.Logger = currentLogger return } diff --git a/logger/logger.go b/logger/logger.go index e568fb24..b278ad5d 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -63,6 +63,7 @@ var ( LogLevel: Warn, Colorful: true, }) + Recorder = traceRecorder{Interface: Default} ) func New(writer Writer, config Config) Interface { @@ -70,18 +71,18 @@ func New(writer Writer, config Config) Interface { infoStr = "%s\n[info] " warnStr = "%s\n[warn] " errStr = "%s\n[error] " - traceStr = "%s\n[%.3fms] [rows:%d] %s" - traceWarnStr = "%s\n[%.3fms] [rows:%d] %s" - traceErrStr = "%s %s\n[%.3fms] [rows:%d] %s" + traceStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s\n[%.3fms] [rows:%v] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" ) if config.Colorful { infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset - traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" - traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset - traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%d]" + Reset + " %s" + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" } return &logger{ @@ -138,13 +139,43 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i switch { case err != nil && l.LogLevel >= Error: sql, rows := fc() - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: sql, rows := fc() - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } case l.LogLevel >= Info: sql, rows := fc() - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } } } } + +type traceRecorder struct { + Interface + BeginAt time.Time + SQL string + RowsAffected int64 + Err error +} + +func (l traceRecorder) New() *traceRecorder { + return &traceRecorder{Interface: l.Interface} +} + +func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + l.BeginAt = begin + l.SQL, l.RowsAffected = fc() + l.Err = err +} diff --git a/scan.go b/scan.go index d7cddbe6..8d737b17 100644 --- a/scan.go +++ b/scan.go @@ -52,6 +52,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) + db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: From a2faa41cbe55dc37e2e0c30cab0fcd1b6d00c5fe Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 28 Sep 2020 10:55:27 +0800 Subject: [PATCH 66/67] Refactor NamingStrategy, close #3540 --- schema/naming.go | 7 ++++--- schema/naming_test.go | 46 ++++++++++++++++++++++++------------------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/schema/naming.go b/schema/naming.go index af753ce5..dbc71e04 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -42,7 +42,7 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { if strings.ToLower(str) == str { - return str + return ns.TablePrefix + str } if ns.SingularTable { @@ -53,17 +53,18 @@ func (ns NamingStrategy) JoinTableName(str string) string { // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)) + return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1) } // CheckerName generate checker name func (ns NamingStrategy) CheckerName(table, column string) string { - return fmt.Sprintf("chk_%s_%s", table, column) + return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1) } // IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) + idxName = strings.Replace(idxName, ".", "_", -1) if utf8.RuneCountInString(idxName) > 64 { h := sha1.New() diff --git a/schema/naming_test.go b/schema/naming_test.go index a4600ceb..26b0dcf6 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -1,7 +1,6 @@ package schema import ( - "strings" "testing" ) @@ -34,27 +33,34 @@ func TestToDBName(t *testing.T) { } } -type NewNamingStrategy struct { - NamingStrategy -} +func TestNamingStrategy(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + } + idxName := ns.IndexName("public.table", "name") -func (ns NewNamingStrategy) ColumnName(table, column string) string { - baseColumnName := ns.NamingStrategy.ColumnName(table, column) - - if table == "" { - return baseColumnName + if idxName != "idx_public_table_name" { + t.Errorf("invalid index name generated, got %v", idxName) } - s := strings.Split(table, "_") - - var prefix string - switch len(s) { - case 1: - prefix = s[0][:3] - case 2: - prefix = s[0][:1] + s[1][:2] - default: - prefix = s[0][:1] + s[1][:1] + s[2][:1] + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.user_language" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.company" { + t.Errorf("invalid table name generated, got %v", tableName) } - return prefix + "_" + baseColumnName } From dbc6b34dce7f5c4ce6f358d23bc70ac738af7793 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 29 Sep 2020 15:42:58 +0800 Subject: [PATCH 67/67] Add detailed error information when missing table name --- callbacks.go | 6 +++++- tests/go.mod | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/callbacks.go b/callbacks.go index fdde21e9..e21e0718 100644 --- a/callbacks.go +++ b/callbacks.go @@ -83,7 +83,11 @@ func (p *processor) Execute(db *DB) { if stmt.Model != nil { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { - db.AddError(err) + if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { + db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) + } else { + db.AddError(err) + } } } diff --git a/tests/go.mod b/tests/go.mod index c92fa0cf..cbafcd7e 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,10 +7,10 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 gorm.io/driver/mysql v1.0.1 - gorm.io/driver/postgres v1.0.1 + gorm.io/driver/postgres v1.0.2 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.4 - gorm.io/gorm v1.20.1 + gorm.io/gorm v1.20.2 ) replace gorm.io/gorm => ../