Merge branch 'go-gorm:master' into master

This commit is contained in:
ayakut 2024-03-21 18:36:22 +02:00 committed by GitHub
commit 35bade6ca7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 108 additions and 70 deletions

View File

@ -111,6 +111,17 @@ func Create(config *Config) func(db *gorm.DB) {
pkField *schema.Field pkField *schema.Field
pkFieldName = "@id" pkFieldName = "@id"
) )
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
if !supportReturning {
db.AddError(err)
}
return
}
if db.Statement.Schema != nil { if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return return
@ -119,13 +130,6 @@ func Create(config *Config) func(db *gorm.DB) {
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
} }
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
return
}
// append @id column with value for auto-increment primary key // append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) { switch values := db.Statement.Dest.(type) {
@ -142,6 +146,11 @@ func Create(config *Config) func(db *gorm.DB) {
} }
} }
} }
if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}
for _, mapValue := range mapValues { for _, mapValue := range mapValues {
if mapValue != nil { if mapValue != nil {
mapValue[pkFieldName] = insertID mapValue[pkFieldName] = insertID

View File

@ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
// RegEx matches only numeric values // RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var ( var (
@ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
convertParams(v, idx) convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() { } else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx) convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else { } else {
for _, t := range convertibleTypes { for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) { if rv.Type().ConvertibleTo(t) {

View File

@ -37,14 +37,18 @@ func format(v []byte, escaper string) string {
func TestExplainSQL(t *testing.T) { func TestExplainSQL(t *testing.T) {
type role string type role string
type password []byte type password []byte
type intType int
type floatType float64
var ( var (
tt = now.MustParse("2020-02-23 11:10:10") tt = now.MustParse("2020-02-23 11:10:10")
myrole = role("admin") myrole = role("admin")
pwd = password("pass") pwd = password("pass")
jsVal = []byte(`{"Name":"test","Val":"test"}`) jsVal = []byte(`{"Name":"test","Val":"test"}`)
js = JSON(jsVal) js = JSON(jsVal)
esVal = []byte(`{"Name":"test","Val":"test"}`) esVal = []byte(`{"Name":"test","Val":"test"}`)
es = ExampleStruct{Name: "test", Val: "test"} es = ExampleStruct{Name: "test", Val: "test"}
intVal intType = 1
floatVal floatType = 1.23
) )
results := []struct { results := []struct {
@ -107,6 +111,18 @@ func TestExplainSQL(t *testing.T) {
Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es},
Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)),
}, },
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`,
},
{
SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
NumericRegexp: nil,
Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal},
Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`,
},
} }
for idx, r := range results { for idx, r := range results {

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"strconv"
"strings" "strings"
"time" "time"
@ -518,12 +519,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} else if !dvNotNull && currentDefaultNotNull { } else if !dvNotNull && currentDefaultNotNull {
// null -> default value // null -> default value
alterColumn = true alterColumn = true
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || } else if currentDefaultNotNull || dvNotNull {
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { switch field.GORMDataType {
// default value not equal case schema.Time:
// not both null if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
if currentDefaultNotNull || dvNotNull { alterColumn = true
alterColumn = true }
case schema.Bool:
v1, _ := strconv.ParseBool(dv)
v2, _ := strconv.ParseBool(field.DefaultValue)
alterColumn = v1 != v2
default:
alterColumn = dv != field.DefaultValue
} }
} }
} }

View File

@ -3,6 +3,8 @@ package gorm
import ( import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver"
"errors"
"reflect" "reflect"
"sync" "sync"
) )
@ -147,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
stmt, err := db.prepare(ctx, db.ConnPool, false, query) stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil { if err == nil {
result, err = stmt.ExecContext(ctx, args...) result, err = stmt.ExecContext(ctx, args...)
if err != nil { if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock() defer db.Mux.Unlock()
go stmt.Close() go stmt.Close()
@ -161,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
stmt, err := db.prepare(ctx, db.ConnPool, false, query) stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil { if err == nil {
rows, err = stmt.QueryContext(ctx, args...) rows, err = stmt.QueryContext(ctx, args...)
if err != nil { if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock() db.Mux.Lock()
defer db.Mux.Unlock() defer db.Mux.Unlock()
@ -207,7 +209,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil { if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if err != nil { if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock() tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock() defer tx.PreparedStmtDB.Mux.Unlock()
@ -222,7 +224,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil { if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if err != nil { if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock() tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock() defer tx.PreparedStmtDB.Mux.Unlock()

View File

@ -713,18 +713,16 @@ func TestCreateFromMapWithoutPK(t *testing.T) {
} }
func TestCreateFromMapWithTable(t *testing.T) { func TestCreateFromMapWithTable(t *testing.T) {
if !isMysql() { tableDB := DB.Table("users")
t.Skipf("This test case skipped, because of only supportting for mysql") supportLastInsertID := isMysql() || isSqlite()
}
tableDB := DB.Table("`users`")
// case 1: create from map[string]interface{} // case 1: create from map[string]interface{}
record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18}
if err := tableDB.Create(record).Error; err != nil { if err := tableDB.Create(record).Error; err != nil {
t.Fatalf("failed to create data from map with table, got error: %v", err) t.Fatalf("failed to create data from map with table, got error: %v", err)
} }
if _, ok := record["@id"]; !ok { if _, ok := record["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'") t.Fatal("failed to create data from map with table, returning map has no key '@id'")
} }
@ -733,8 +731,8 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Fatalf("failed to create from map, got error %v", err) t.Fatalf("failed to create from map, got error %v", err)
} }
if int64(res["id"].(uint64)) != record["@id"] { if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) {
t.Fatal("failed to create data from map with table, @id != id") t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"])
} }
// case 2: create from *map[string]interface{} // case 2: create from *map[string]interface{}
@ -743,7 +741,7 @@ func TestCreateFromMapWithTable(t *testing.T) {
if err := tableDB2.Create(&record1).Error; err != nil { if err := tableDB2.Create(&record1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err) t.Fatalf("failed to create data from map, got error: %v", err)
} }
if _, ok := record1["@id"]; !ok { if _, ok := record1["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'") t.Fatal("failed to create data from map with table, returning map has no key '@id'")
} }
@ -752,7 +750,7 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Fatalf("failed to create from map, got error %v", err) t.Fatalf("failed to create from map, got error %v", err)
} }
if int64(res1["id"].(uint64)) != record1["@id"] { if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) {
t.Fatal("failed to create data from map with table, @id != id") t.Fatal("failed to create data from map with table, @id != id")
} }
@ -767,11 +765,11 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Fatalf("failed to create data from slice of map, got error: %v", err) t.Fatalf("failed to create data from slice of map, got error: %v", err)
} }
if _, ok := records[0]["@id"]; !ok { if _, ok := records[0]["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'") t.Fatal("failed to create data from map with table, returning map has no key '@id'")
} }
if _, ok := records[1]["@id"]; !ok { if _, ok := records[1]["@id"]; !ok && supportLastInsertID {
t.Fatal("failed to create data from map with table, returning map has no key '@id'") t.Fatal("failed to create data from map with table, returning map has no key '@id'")
} }
@ -785,11 +783,11 @@ func TestCreateFromMapWithTable(t *testing.T) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err) t.Fatalf("failed to query data after create from slice of map, got error %v", err)
} }
if int64(res2["id"].(uint64)) != records[0]["@id"] { if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) {
t.Fatal("failed to create data from map with table, @id != id") t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"])
} }
if int64(res3["id"].(uint64)) != records[1]["@id"] { if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) {
t.Fatal("failed to create data from map with table, @id != id") t.Errorf("failed to create data from map with table, @id != id")
} }
} }

View File

@ -7,11 +7,11 @@ require (
github.com/jinzhu/now v1.1.5 github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
gorm.io/driver/mysql v1.5.5 gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.5 gorm.io/driver/sqlite v1.5.5
gorm.io/driver/sqlserver v1.5.3 gorm.io/driver/sqlserver v1.5.3
gorm.io/gorm v1.25.7 gorm.io/gorm v1.25.8
) )
require ( require (

View File

@ -281,6 +281,10 @@ func isMysql() bool {
return os.Getenv("GORM_DIALECT") == "mysql" return os.Getenv("GORM_DIALECT") == "mysql"
} }
func isSqlite() bool {
return os.Getenv("GORM_DIALECT") == "sqlite"
}
func db(unscoped bool) *gorm.DB { func db(unscoped bool) *gorm.DB {
if unscoped { if unscoped {
return DB.Unscoped() return DB.Unscoped()

View File

@ -7,6 +7,7 @@ import (
"math/rand" "math/rand"
"os" "os"
"reflect" "reflect"
"strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -1420,7 +1421,7 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) {
AssertEqual(t, nil, err) AssertEqual(t, nil, err)
} }
func TestMigrateDefaultNullString(t *testing.T) { func TestMigrateWithDefaultValue(t *testing.T) {
if DB.Dialector.Name() == "sqlserver" { if DB.Dialector.Name() == "sqlserver" {
// sqlserver driver treats NULL and 'NULL' the same // sqlserver driver treats NULL and 'NULL' the same
t.Skip("skip sqlserver") t.Skip("skip sqlserver")
@ -1434,6 +1435,7 @@ func TestMigrateDefaultNullString(t *testing.T) {
type NullStringModel struct { type NullStringModel struct {
ID uint ID uint
Content string `gorm:"default:'null'"` Content string `gorm:"default:'null'"`
Active bool `gorm:"default:false"`
} }
tableName := "null_string_model" tableName := "null_string_model"
@ -1454,6 +1456,14 @@ func TestMigrateDefaultNullString(t *testing.T) {
AssertEqual(t, defVal, "null") AssertEqual(t, defVal, "null")
AssertEqual(t, ok, true) AssertEqual(t, ok, true)
columnType2, err := findColumnType(tableName, "active")
AssertEqual(t, err, nil)
defVal, ok = columnType2.DefaultValue()
bv, _ := strconv.ParseBool(defVal)
AssertEqual(t, bv, false)
AssertEqual(t, ok, true)
// default 'null' -> 'null' // default 'null' -> 'null'
session := DB.Session(&gorm.Session{Logger: Tracer{ session := DB.Session(&gorm.Session{Logger: Tracer{
Logger: DB.Config.Logger, Logger: DB.Config.Logger,

View File

@ -126,33 +126,6 @@ func TestPreparedStmtDeadlock(t *testing.T) {
AssertEqual(t, sqlDB.Stats().InUse, 0) AssertEqual(t, sqlDB.Stats().InUse, 0)
} }
func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection(&gorm.Config{})
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)
tx = tx.Session(&gorm.Session{PrepareStmt: true})
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
// err prepare
tag := Tag{Locale: "zh"}
tx.Table("users").Find(&tag)
wg.Done()
}()
}
wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 0)
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtInTransaction(t *testing.T) { func TestPreparedStmtInTransaction(t *testing.T) {
user := User{Name: "jinzhu"} user := User{Name: "jinzhu"}