Merge branch 'master' into distinguish_unique

# Conflicts:
#	tests/go.mod
This commit is contained in:
black 2023-11-20 15:44:03 +08:00
commit 647bd36d01
14 changed files with 279 additions and 50 deletions

View File

@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly.
© Jinzhu, 2013~time.Now © Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License) Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)

View File

@ -103,13 +103,53 @@ func Create(config *Config) func(db *gorm.DB) {
} }
db.RowsAffected, _ = result.RowsAffected() db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected != 0 && db.Statement.Schema != nil && if db.RowsAffected == 0 {
db.Statement.Schema.PrioritizedPrimaryField != nil && return
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { }
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0 var (
if !insertOk { pkField *schema.Field
db.AddError(err) pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
return
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return return
} }
@ -122,10 +162,10 @@ func Create(config *Config) func(db *gorm.DB) {
break break
} }
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) _, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero { if isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement insertID -= pkField.AutoIncrementIncrement
} }
} }
} else { } else {
@ -135,16 +175,16 @@ func Create(config *Config) func(db *gorm.DB) {
break break
} }
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement insertID += pkField.AutoIncrementIncrement
} }
} }
} }
case reflect.Struct: case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero { if isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
} }
} }
} }

View File

@ -69,7 +69,7 @@ type Interface interface {
} }
var ( var (
// Discard Discard logger will print any log to io.Discard // Discard logger will print any log to io.Discard
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger // Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
@ -78,7 +78,7 @@ var (
IgnoreRecordNotFoundError: false, IgnoreRecordNotFoundError: false,
Colorful: true, Colorful: true,
}) })
// Recorder Recorder logger records running SQL into a recorder instance // Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
) )
@ -129,28 +129,30 @@ func (l *logger) LogMode(level LogLevel) Interface {
} }
// Info print info // Info print info
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) { func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Info { if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Warn print warn messages // Warn print warn messages
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) { func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Warn { if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Error print error messages // Error print error messages
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) { func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Error { if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
} }
} }
// Trace print sql message // Trace print sql message
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { //
//nolint:cyclop
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent { if l.LogLevel <= Silent {
return return
} }
@ -182,8 +184,8 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i
} }
} }
// Trace print sql message // ParamsFilter filter params
func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries { if l.Config.ParameterizedQueries {
return sql, nil return sql, nil
} }
@ -198,8 +200,8 @@ type traceRecorder struct {
Err error Err error
} }
// New new trace recorder // New trace recorder
func (l traceRecorder) New() *traceRecorder { func (l *traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
} }

View File

