Merge branch 'master' into feat_type_aliases

This commit is contained in:
a631807682 2022-09-22 12:01:24 +08:00
commit 311b852cfd
No known key found for this signature in database
GPG Key ID: 137D1D75522168AB
8 changed files with 106 additions and 59 deletions

View File

@ -70,11 +70,13 @@ func Update(config *Config) func(db *gorm.DB) {
if db.Statement.SQL.Len() == 0 { if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180) db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{}) db.Statement.AddClauseIfNotExists(clause.Update{})
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 { if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set) db.Statement.AddClause(set)
} else if _, ok := db.Statement.Clauses["SET"]; !ok { } else {
return return
} }
}
db.Statement.Build(db.Statement.BuildClauses...) db.Statement.Build(db.Statement.BuildClauses...)
} }

View File

@ -13,7 +13,7 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// Create insert the value into database // Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) { func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 { if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize) return db.CreateInBatches(value, db.CreateBatchSize)
@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
return tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
} }
// CreateInBatches insert the value in batches into database // CreateInBatches inserts value in batches of batchSize
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
@ -68,7 +68,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
return return
} }
// Save update value in database, if the value doesn't have primary key, will insert it // Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) { func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
@ -114,7 +114,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
return return
} }
// First find first record that match given conditions, order by primary key // First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Take return a record that match given conditions, the order will depend on the database implementation // Take finds the first record returned by the database in no specified order, matching given conditions conds
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1) tx = db.Limit(1)
if len(conds) > 0 { if len(conds) > 0 {
@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Last find last record that match given conditions, order by primary key // Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Find find records that match given conditions // Find finds all records matching given conditions conds
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// FindInBatches find records in batches // FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var ( var (
tx = db.Order(clause.OrderByColumn{ tx = db.Order(clause.OrderByColumn{
@ -286,7 +286,8 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
} }
} }
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) // FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -312,7 +313,8 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
return return
} }
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) // FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
@ -360,14 +362,14 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx return tx
} }
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields // Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value interface{}) (tx *DB) { func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
return tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
} }
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields // Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
@ -388,8 +390,8 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
return tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
} }
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. // Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
// If value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current // value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null. // time if null.
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
@ -484,7 +486,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
return rows, tx.Error return rows, tx.Error
} }
// Scan scan value to a struct // Scan scans selected value to the struct dest
func (db *DB) Scan(dest interface{}) (tx *DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New() currentLogger, newLogger := config.Logger, logger.Recorder.New()
@ -509,7 +511,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
return return
} }
// Pluck used to query single column from a model as a map // Pluck queries a single column from a model, returning in the slice dest. E.g.:
// var ages []int64 // var ages []int64
// db.Model(&users).Pluck("age", &ages) // db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
@ -552,7 +554,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
return tx.Error return tx.Error
} }
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. // Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) { func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil { if db.Error != nil {
return db.Error return db.Error
@ -574,7 +577,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
return fc(tx) return fc(tx)
} }
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true panicked := true
@ -617,7 +622,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
return return
} }
// Begin begins a transaction // Begin begins a transaction with any transaction options opts
func (db *DB) Begin(opts ...*sql.TxOptions) *DB { func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var ( var (
// clone statement // clone statement
@ -646,7 +651,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
return tx return tx
} }
// Commit commit a transaction // Commit commits the changes in a transaction
func (db *DB) Commit() *DB { func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit()) db.AddError(committer.Commit())
@ -656,7 +661,7 @@ func (db *DB) Commit() *DB {
return db return db
} }
// Rollback rollback a transaction // Rollback rollbacks the changes in a transaction
func (db *DB) Rollback() *DB { func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() { if !reflect.ValueOf(committer).IsNil() {
@ -686,7 +691,7 @@ func (db *DB) RollbackTo(name string) *DB {
return db return db
} }
// Exec execute raw sql // Exec executes raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}

View File

@ -15,7 +15,7 @@ import (
) )
var ( var (
regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
) )
// Migrator m struct // Migrator m struct
@ -135,6 +135,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
} }
} }
}
for _, chk := range stmt.Schema.ParseCheckConstraints() { for _, chk := range stmt.Schema.ParseCheckConstraints() {
if !tx.Migrator().HasConstraint(value, chk.Name) { if !tx.Migrator().HasConstraint(value, chk.Name) {
@ -143,7 +144,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
} }
} }
}
for _, idx := range stmt.Schema.ParseIndexes() { for _, idx := range stmt.Schema.ParseIndexes() {
if !tx.Migrator().HasIndex(value, idx.Name) { if !tx.Migrator().HasIndex(value, idx.Name) {
@ -490,7 +490,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} }
if alterColumn && !field.IgnoreMigration { if alterColumn && !field.IgnoreMigration {
return m.DB.Migrator().AlterColumn(value, field.Name) return m.DB.Migrator().AlterColumn(value, field.DBName)
} }
return nil return nil

View File

@ -66,9 +66,12 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
joinedSchemaMap := make(map[*schema.Field]interface{}, 0) joinedSchemaMap := make(map[*schema.Field]interface{})
for idx, field := range fields { for idx, field := range fields {
if field != nil { if field == nil {
continue
}
if len(joinFields) == 0 || joinFields[idx][0] == nil { if len(joinFields) == 0 || joinFields[idx][0] == nil {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else { } else {
@ -91,7 +94,6 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
field.NewValuePool.Put(values[idx]) field.NewValuePool.Put(values[idx])
} }
} }
}
// ScanMode scan data mode // ScanMode scan data mode
type ScanMode uint8 type ScanMode uint8

View File

@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false return false
} }
var nameMatcher = regexp.MustCompile(`^(?:[\W]?([A-Za-z_0-9]+?)[\W]?\.)?[\W]?([A-Za-z_0-9]+?)[\W]?$`) var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {

View File

@ -959,3 +959,41 @@ func TestMigrateArrayTypeModel(t *testing.T) {
AssertEqual(t, nil, err) AssertEqual(t, nil, err)
AssertEqual(t, "integer[]", ct.DatabaseTypeName()) AssertEqual(t, "integer[]", ct.DatabaseTypeName())
} }
func TestMigrateSameEmbeddedFieldName(t *testing.T) {
type UserStat struct {
GroundDestroyCount int
}
type GameUser struct {
gorm.Model
StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"`
}
type UserStat1 struct {
GroundDestroyCount string
}
type GroundRate struct {
GroundDestroyCount int
}
type GameUser1 struct {
gorm.Model
StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"`
GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"`
}
DB.Migrator().DropTable(&GameUser{})
err := DB.AutoMigrate(&GameUser{})
AssertEqual(t, nil, err)
err = DB.Table("game_users").AutoMigrate(&GameUser1{})
AssertEqual(t, nil, err)
_, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count")
AssertEqual(t, nil, err)
_, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count")
AssertEqual(t, nil, err)
}

View File

@ -63,13 +63,13 @@ func TestPostgres(t *testing.T) {
} }
type Post struct { type Post struct {
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
Title string Title string
Categories []*Category `gorm:"Many2Many:post_categories"` Categories []*Category `gorm:"Many2Many:post_categories"`
} }
type Category struct { type Category struct {
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
Title string Title string
Posts []*Post `gorm:"Many2Many:post_categories"` Posts []*Post `gorm:"Many2Many:post_categories"`
} }

View File

@ -62,7 +62,7 @@ func TestUpsert(t *testing.T) {
} }
r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"})
if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
} }
} }