diff --git a/callbacks/create.go b/callbacks/create.go index d930e922..afea2cca 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -111,6 +111,17 @@ func Create(config *Config) func(db *gorm.DB) { pkField *schema.Field pkFieldName = "@id" ) + + insertID, err := result.LastInsertId() + insertOk := err == nil && insertID > 0 + + if !insertOk { + if !supportReturning { + db.AddError(err) + } + return + } + if db.Statement.Schema != nil { if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { return @@ -119,13 +130,6 @@ func Create(config *Config) func(db *gorm.DB) { pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName } - insertID, err := result.LastInsertId() - insertOk := err == nil && insertID > 0 - if !insertOk { - db.AddError(err) - return - } - // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { @@ -142,6 +146,11 @@ func Create(config *Config) func(db *gorm.DB) { } } } + + if config.LastInsertIDReversed { + insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement + } + for _, mapValue := range mapValues { if mapValue != nil { mapValue[pkFieldName] = insertID diff --git a/logger/sql.go b/logger/sql.go index 8ce8d8b1..ad478795 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO // RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) +func isNumeric(k reflect.Kind) bool { + switch k { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return true + case reflect.Float32, reflect.Float64: + return true + default: + return false + } +} + // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var ( @@ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) + } else if isNumeric(rv.Kind()) { + if rv.CanInt() || rv.CanUint() { + vars[idx] = fmt.Sprintf("%d", rv.Interface()) + } else { + vars[idx] = fmt.Sprintf("%.6f", rv.Interface()) + } } else { for _, t := range convertibleTypes { if rv.Type().ConvertibleTo(t) { diff --git a/logger/sql_test.go b/logger/sql_test.go index 036ef3a4..9002a7eb 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -37,14 +37,18 @@ func format(v []byte, escaper string) string { func TestExplainSQL(t *testing.T) { type role string type password []byte + type intType int + type floatType float64 var ( - tt = now.MustParse("2020-02-23 11:10:10") - myrole = role("admin") - pwd = password("pass") - jsVal = []byte(`{"Name":"test","Val":"test"}`) - js = JSON(jsVal) - esVal = []byte(`{"Name":"test","Val":"test"}`) - es = ExampleStruct{Name: "test", Val: "test"} + tt = now.MustParse("2020-02-23 11:10:10") + myrole = role("admin") + pwd = password("pass") + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} + intVal intType = 1 + floatVal floatType = 1.23 ) results := []struct { @@ -107,6 +111,18 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`, + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`, + }, } for idx, r := range results { diff --git a/migrator/migrator.go b/migrator/migrator.go index 9fee9d60..702fda2b 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -7,6 +7,7 @@ import ( "fmt" "reflect" "regexp" + "strconv" "strings" "time" @@ -518,12 +519,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true - } else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || - (field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { - // default value not equal - // not both null - if currentDefaultNotNull || dvNotNull { - alterColumn = true + } else if currentDefaultNotNull || dvNotNull { + switch field.GORMDataType { + case schema.Time: + if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) { + alterColumn = true + } + case schema.Bool: + v1, _ := strconv.ParseBool(dv) + v2, _ := strconv.ParseBool(field.DefaultValue) + alterColumn = v1 != v2 + default: + alterColumn = dv != field.DefaultValue } } } diff --git a/prepare_stmt.go b/prepare_stmt.go index aa944624..c60b5db7 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -3,6 +3,8 @@ package gorm import ( "context" "database/sql" + "database/sql/driver" + "errors" "reflect" "sync" ) @@ -147,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() go stmt.Close() @@ -161,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { db.Mux.Lock() defer db.Mux.Unlock() @@ -207,7 +209,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() @@ -222,7 +224,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) - if err != nil { + if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Mux.Lock() defer tx.PreparedStmtDB.Mux.Unlock() diff --git a/tests/create_test.go b/tests/create_test.go index 5e97a542..abb82472 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -713,18 +713,16 @@ func TestCreateFromMapWithoutPK(t *testing.T) { } func TestCreateFromMapWithTable(t *testing.T) { - if !isMysql() { - t.Skipf("This test case skipped, because of only supportting for mysql") - } - tableDB := DB.Table("`users`") + tableDB := DB.Table("users") + supportLastInsertID := isMysql() || isSqlite() // case 1: create from map[string]interface{} - record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} + record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18} if err := tableDB.Create(record).Error; err != nil { t.Fatalf("failed to create data from map with table, got error: %v", err) } - if _, ok := record["@id"]; !ok { + if _, ok := record["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -733,8 +731,8 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create from map, got error %v", err) } - if int64(res["id"].(uint64)) != record["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) { + t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"]) } // case 2: create from *map[string]interface{} @@ -743,7 +741,7 @@ func TestCreateFromMapWithTable(t *testing.T) { if err := tableDB2.Create(&record1).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } - if _, ok := record1["@id"]; !ok { + if _, ok := record1["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -752,7 +750,7 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create from map, got error %v", err) } - if int64(res1["id"].(uint64)) != record1["@id"] { + if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) { t.Fatal("failed to create data from map with table, @id != id") } @@ -767,11 +765,11 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to create data from slice of map, got error: %v", err) } - if _, ok := records[0]["@id"]; !ok { + if _, ok := records[0]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } - if _, ok := records[1]["@id"]; !ok { + if _, ok := records[1]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } @@ -785,11 +783,11 @@ func TestCreateFromMapWithTable(t *testing.T) { t.Fatalf("failed to query data after create from slice of map, got error %v", err) } - if int64(res2["id"].(uint64)) != records[0]["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"]) } - if int64(res3["id"].(uint64)) != records[1]["@id"] { - t.Fatal("failed to create data from map with table, @id != id") + if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) { + t.Errorf("failed to create data from map with table, @id != id") } } diff --git a/tests/go.mod b/tests/go.mod index 350152d3..3d3901d9 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,11 +7,11 @@ require ( github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 github.com/stretchr/testify v1.9.0 - gorm.io/driver/mysql v1.5.5 + gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlserver v1.5.3 - gorm.io/gorm v1.25.7 + gorm.io/gorm v1.25.8 ) require ( diff --git a/tests/helper_test.go b/tests/helper_test.go index feb67f9e..dc250b7c 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -281,6 +281,10 @@ func isMysql() bool { return os.Getenv("GORM_DIALECT") == "mysql" } +func isSqlite() bool { + return os.Getenv("GORM_DIALECT") == "sqlite" +} + func db(unscoped bool) *gorm.DB { if unscoped { return DB.Unscoped() diff --git a/tests/migrate_test.go b/tests/migrate_test.go index b25b9da6..d955c8d7 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -7,6 +7,7 @@ import ( "math/rand" "os" "reflect" + "strconv" "strings" "testing" "time" @@ -1420,7 +1421,7 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) { AssertEqual(t, nil, err) } -func TestMigrateDefaultNullString(t *testing.T) { +func TestMigrateWithDefaultValue(t *testing.T) { if DB.Dialector.Name() == "sqlserver" { // sqlserver driver treats NULL and 'NULL' the same t.Skip("skip sqlserver") @@ -1434,6 +1435,7 @@ func TestMigrateDefaultNullString(t *testing.T) { type NullStringModel struct { ID uint Content string `gorm:"default:'null'"` + Active bool `gorm:"default:false"` } tableName := "null_string_model" @@ -1454,6 +1456,14 @@ func TestMigrateDefaultNullString(t *testing.T) { AssertEqual(t, defVal, "null") AssertEqual(t, ok, true) + columnType2, err := findColumnType(tableName, "active") + AssertEqual(t, err, nil) + + defVal, ok = columnType2.DefaultValue() + bv, _ := strconv.ParseBool(defVal) + AssertEqual(t, bv, false) + AssertEqual(t, ok, true) + // default 'null' -> 'null' session := DB.Session(&gorm.Session{Logger: Tracer{ Logger: DB.Config.Logger, diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index b234c8bf..b86bc3d6 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -126,33 +126,6 @@ func TestPreparedStmtDeadlock(t *testing.T) { AssertEqual(t, sqlDB.Stats().InUse, 0) } -func TestPreparedStmtError(t *testing.T) { - tx, err := OpenTestConnection(&gorm.Config{}) - AssertEqual(t, err, nil) - - sqlDB, _ := tx.DB() - sqlDB.SetMaxOpenConns(1) - - tx = tx.Session(&gorm.Session{PrepareStmt: true}) - - wg := sync.WaitGroup{} - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - // err prepare - tag := Tag{Locale: "zh"} - tx.Table("users").Find(&tag) - wg.Done() - }() - } - wg.Wait() - - conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) - AssertEqual(t, ok, true) - AssertEqual(t, len(conn.Stmts), 0) - AssertEqual(t, sqlDB.Stats().InUse, 0) -} - func TestPreparedStmtInTransaction(t *testing.T) { user := User{Name: "jinzhu"}