@ -27,6 +27,8 @@ var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
// TODO:? Create const vars for raw sql queries ? // TODO:? Create const vars for raw sql queries ?
var _ gorm.Migrator = (*Migrator)(nil)
// Migrator m struct // Migrator m struct
type Migrator struct { type Migrator struct {
Config Config

View File

@ -49,6 +49,8 @@ const (
Bytes DataType = "bytes" Bytes DataType = "bytes"
) )
const DefaultAutoIncrementIncrement int64 = 1
// Field is the representation of model schema's field // Field is the representation of model schema's field
type Field struct { type Field struct {
Name string Name string
@ -125,7 +127,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]), Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"], Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: 1, AutoIncrementIncrement: DefaultAutoIncrementIncrement,
} }
for field.IndirectFieldType.Kind() == reflect.Ptr { for field.IndirectFieldType.Kind() == reflect.Ptr {

View File

@ -27,6 +27,8 @@ type Replacer interface {
Replace(name string) string Replace(name string) string
} }
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy // NamingStrategy tables, columns naming strategy
type NamingStrategy struct { type NamingStrategy struct {
TablePrefix string TablePrefix string

View File

@ -278,8 +278,6 @@ func TestBelongsToAssociationUnscoped(t *testing.T) {
t.Fatalf("failed to create items, got error: %v", err) t.Fatalf("failed to create items, got error: %v", err)
} }
tx = tx.Debug()
// test replace // test replace
if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{
Logo: "updated logo", Logo: "updated logo",

View File

@ -29,7 +29,7 @@ func TestCountWithGroup(t *testing.T) {
} }
var count2 int64 var count2 int64
if err := DB.Debug().Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err)) t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
} }
if count2 != 2 { if count2 != 2 {

View File

@ -2,6 +2,7 @@ package tests_test
import ( import (
"errors" "errors"
"fmt"
"regexp" "regexp"
"testing" "testing"
"time" "time"
@ -580,7 +581,7 @@ func TestCreateWithAutoIncrementCompositeKey(t *testing.T) {
} }
} }
func TestCreateOnConfilctWithDefalutNull(t *testing.T) { func TestCreateOnConflictWithDefaultNull(t *testing.T) {
type OnConfilctUser struct { type OnConfilctUser struct {
ID string ID string
Name string `gorm:"default:null"` Name string `gorm:"default:null"`
@ -615,3 +616,180 @@ func TestCreateOnConfilctWithDefalutNull(t *testing.T) {
AssertEqual(t, u2.Email, "on-confilct-user-email-2") AssertEqual(t, u2.Email, "on-confilct-user-email-2")
AssertEqual(t, u2.Mobile, "133xxxx") AssertEqual(t, u2.Mobile, "133xxxx")
} }
func TestCreateFromMapWithoutPK(t *testing.T) {
if !isMysql() {
t.Skipf("This test case skipped, because of only supportting for mysql")
}
// case 1: one record, create from map[string]interface{}
mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1}
if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := mapValue1["id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}
var result1 User
if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
var idVal int64
_, ok := mapValue1["id"].(uint)
if ok {
t.Skipf("This test case skipped, because the db supports returning")
}
idVal, ok = mapValue1["id"].(int64)
if !ok {
t.Fatal("ret result missing id")
}
if int64(result1.ID) != idVal {
t.Fatal("failed to create data from map with table, @id != id")
}
// case2: one record, create from *map[string]interface{}
mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1}
if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := mapValue2["id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}
var result2 User
if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
_, ok = mapValue2["id"].(uint)
if ok {
t.Skipf("This test case skipped, because the db supports returning")
}
idVal, ok = mapValue2["id"].(int64)
if !ok {
t.Fatal("ret result missing id")
}
if int64(result2.ID) != idVal {
t.Fatal("failed to create data from map with table, @id != id")
}
// case 3: records
values := []map[string]interface{}{
{"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1},
}
beforeLen := len(values)
if err := DB.Model(&User{}).Create(&values).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
// mariadb with returning, values will be appended with id map
if len(values) == beforeLen*2 {
t.Skipf("This test case skipped, because the db supports returning")
}
for i := range values {
v, ok := values[i]["id"]
if !ok {
t.Fatal("failed to create data from map with table, returning map has no primary key")
}
var result User
if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 {
t.Fatalf("failed to create from map, got error %v", err)
}
if int64(result.ID) != v.(int64) {
t.Fatal("failed to create data from map with table, @id != id")
}
}
}
func TestCreateFromMapWithTable(t *testing.T) {
if !isMysql() {
t.Skipf("This test case skipped, because of only supportting for mysql")
}
tableDB := DB.Table("`users`")
// case 1: create from map[string]interface{}
record := map[string]interface{}{"`name`": "create_from_map_with_table", "`age`": 18}
if err := tableDB.Create(record).Error; err != nil {
t.Fatalf("failed to create data from map with table, got error: %v", err)
}
if _, ok := record["@id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
var res map[string]interface{}
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) {
t.Fatalf("failed to create from map, got error %v", err)
}
if int64(res["id"].(uint64)) != record["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
}
// case 2: create from *map[string]interface{}
record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18}
tableDB2 := DB.Table("users")
if err := tableDB2.Create(&record1).Error; err != nil {
t.Fatalf("failed to create data from map, got error: %v", err)
}
if _, ok := record1["@id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
var res1 map[string]interface{}
if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) {
t.Fatalf("failed to create from map, got error %v", err)
}
if int64(res1["id"].(uint64)) != record1["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
}
// case 3: create from []map[string]interface{}
records := []map[string]interface{}{
{"name": "create_from_map_with_table_2", "age": 19},
{"name": "create_from_map_with_table_3", "age": 20},
}
tableDB = DB.Table("users")
if err := tableDB.Create(&records).Error; err != nil {
t.Fatalf("failed to create data from slice of map, got error: %v", err)
}
if _, ok := records[0]["@id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
if _, ok := records[1]["@id"]; !ok {
t.Fatal("failed to create data from map with table, returning map has no key '@id'")
}
var res2 map[string]interface{}
if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
var res3 map[string]interface{}
if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) {
t.Fatalf("failed to query data after create from slice of map, got error %v", err)
}
if int64(res2["id"].(uint64)) != records[0]["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
}
if int64(res3["id"].(uint64)) != records[1]["@id"] {
t.Fatal("failed to create data from map with table, @id != id")
}
}

View File

@ -236,8 +236,15 @@ func TestEmbeddedScanValuer(t *testing.T) {
} }
func TestEmbeddedRelations(t *testing.T) { func TestEmbeddedRelations(t *testing.T) {
type EmbUser struct {
gorm.Model
Name string
Age uint
Languages []Language `gorm:"many2many:EmbUserSpeak;"`
}
type AdvancedUser struct { type AdvancedUser struct {
User `gorm:"embedded"` EmbUser `gorm:"embedded"`
Advanced bool Advanced bool
} }

View File

@ -3,15 +3,15 @@ module gorm.io/gorm/tests
go 1.18 go 1.18
require ( require (
github.com/google/uuid v1.3.1 github.com/google/uuid v1.4.0
github.com/jinzhu/now v1.1.5 github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
gorm.io/driver/mysql v1.5.2-0.20230612053416-48b6526a21f0 gorm.io/driver/mysql v1.5.2
gorm.io/driver/postgres v1.5.3-0.20230607070428-18bc84b75196 gorm.io/driver/postgres v1.5.4
gorm.io/driver/sqlite v1.5.4 gorm.io/driver/sqlite v1.5.4
gorm.io/driver/sqlserver v1.5.2-0.20230613072041-6e2cde390b0a gorm.io/driver/sqlserver v1.5.2
gorm.io/gorm v1.25.4 gorm.io/gorm v1.25.5
) )
require ( require (
@ -21,16 +21,15 @@ require (
github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jackc/pgx/v5 v5.5.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/mattn/go-sqlite3 v1.14.18 // indirect
github.com/microsoft/go-mssqldb v1.6.0 // indirect github.com/microsoft/go-mssqldb v1.6.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.15.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/text v0.13.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )
replace gorm.io/gorm => ../ replace gorm.io/gorm => ../
replace github.com/jackc/pgx/v5 => github.com/jackc/pgx/v5 v5.4.3

View File

@ -429,7 +429,6 @@ func TestEmbedPreload(t *testing.T) {
}, },
} }
DB = DB.Debug()
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
actual := Org{} actual := Org{}

View File

@ -43,9 +43,6 @@ func init() {
} }
RunMigrations() RunMigrations()
if DB.Dialector.Name() == "sqlite" {
DB.Exec("PRAGMA foreign_keys = ON")
}
} }
} }
@ -89,7 +86,10 @@ func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) {
db, err = gorm.Open(mysql.Open(dbDSN), cfg) db, err = gorm.Open(mysql.Open(dbDSN), cfg)
default: default:
log.Println("testing sqlite3...") log.Println("testing sqlite3...")
db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db?_foreign_keys=on")), cfg) db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg)
if err == nil {
db.Exec("PRAGMA foreign_keys = ON")
}
} }
if err != nil { if err != nil {

View File

@ -838,7 +838,7 @@ func TestSaveWithHooks(t *testing.T) {
saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) {
var newOwner TokenOwner var newOwner TokenOwner
if err := DB.Transaction(func(tx *gorm.DB) error { if err := DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Debug().Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil {
return err return err
} }
if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil { if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil {