Merge branch 'go-gorm:master' into master
This commit is contained in:
		
						commit
						35bade6ca7
					
				| @ -111,6 +111,17 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 			pkField     *schema.Field | ||||
| 			pkFieldName = "@id" | ||||
| 		) | ||||
| 
 | ||||
| 		insertID, err := result.LastInsertId() | ||||
| 		insertOk := err == nil && insertID > 0 | ||||
| 
 | ||||
| 		if !insertOk { | ||||
| 			if !supportReturning { | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		if db.Statement.Schema != nil { | ||||
| 			if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||
| 				return | ||||
| @ -119,13 +130,6 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 			pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName | ||||
| 		} | ||||
| 
 | ||||
| 		insertID, err := result.LastInsertId() | ||||
| 		insertOk := err == nil && insertID > 0 | ||||
| 		if !insertOk { | ||||
| 			db.AddError(err) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		// append @id column with value for auto-increment primary key
 | ||||
| 		// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
 | ||||
| 		switch values := db.Statement.Dest.(type) { | ||||
| @ -142,6 +146,11 @@ func Create(config *Config) func(db *gorm.DB) { | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if config.LastInsertIDReversed { | ||||
| 				insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement | ||||
| 			} | ||||
| 
 | ||||
| 			for _, mapValue := range mapValues { | ||||
| 				if mapValue != nil { | ||||
| 					mapValue[pkFieldName] = insertID | ||||
|  | ||||
| @ -34,6 +34,19 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO | ||||
| // RegEx matches only numeric values
 | ||||
| var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) | ||||
| 
 | ||||
| func isNumeric(k reflect.Kind) bool { | ||||
| 	switch k { | ||||
| 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 		return true | ||||
| 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||||
| 		return true | ||||
| 	case reflect.Float32, reflect.Float64: | ||||
| 		return true | ||||
| 	default: | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
 | ||||
| func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { | ||||
| 	var ( | ||||
| @ -110,6 +123,12 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a | ||||
| 				convertParams(v, idx) | ||||
| 			} else if rv.Kind() == reflect.Ptr && !rv.IsZero() { | ||||
| 				convertParams(reflect.Indirect(rv).Interface(), idx) | ||||
| 			} else if isNumeric(rv.Kind()) { | ||||
| 				if rv.CanInt() || rv.CanUint() { | ||||
| 					vars[idx] = fmt.Sprintf("%d", rv.Interface()) | ||||
| 				} else { | ||||
| 					vars[idx] = fmt.Sprintf("%.6f", rv.Interface()) | ||||
| 				} | ||||
| 			} else { | ||||
| 				for _, t := range convertibleTypes { | ||||
| 					if rv.Type().ConvertibleTo(t) { | ||||
|  | ||||
| @ -37,14 +37,18 @@ func format(v []byte, escaper string) string { | ||||
| func TestExplainSQL(t *testing.T) { | ||||
| 	type role string | ||||
| 	type password []byte | ||||
| 	type intType int | ||||
| 	type floatType float64 | ||||
| 	var ( | ||||
| 		tt     = now.MustParse("2020-02-23 11:10:10") | ||||
| 		myrole = role("admin") | ||||
| 		pwd    = password("pass") | ||||
| 		jsVal  = []byte(`{"Name":"test","Val":"test"}`) | ||||
| 		js     = JSON(jsVal) | ||||
| 		esVal  = []byte(`{"Name":"test","Val":"test"}`) | ||||
| 		es     = ExampleStruct{Name: "test", Val: "test"} | ||||
| 		tt                 = now.MustParse("2020-02-23 11:10:10") | ||||
| 		myrole             = role("admin") | ||||
| 		pwd                = password("pass") | ||||
| 		jsVal              = []byte(`{"Name":"test","Val":"test"}`) | ||||
| 		js                 = JSON(jsVal) | ||||
| 		esVal              = []byte(`{"Name":"test","Val":"test"}`) | ||||
| 		es                 = ExampleStruct{Name: "test", Val: "test"} | ||||
| 		intVal   intType   = 1 | ||||
| 		floatVal floatType = 1.23 | ||||
| 	) | ||||
| 
 | ||||
| 	results := []struct { | ||||
| @ -107,6 +111,18 @@ func TestExplainSQL(t *testing.T) { | ||||
| 			Vars:          []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, | ||||
| 			Result:        fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | ||||
| 			NumericRegexp: nil, | ||||
| 			Vars:          []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SQL:           "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", | ||||
| 			NumericRegexp: nil, | ||||
| 			Vars:          []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal}, | ||||
| 			Result:        `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`, | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for idx, r := range results { | ||||
|  | ||||
| @ -7,6 +7,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| @ -518,12 +519,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy | ||||
| 		} else if !dvNotNull && currentDefaultNotNull { | ||||
| 			// null -> default value
 | ||||
| 			alterColumn = true | ||||
| 		} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) || | ||||
| 			(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) { | ||||
| 			// default value not equal
 | ||||
| 			// not both null
 | ||||
| 			if currentDefaultNotNull || dvNotNull { | ||||
| 				alterColumn = true | ||||
| 		} else if currentDefaultNotNull || dvNotNull { | ||||
| 			switch field.GORMDataType { | ||||
| 			case schema.Time: | ||||
| 				if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) { | ||||
| 					alterColumn = true | ||||
| 				} | ||||
| 			case schema.Bool: | ||||
| 				v1, _ := strconv.ParseBool(dv) | ||||
| 				v2, _ := strconv.ParseBool(field.DefaultValue) | ||||
| 				alterColumn = v1 != v2 | ||||
| 			default: | ||||
| 				alterColumn = dv != field.DefaultValue | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| @ -3,6 +3,8 @@ package gorm | ||||
| import ( | ||||
| 	"context" | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
| ) | ||||
| @ -147,7 +149,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. | ||||
| 	stmt, err := db.prepare(ctx, db.ConnPool, false, query) | ||||
| 	if err == nil { | ||||
| 		result, err = stmt.ExecContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 		if errors.Is(err, driver.ErrBadConn) { | ||||
| 			db.Mux.Lock() | ||||
| 			defer db.Mux.Unlock() | ||||
| 			go stmt.Close() | ||||
| @ -161,7 +163,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . | ||||
| 	stmt, err := db.prepare(ctx, db.ConnPool, false, query) | ||||
| 	if err == nil { | ||||
| 		rows, err = stmt.QueryContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 		if errors.Is(err, driver.ErrBadConn) { | ||||
| 			db.Mux.Lock() | ||||
| 			defer db.Mux.Unlock() | ||||
| 
 | ||||
| @ -207,7 +209,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) | ||||
| 	if err == nil { | ||||
| 		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 		if errors.Is(err, driver.ErrBadConn) { | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||
| 
 | ||||
| @ -222,7 +224,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . | ||||
| 	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) | ||||
| 	if err == nil { | ||||
| 		rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) | ||||
| 		if err != nil { | ||||
| 		if errors.Is(err, driver.ErrBadConn) { | ||||
| 			tx.PreparedStmtDB.Mux.Lock() | ||||
| 			defer tx.PreparedStmtDB.Mux.Unlock() | ||||
| 
 | ||||
|  | ||||
| @ -713,18 +713,16 @@ func TestCreateFromMapWithoutPK(t *testing.T) { | ||||
| } | ||||
| 
 | ||||
| func TestCreateFromMapWithTable(t *testing.T) { | ||||
| 	if !isMysql() { | ||||
| 		t.Skipf("This test case skipped, because of only supportting for mysql") | ||||
| 	} | ||||
| 	tableDB := DB.Table("`users`") | ||||
| 	tableDB := DB.Table("users") | ||||
| 	supportLastInsertID := isMysql() || isSqlite() | ||||
| 
 | ||||
| 	// case 1: create from map[string]interface{}
 | ||||
| 	record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18} | ||||
| 	record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18} | ||||
| 	if err := tableDB.Create(record).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map with table, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := record["@id"]; !ok { | ||||
| 	if _, ok := record["@id"]; !ok && supportLastInsertID { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| @ -733,8 +731,8 @@ func TestCreateFromMapWithTable(t *testing.T) { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res["id"].(uint64)) != record["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) { | ||||
| 		t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"]) | ||||
| 	} | ||||
| 
 | ||||
| 	// case 2: create from *map[string]interface{}
 | ||||
| @ -743,7 +741,7 @@ func TestCreateFromMapWithTable(t *testing.T) { | ||||
| 	if err := tableDB2.Create(&record1).Error; err != nil { | ||||
| 		t.Fatalf("failed to create data from map, got error: %v", err) | ||||
| 	} | ||||
| 	if _, ok := record1["@id"]; !ok { | ||||
| 	if _, ok := record1["@id"]; !ok && supportLastInsertID { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| @ -752,7 +750,7 @@ func TestCreateFromMapWithTable(t *testing.T) { | ||||
| 		t.Fatalf("failed to create from map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res1["id"].(uint64)) != record1["@id"] { | ||||
| 	if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| 
 | ||||
| @ -767,11 +765,11 @@ func TestCreateFromMapWithTable(t *testing.T) { | ||||
| 		t.Fatalf("failed to create data from slice of map, got error: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := records[0]["@id"]; !ok { | ||||
| 	if _, ok := records[0]["@id"]; !ok && supportLastInsertID { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| 	if _, ok := records[1]["@id"]; !ok { | ||||
| 	if _, ok := records[1]["@id"]; !ok && supportLastInsertID { | ||||
| 		t.Fatal("failed to create data from map with table, returning map has no key '@id'") | ||||
| 	} | ||||
| 
 | ||||
| @ -785,11 +783,11 @@ func TestCreateFromMapWithTable(t *testing.T) { | ||||
| 		t.Fatalf("failed to query data after create from slice of map, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res2["id"].(uint64)) != records[0]["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) { | ||||
| 		t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"]) | ||||
| 	} | ||||
| 
 | ||||
| 	if int64(res3["id"].(uint64)) != records[1]["@id"] { | ||||
| 		t.Fatal("failed to create data from map with table, @id != id") | ||||
| 	if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) { | ||||
| 		t.Errorf("failed to create data from map with table, @id != id") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -7,11 +7,11 @@ require ( | ||||
| 	github.com/jinzhu/now v1.1.5 | ||||
| 	github.com/lib/pq v1.10.9 | ||||
| 	github.com/stretchr/testify v1.9.0 | ||||
| 	gorm.io/driver/mysql v1.5.5 | ||||
| 	gorm.io/driver/mysql v1.5.6 | ||||
| 	gorm.io/driver/postgres v1.5.7 | ||||
| 	gorm.io/driver/sqlite v1.5.5 | ||||
| 	gorm.io/driver/sqlserver v1.5.3 | ||||
| 	gorm.io/gorm v1.25.7 | ||||
| 	gorm.io/gorm v1.25.8 | ||||
| ) | ||||
| 
 | ||||
| require ( | ||||
|  | ||||
| @ -281,6 +281,10 @@ func isMysql() bool { | ||||
| 	return os.Getenv("GORM_DIALECT") == "mysql" | ||||
| } | ||||
| 
 | ||||
| func isSqlite() bool { | ||||
| 	return os.Getenv("GORM_DIALECT") == "sqlite" | ||||
| } | ||||
| 
 | ||||
| func db(unscoped bool) *gorm.DB { | ||||
| 	if unscoped { | ||||
| 		return DB.Unscoped() | ||||
|  | ||||
| @ -7,6 +7,7 @@ import ( | ||||
| 	"math/rand" | ||||
| 	"os" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| @ -1420,7 +1421,7 @@ func TestMigrateSameEmbeddedFieldName(t *testing.T) { | ||||
| 	AssertEqual(t, nil, err) | ||||
| } | ||||
| 
 | ||||
| func TestMigrateDefaultNullString(t *testing.T) { | ||||
| func TestMigrateWithDefaultValue(t *testing.T) { | ||||
| 	if DB.Dialector.Name() == "sqlserver" { | ||||
| 		// sqlserver driver treats NULL and 'NULL' the same
 | ||||
| 		t.Skip("skip sqlserver") | ||||
| @ -1434,6 +1435,7 @@ func TestMigrateDefaultNullString(t *testing.T) { | ||||
| 	type NullStringModel struct { | ||||
| 		ID      uint | ||||
| 		Content string `gorm:"default:'null'"` | ||||
| 		Active  bool   `gorm:"default:false"` | ||||
| 	} | ||||
| 
 | ||||
| 	tableName := "null_string_model" | ||||
| @ -1454,6 +1456,14 @@ func TestMigrateDefaultNullString(t *testing.T) { | ||||
| 	AssertEqual(t, defVal, "null") | ||||
| 	AssertEqual(t, ok, true) | ||||
| 
 | ||||
| 	columnType2, err := findColumnType(tableName, "active") | ||||
| 	AssertEqual(t, err, nil) | ||||
| 
 | ||||
| 	defVal, ok = columnType2.DefaultValue() | ||||
| 	bv, _ := strconv.ParseBool(defVal) | ||||
| 	AssertEqual(t, bv, false) | ||||
| 	AssertEqual(t, ok, true) | ||||
| 
 | ||||
| 	// default 'null' -> 'null'
 | ||||
| 	session := DB.Session(&gorm.Session{Logger: Tracer{ | ||||
| 		Logger: DB.Config.Logger, | ||||
|  | ||||
| @ -126,33 +126,6 @@ func TestPreparedStmtDeadlock(t *testing.T) { | ||||
| 	AssertEqual(t, sqlDB.Stats().InUse, 0) | ||||
| } | ||||
| 
 | ||||
| func TestPreparedStmtError(t *testing.T) { | ||||
| 	tx, err := OpenTestConnection(&gorm.Config{}) | ||||
| 	AssertEqual(t, err, nil) | ||||
| 
 | ||||
| 	sqlDB, _ := tx.DB() | ||||
| 	sqlDB.SetMaxOpenConns(1) | ||||
| 
 | ||||
| 	tx = tx.Session(&gorm.Session{PrepareStmt: true}) | ||||
| 
 | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	for i := 0; i < 10; i++ { | ||||
| 		wg.Add(1) | ||||
| 		go func() { | ||||
| 			// err prepare
 | ||||
| 			tag := Tag{Locale: "zh"} | ||||
| 			tx.Table("users").Find(&tag) | ||||
| 			wg.Done() | ||||
| 		}() | ||||
| 	} | ||||
| 	wg.Wait() | ||||
| 
 | ||||
| 	conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) | ||||
| 	AssertEqual(t, ok, true) | ||||
| 	AssertEqual(t, len(conn.Stmts), 0) | ||||
| 	AssertEqual(t, sqlDB.Stats().InUse, 0) | ||||
| } | ||||
| 
 | ||||
| func TestPreparedStmtInTransaction(t *testing.T) { | ||||
| 	user := User{Name: "jinzhu"} | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 ayakut
						ayakut