Merge branch 'master' into test_migrate_column
This commit is contained in:
commit
02b48e6798
2
.github/workflows/invalid_question.yml
vendored
2
.github/workflows/invalid_question.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v5
|
uses: actions/stale@v6
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
|
2
.github/workflows/missing_playground.yml
vendored
2
.github/workflows/missing_playground.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v5
|
uses: actions/stale@v6
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"
|
||||||
|
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
ACTIONS_STEP_DEBUG: true
|
ACTIONS_STEP_DEBUG: true
|
||||||
steps:
|
steps:
|
||||||
- name: Close Stale Issues
|
- name: Close Stale Issues
|
||||||
uses: actions/stale@v5
|
uses: actions/stale@v6
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"
|
||||||
|
8
.github/workflows/tests.yml
vendored
8
.github/workflows/tests.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
sqlite:
|
sqlite:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.18', '1.17', '1.16']
|
go: ['1.19', '1.18', '1.17', '1.16']
|
||||||
platform: [ubuntu-latest] # can not run in windows OS
|
platform: [ubuntu-latest] # can not run in windows OS
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
|
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
|
||||||
go: ['1.18', '1.17', '1.16']
|
go: ['1.19', '1.18', '1.17', '1.16']
|
||||||
platform: [ubuntu-latest]
|
platform: [ubuntu-latest]
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
|
||||||
go: ['1.18', '1.17', '1.16']
|
go: ['1.19', '1.18', '1.17', '1.16']
|
||||||
platform: [ubuntu-latest] # can not run in macOS and Windows
|
platform: [ubuntu-latest] # can not run in macOS and Windows
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ jobs:
|
|||||||
sqlserver:
|
sqlserver:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go: ['1.18', '1.17', '1.16']
|
go: ['1.19', '1.18', '1.17', '1.16']
|
||||||
platform: [ubuntu-latest] # can not run test in macOS and windows
|
platform: [ubuntu-latest] # can not run test in macOS and windows
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ coverage.txt
|
|||||||
_book
|
_book
|
||||||
.idea
|
.idea
|
||||||
vendor
|
vendor
|
||||||
|
.vscode
|
||||||
|
@ -507,8 +507,10 @@ func (association *Association) buildCondition() *DB {
|
|||||||
joinStmt.AddClause(queryClause)
|
joinStmt.AddClause(queryClause)
|
||||||
}
|
}
|
||||||
joinStmt.Build("WHERE")
|
joinStmt.Build("WHERE")
|
||||||
|
if len(joinStmt.SQL.String()) > 0 {
|
||||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
|
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
|
||||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||||
|
@ -206,7 +206,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := utils.ToStringKey(relPrimaryValues)
|
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||||
identityMap[cacheKey] = true
|
identityMap[cacheKey] = true
|
||||||
if isPtr {
|
if isPtr {
|
||||||
@ -292,7 +292,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cacheKey := utils.ToStringKey(relPrimaryValues)
|
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||||
identityMap[cacheKey] = true
|
identityMap[cacheKey] = true
|
||||||
distinctElems = reflect.Append(distinctElems, elem)
|
distinctElems = reflect.Append(distinctElems, elem)
|
||||||
|
@ -70,11 +70,13 @@ func Update(config *Config) func(db *gorm.DB) {
|
|||||||
if db.Statement.SQL.Len() == 0 {
|
if db.Statement.SQL.Len() == 0 {
|
||||||
db.Statement.SQL.Grow(180)
|
db.Statement.SQL.Grow(180)
|
||||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||||
|
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||||
db.Statement.AddClause(set)
|
db.Statement.AddClause(set)
|
||||||
} else if _, ok := db.Statement.Clauses["SET"]; !ok {
|
} else {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
db.Statement.Build(db.Statement.BuildClauses...)
|
db.Statement.Build(db.Statement.BuildClauses...)
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ import (
|
|||||||
"gorm.io/gorm/utils"
|
"gorm.io/gorm/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create insert the value into database
|
// Create inserts value, returning the inserted data's primary key in value's id
|
||||||
func (db *DB) Create(value interface{}) (tx *DB) {
|
func (db *DB) Create(value interface{}) (tx *DB) {
|
||||||
if db.CreateBatchSize > 0 {
|
if db.CreateBatchSize > 0 {
|
||||||
return db.CreateInBatches(value, db.CreateBatchSize)
|
return db.CreateInBatches(value, db.CreateBatchSize)
|
||||||
@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Create().Execute(tx)
|
return tx.callbacks.Create().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateInBatches insert the value in batches into database
|
// CreateInBatches inserts value in batches of batchSize
|
||||||
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
|
||||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Dest = value
|
tx.Statement.Dest = value
|
||||||
@ -114,7 +114,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// First find first record that match given conditions, order by primary key
|
// First finds the first record ordered by primary key, matching given conditions conds
|
||||||
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Query().Execute(tx)
|
return tx.callbacks.Query().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
// Take finds the first record returned by the database in no specified order, matching given conditions conds
|
||||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.Limit(1)
|
tx = db.Limit(1)
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Query().Execute(tx)
|
return tx.callbacks.Query().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Last find last record that match given conditions, order by primary key
|
// Last finds the last record ordered by primary key, matching given conditions conds
|
||||||
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Query().Execute(tx)
|
return tx.callbacks.Query().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find find records that match given conditions
|
// Find finds all records matching given conditions conds
|
||||||
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Query().Execute(tx)
|
return tx.callbacks.Query().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindInBatches find records in batches
|
// FindInBatches finds all records in batches of batchSize
|
||||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
||||||
var (
|
var (
|
||||||
tx = db.Order(clause.OrderByColumn{
|
tx = db.Order(clause.OrderByColumn{
|
||||||
@ -202,7 +202,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
|||||||
batch++
|
batch++
|
||||||
|
|
||||||
if result.Error == nil && result.RowsAffected != 0 {
|
if result.Error == nil && result.RowsAffected != 0 {
|
||||||
tx.AddError(fc(result, batch))
|
fcTx := result.Session(&Session{NewDB: true})
|
||||||
|
fcTx.RowsAffected = result.RowsAffected
|
||||||
|
tx.AddError(fc(fcTx, batch))
|
||||||
} else if result.Error != nil {
|
} else if result.Error != nil {
|
||||||
tx.AddError(result.Error)
|
tx.AddError(result.Error)
|
||||||
}
|
}
|
||||||
@ -284,7 +286,8 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
|
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
|
||||||
|
// Each conds must be a struct or map.
|
||||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||||
@ -310,7 +313,8 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
|
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
|
||||||
|
// Each conds must be a struct or map.
|
||||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
||||||
@ -358,14 +362,14 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
|||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||||
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||||
return tx.callbacks.Update().Execute(tx)
|
return tx.callbacks.Update().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
|
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||||
func (db *DB) Updates(values interface{}) (tx *DB) {
|
func (db *DB) Updates(values interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.Dest = values
|
tx.Statement.Dest = values
|
||||||
@ -386,7 +390,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
|||||||
return tx.callbacks.Update().Execute(tx)
|
return tx.callbacks.Update().Execute(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
|
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
|
||||||
|
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
|
||||||
|
// time if null.
|
||||||
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
if len(conds) > 0 {
|
if len(conds) > 0 {
|
||||||
@ -480,7 +486,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
|
|||||||
return rows, tx.Error
|
return rows, tx.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan scan value to a struct
|
// Scan scans selected value to the struct dest
|
||||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||||
config := *db.Config
|
config := *db.Config
|
||||||
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
||||||
@ -505,7 +511,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pluck used to query single column from a model as a map
|
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
|
||||||
// var ages []int64
|
// var ages []int64
|
||||||
// db.Model(&users).Pluck("age", &ages)
|
// db.Model(&users).Pluck("age", &ages)
|
||||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||||
@ -548,7 +554,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
|||||||
return tx.Error
|
return tx.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
|
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
|
||||||
|
// returned to the connection pool.
|
||||||
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
||||||
if db.Error != nil {
|
if db.Error != nil {
|
||||||
return db.Error
|
return db.Error
|
||||||
@ -570,7 +577,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
|||||||
return fc(tx)
|
return fc(tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
|
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
|
||||||
|
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
|
||||||
|
// they are rolled back.
|
||||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||||
panicked := true
|
panicked := true
|
||||||
|
|
||||||
@ -613,7 +622,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin begins a transaction
|
// Begin begins a transaction with any transaction options opts
|
||||||
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||||
var (
|
var (
|
||||||
// clone statement
|
// clone statement
|
||||||
@ -642,7 +651,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
|||||||
return tx
|
return tx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit commit a transaction
|
// Commit commits the changes in a transaction
|
||||||
func (db *DB) Commit() *DB {
|
func (db *DB) Commit() *DB {
|
||||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
||||||
db.AddError(committer.Commit())
|
db.AddError(committer.Commit())
|
||||||
@ -652,7 +661,7 @@ func (db *DB) Commit() *DB {
|
|||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rollback rollback a transaction
|
// Rollback rollbacks the changes in a transaction
|
||||||
func (db *DB) Rollback() *DB {
|
func (db *DB) Rollback() *DB {
|
||||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||||
if !reflect.ValueOf(committer).IsNil() {
|
if !reflect.ValueOf(committer).IsNil() {
|
||||||
@ -682,7 +691,7 @@ func (db *DB) RollbackTo(name string) *DB {
|
|||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec execute raw sql
|
// Exec executes raw sql
|
||||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||||
tx = db.getInstance()
|
tx = db.getInstance()
|
||||||
tx.Statement.SQL = strings.Builder{}
|
tx.Statement.SQL = strings.Builder{}
|
||||||
|
2
go.mod
2
go.mod
@ -1,6 +1,6 @@
|
|||||||
module gorm.io/gorm
|
module gorm.io/gorm
|
||||||
|
|
||||||
go 1.14
|
go 1.16
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/jinzhu/inflection v1.0.0
|
github.com/jinzhu/inflection v1.0.0
|
||||||
|
15
gorm.go
15
gorm.go
@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
|||||||
|
|
||||||
preparedStmt := &PreparedStmtDB{
|
preparedStmt := &PreparedStmtDB{
|
||||||
ConnPool: db.ConnPool,
|
ConnPool: db.ConnPool,
|
||||||
Stmts: map[string]Stmt{},
|
Stmts: map[string](*Stmt){},
|
||||||
Mux: &sync.RWMutex{},
|
Mux: &sync.RWMutex{},
|
||||||
PreparedSQL: make([]string, 0, 100),
|
PreparedSQL: make([]string, 0, 100),
|
||||||
}
|
}
|
||||||
@ -248,11 +248,19 @@ func (db *DB) Session(config *Session) *DB {
|
|||||||
if config.PrepareStmt {
|
if config.PrepareStmt {
|
||||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||||
preparedStmt := v.(*PreparedStmtDB)
|
preparedStmt := v.(*PreparedStmtDB)
|
||||||
|
switch t := tx.Statement.ConnPool.(type) {
|
||||||
|
case Tx:
|
||||||
|
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||||
|
Tx: t,
|
||||||
|
PreparedStmtDB: preparedStmt,
|
||||||
|
}
|
||||||
|
default:
|
||||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||||
ConnPool: db.Config.ConnPool,
|
ConnPool: db.Config.ConnPool,
|
||||||
Mux: preparedStmt.Mux,
|
Mux: preparedStmt.Mux,
|
||||||
Stmts: preparedStmt.Stmts,
|
Stmts: preparedStmt.Stmts,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
txConfig.ConnPool = tx.Statement.ConnPool
|
txConfig.ConnPool = tx.Statement.ConnPool
|
||||||
txConfig.PrepareStmt = true
|
txConfig.PrepareStmt = true
|
||||||
}
|
}
|
||||||
@ -300,7 +308,8 @@ func (db *DB) WithContext(ctx context.Context) *DB {
|
|||||||
|
|
||||||
// Debug start debug mode
|
// Debug start debug mode
|
||||||
func (db *DB) Debug() (tx *DB) {
|
func (db *DB) Debug() (tx *DB) {
|
||||||
return db.Session(&Session{
|
tx = db.getInstance()
|
||||||
|
return tx.Session(&Session{
|
||||||
Logger: db.Logger.LogMode(logger.Info),
|
Logger: db.Logger.LogMode(logger.Info),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -412,7 +421,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
|
|||||||
relation, ok := modelSchema.Relationships.Relations[field]
|
relation, ok := modelSchema.Relationships.Relations[field]
|
||||||
isRelation := ok && relation.JoinTable != nil
|
isRelation := ok && relation.JoinTable != nil
|
||||||
if !isRelation {
|
if !isRelation {
|
||||||
return fmt.Errorf("failed to found relation: %s", field)
|
return fmt.Errorf("failed to find relation: %s", field)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ref := range relation.References {
|
for _, ref := range relation.References {
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
@ -68,8 +68,8 @@ type Interface interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Discard Discard logger will print any log to ioutil.Discard
|
// Discard Discard logger will print any log to io.Discard
|
||||||
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
|
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
|
||||||
// Default Default logger
|
// Default Default logger
|
||||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||||
SlowThreshold: 200 * time.Millisecond,
|
SlowThreshold: 200 * time.Millisecond,
|
||||||
|
@ -30,6 +30,8 @@ func isPrintable(s string) bool {
|
|||||||
|
|
||||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||||
|
|
||||||
|
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||||
|
|
||||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||||
var (
|
var (
|
||||||
@ -138,9 +140,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
|||||||
sql = newSQL.String()
|
sql = newSQL.String()
|
||||||
} else {
|
} else {
|
||||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
||||||
for idx, v := range vars {
|
|
||||||
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
|
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
|
||||||
|
num := v[1 : len(v)-1]
|
||||||
|
n, _ := strconv.Atoi(num)
|
||||||
|
|
||||||
|
// position var start from 1 ($1, $2)
|
||||||
|
n -= 1
|
||||||
|
if n >= 0 && n <= len(vars)-1 {
|
||||||
|
return vars[n]
|
||||||
}
|
}
|
||||||
|
return v
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return sql
|
return sql
|
||||||
|
@ -68,6 +68,7 @@ type Migrator interface {
|
|||||||
// Database
|
// Database
|
||||||
CurrentDatabase() string
|
CurrentDatabase() string
|
||||||
FullDataTypeOf(*schema.Field) clause.Expr
|
FullDataTypeOf(*schema.Field) clause.Expr
|
||||||
|
GetTypeAliases(databaseTypeName string) []string
|
||||||
|
|
||||||
// Tables
|
// Tables
|
||||||
CreateTable(dst ...interface{}) error
|
CreateTable(dst ...interface{}) error
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`)
|
regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||||
)
|
)
|
||||||
|
|
||||||
// Migrator m struct
|
// Migrator m struct
|
||||||
@ -135,6 +135,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||||
if !tx.Migrator().HasConstraint(value, chk.Name) {
|
if !tx.Migrator().HasConstraint(value, chk.Name) {
|
||||||
@ -143,7 +144,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||||
if !tx.Migrator().HasIndex(value, idx.Name) {
|
if !tx.Migrator().HasIndex(value, idx.Name) {
|
||||||
@ -408,10 +408,28 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
|
|
||||||
alterColumn := false
|
alterColumn := false
|
||||||
|
|
||||||
|
if !field.PrimaryKey {
|
||||||
// check type
|
// check type
|
||||||
if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) {
|
var isSameType bool
|
||||||
|
if strings.HasPrefix(fullDataType, realDataType) {
|
||||||
|
isSameType = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// check type aliases
|
||||||
|
if !isSameType {
|
||||||
|
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
|
||||||
|
for _, alias := range aliases {
|
||||||
|
if strings.HasPrefix(fullDataType, alias) {
|
||||||
|
isSameType = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSameType {
|
||||||
alterColumn = true
|
alterColumn = true
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// check size
|
// check size
|
||||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||||
@ -478,7 +496,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
}
|
}
|
||||||
|
|
||||||
if alterColumn && !field.IgnoreMigration {
|
if alterColumn && !field.IgnoreMigration {
|
||||||
return m.DB.Migrator().AlterColumn(value, field.Name)
|
return m.DB.Migrator().AlterColumn(value, field.DBName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -863,3 +881,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
|
|||||||
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
|
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
|
||||||
return nil, errors.New("not support")
|
return nil, errors.New("not support")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetTypeAliases return database type aliases
|
||||||
|
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -9,10 +9,12 @@ import (
|
|||||||
type Stmt struct {
|
type Stmt struct {
|
||||||
*sql.Stmt
|
*sql.Stmt
|
||||||
Transaction bool
|
Transaction bool
|
||||||
|
prepared chan struct{}
|
||||||
|
prepareErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
type PreparedStmtDB struct {
|
type PreparedStmtDB struct {
|
||||||
Stmts map[string]Stmt
|
Stmts map[string]*Stmt
|
||||||
PreparedSQL []string
|
PreparedSQL []string
|
||||||
Mux *sync.RWMutex
|
Mux *sync.RWMutex
|
||||||
ConnPool
|
ConnPool
|
||||||
@ -46,27 +48,57 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
|||||||
db.Mux.RLock()
|
db.Mux.RLock()
|
||||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||||
db.Mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
return stmt, nil
|
// wait for other goroutines prepared
|
||||||
|
<-stmt.prepared
|
||||||
|
if stmt.prepareErr != nil {
|
||||||
|
return Stmt{}, stmt.prepareErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return *stmt, nil
|
||||||
}
|
}
|
||||||
db.Mux.RUnlock()
|
db.Mux.RUnlock()
|
||||||
|
|
||||||
db.Mux.Lock()
|
db.Mux.Lock()
|
||||||
defer db.Mux.Unlock()
|
|
||||||
|
|
||||||
// double check
|
// double check
|
||||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||||
return stmt, nil
|
db.Mux.Unlock()
|
||||||
} else if ok {
|
// wait for other goroutines prepared
|
||||||
go stmt.Close()
|
<-stmt.prepared
|
||||||
|
if stmt.prepareErr != nil {
|
||||||
|
return Stmt{}, stmt.prepareErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return *stmt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cache preparing stmt first
|
||||||
|
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
||||||
|
db.Stmts[query] = &cacheStmt
|
||||||
|
db.Mux.Unlock()
|
||||||
|
|
||||||
|
// prepare completed
|
||||||
|
defer close(cacheStmt.prepared)
|
||||||
|
|
||||||
|
// Reason why cannot lock conn.PrepareContext
|
||||||
|
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
||||||
|
// 1. g1 begin tx, g1 is requeued because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
|
||||||
|
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
|
||||||
|
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
||||||
stmt, err := conn.PrepareContext(ctx, query)
|
stmt, err := conn.PrepareContext(ctx, query)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
|
cacheStmt.prepareErr = err
|
||||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
db.Mux.Lock()
|
||||||
|
delete(db.Stmts, query)
|
||||||
|
db.Mux.Unlock()
|
||||||
|
return Stmt{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.Stmts[query], err
|
db.Mux.Lock()
|
||||||
|
cacheStmt.Stmt = stmt
|
||||||
|
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||||
|
db.Mux.Unlock()
|
||||||
|
|
||||||
|
return cacheStmt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||||
|
53
scan.go
53
scan.go
@ -66,9 +66,12 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
|
|||||||
db.RowsAffected++
|
db.RowsAffected++
|
||||||
db.AddError(rows.Scan(values...))
|
db.AddError(rows.Scan(values...))
|
||||||
|
|
||||||
joinedSchemaMap := make(map[*schema.Field]interface{}, 0)
|
joinedSchemaMap := make(map[*schema.Field]interface{})
|
||||||
for idx, field := range fields {
|
for idx, field := range fields {
|
||||||
if field != nil {
|
if field == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
if len(joinFields) == 0 || joinFields[idx][0] == nil {
|
||||||
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||||
} else {
|
} else {
|
||||||
@ -91,7 +94,6 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
|
|||||||
field.NewValuePool.Put(values[idx])
|
field.NewValuePool.Put(values[idx])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// ScanMode scan data mode
|
// ScanMode scan data mode
|
||||||
type ScanMode uint8
|
type ScanMode uint8
|
||||||
@ -162,7 +164,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
default:
|
default:
|
||||||
var (
|
var (
|
||||||
fields = make([]*schema.Field, len(columns))
|
fields = make([]*schema.Field, len(columns))
|
||||||
selectedColumnsMap = make(map[string]int, len(columns))
|
|
||||||
joinFields [][2]*schema.Field
|
joinFields [][2]*schema.Field
|
||||||
sch = db.Statement.Schema
|
sch = db.Statement.Schema
|
||||||
reflectValue = db.Statement.ReflectValue
|
reflectValue = db.Statement.ReflectValue
|
||||||
@ -198,26 +199,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
|
|
||||||
// Not Pluck
|
// Not Pluck
|
||||||
if sch != nil {
|
if sch != nil {
|
||||||
schFieldsCount := len(sch.Fields)
|
matchedFieldCount := make(map[string]int, len(columns))
|
||||||
for idx, column := range columns {
|
for idx, column := range columns {
|
||||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||||
if curIndex, ok := selectedColumnsMap[column]; ok {
|
fields[idx] = field
|
||||||
fields[idx] = field // handle duplicate fields
|
if count, ok := matchedFieldCount[column]; ok {
|
||||||
offset := curIndex + 1
|
// handle duplicate fields
|
||||||
// handle sch inconsistent with database
|
for _, selectField := range sch.Fields {
|
||||||
// like Raw(`...`).Scan
|
|
||||||
if schFieldsCount > offset {
|
|
||||||
for fieldIndex, selectField := range sch.Fields[offset:] {
|
|
||||||
if selectField.DBName == column && selectField.Readable {
|
if selectField.DBName == column && selectField.Readable {
|
||||||
selectedColumnsMap[column] = curIndex + fieldIndex + 1
|
if count == 0 {
|
||||||
|
matchedFieldCount[column]++
|
||||||
fields[idx] = selectField
|
fields[idx] = selectField
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
count--
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fields[idx] = field
|
matchedFieldCount[column] = 1
|
||||||
selectedColumnsMap[column] = idx
|
|
||||||
}
|
}
|
||||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||||
@ -241,12 +240,21 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
|
|
||||||
switch reflectValue.Kind() {
|
switch reflectValue.Kind() {
|
||||||
case reflect.Slice, reflect.Array:
|
case reflect.Slice, reflect.Array:
|
||||||
var elem reflect.Value
|
var (
|
||||||
recyclableStruct := reflect.New(reflectValueType)
|
elem reflect.Value
|
||||||
|
recyclableStruct = reflect.New(reflectValueType)
|
||||||
|
isArrayKind = reflectValue.Kind() == reflect.Array
|
||||||
|
)
|
||||||
|
|
||||||
if !update || reflectValue.Len() == 0 {
|
if !update || reflectValue.Len() == 0 {
|
||||||
update = false
|
update = false
|
||||||
|
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
||||||
|
if reflectValue.Cap() == 0 {
|
||||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||||
|
} else if !isArrayKind {
|
||||||
|
reflectValue.SetLen(0)
|
||||||
|
db.Statement.ReflectValue.Set(reflectValue)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for initialized || rows.Next() {
|
for initialized || rows.Next() {
|
||||||
@ -277,10 +285,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
|||||||
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||||
|
|
||||||
if !update {
|
if !update {
|
||||||
if isPtr {
|
if !isPtr {
|
||||||
reflectValue = reflect.Append(reflectValue, elem)
|
elem = elem.Elem()
|
||||||
|
}
|
||||||
|
if isArrayKind {
|
||||||
|
if reflectValue.Len() >= int(db.RowsAffected) {
|
||||||
|
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
reflectValue = reflect.Append(reflectValue, elem.Elem())
|
reflectValue = reflect.Append(reflectValue, elem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -403,18 +403,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ef.PrimaryKey {
|
if ef.PrimaryKey {
|
||||||
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
|
if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) {
|
||||||
ef.PrimaryKey = true
|
|
||||||
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
|
|
||||||
ef.PrimaryKey = true
|
|
||||||
} else {
|
|
||||||
ef.PrimaryKey = false
|
ef.PrimaryKey = false
|
||||||
|
|
||||||
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
|
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
|
||||||
ef.AutoIncrement = false
|
ef.AutoIncrement = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if ef.DefaultValue == "" {
|
if !ef.AutoIncrement && ef.DefaultValue == "" {
|
||||||
ef.HasDefaultValue = false
|
ef.HasDefaultValue = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -472,9 +468,6 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
oldValuerOf := field.ValueOf
|
oldValuerOf := field.ValueOf
|
||||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||||
value, zero := oldValuerOf(ctx, v)
|
value, zero := oldValuerOf(ctx, v)
|
||||||
if zero {
|
|
||||||
return value, zero
|
|
||||||
}
|
|
||||||
|
|
||||||
s, ok := value.(SerializerValuerInterface)
|
s, ok := value.(SerializerValuerInterface)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -487,7 +480,7 @@ func (field *Field) setupValuerAndSetter() {
|
|||||||
Destination: v,
|
Destination: v,
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
fieldValue: value,
|
fieldValue: value,
|
||||||
}, false
|
}, zero
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
err error
|
err error
|
||||||
joinTableFields []reflect.StructField
|
joinTableFields []reflect.StructField
|
||||||
fieldsMap = map[string]*Field{}
|
fieldsMap = map[string]*Field{}
|
||||||
ownFieldsMap = map[string]bool{} // fix self join many2many
|
ownFieldsMap = map[string]*Field{} // fix self join many2many
|
||||||
|
referFieldsMap = map[string]*Field{}
|
||||||
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
|
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
|
||||||
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
|
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
|
||||||
)
|
)
|
||||||
@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
joinFieldName = strings.Title(joinForeignKeys[idx])
|
joinFieldName = strings.Title(joinForeignKeys[idx])
|
||||||
}
|
}
|
||||||
|
|
||||||
ownFieldsMap[joinFieldName] = true
|
ownFieldsMap[joinFieldName] = ownField
|
||||||
fieldsMap[joinFieldName] = ownField
|
fieldsMap[joinFieldName] = ownField
|
||||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||||
Name: joinFieldName,
|
Name: joinFieldName,
|
||||||
@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
|
|
||||||
for idx, relField := range refForeignFields {
|
for idx, relField := range refForeignFields {
|
||||||
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
|
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
|
||||||
if len(joinReferences) > idx {
|
|
||||||
joinFieldName = strings.Title(joinReferences[idx])
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
||||||
if field.Name != relation.FieldSchema.Name {
|
if field.Name != relation.FieldSchema.Name {
|
||||||
@ -254,6 +252,13 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(joinReferences) > idx {
|
||||||
|
joinFieldName = strings.Title(joinReferences[idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
referFieldsMap[joinFieldName] = relField
|
||||||
|
|
||||||
|
if _, ok := fieldsMap[joinFieldName]; !ok {
|
||||||
fieldsMap[joinFieldName] = relField
|
fieldsMap[joinFieldName] = relField
|
||||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||||
Name: joinFieldName,
|
Name: joinFieldName,
|
||||||
@ -263,6 +268,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||||
Name: strings.Title(schema.Name) + field.Name,
|
Name: strings.Title(schema.Name) + field.Name,
|
||||||
@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
|||||||
f.Size = fieldsMap[f.Name].Size
|
f.Size = fieldsMap[f.Name].Size
|
||||||
}
|
}
|
||||||
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
||||||
ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
|
|
||||||
|
|
||||||
if ownPrimaryField {
|
if of, ok := ownFieldsMap[f.Name]; ok {
|
||||||
joinRel := relation.JoinTable.Relationships.Relations[relName]
|
joinRel := relation.JoinTable.Relationships.Relations[relName]
|
||||||
joinRel.Field = relation.Field
|
joinRel.Field = relation.Field
|
||||||
joinRel.References = append(joinRel.References, &Reference{
|
joinRel.References = append(joinRel.References, &Reference{
|
||||||
PrimaryKey: fieldsMap[f.Name],
|
PrimaryKey: of,
|
||||||
ForeignKey: f,
|
ForeignKey: f,
|
||||||
})
|
})
|
||||||
} else {
|
|
||||||
|
relation.References = append(relation.References, &Reference{
|
||||||
|
PrimaryKey: of,
|
||||||
|
ForeignKey: f,
|
||||||
|
OwnPrimaryKey: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if rf, ok := referFieldsMap[f.Name]; ok {
|
||||||
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
|
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
|
||||||
if joinRefRel.Field == nil {
|
if joinRefRel.Field == nil {
|
||||||
joinRefRel.Field = relation.Field
|
joinRefRel.Field = relation.Field
|
||||||
}
|
}
|
||||||
joinRefRel.References = append(joinRefRel.References, &Reference{
|
joinRefRel.References = append(joinRefRel.References, &Reference{
|
||||||
PrimaryKey: fieldsMap[f.Name],
|
PrimaryKey: rf,
|
||||||
|
ForeignKey: f,
|
||||||
|
})
|
||||||
|
|
||||||
|
relation.References = append(relation.References, &Reference{
|
||||||
|
PrimaryKey: rf,
|
||||||
ForeignKey: f,
|
ForeignKey: f,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
relation.References = append(relation.References, &Reference{
|
|
||||||
PrimaryKey: fieldsMap[f.Name],
|
|
||||||
ForeignKey: f,
|
|
||||||
OwnPrimaryKey: ownPrimaryField,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
|
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
|
||||||
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
|
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
|
||||||
t.Errorf("Failed to parse schema")
|
t.Errorf("Failed to parse schema, got error %v", err)
|
||||||
} else {
|
} else {
|
||||||
for _, rel := range relations {
|
for _, rel := range relations {
|
||||||
checkSchemaRelation(t, s, rel)
|
checkSchemaRelation(t, s, rel)
|
||||||
@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMany2ManySharedForeignKey(t *testing.T) {
|
||||||
|
type Profile struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Kind string
|
||||||
|
ProfileRefer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
gorm.Model
|
||||||
|
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
|
||||||
|
Kind string
|
||||||
|
Refer uint
|
||||||
|
}
|
||||||
|
|
||||||
|
checkStructRelation(t, &User{}, Relation{
|
||||||
|
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
|
||||||
|
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
|
||||||
|
References: []Reference{
|
||||||
|
{"Refer", "User", "UserRefer", "user_profiles", "", true},
|
||||||
|
{"Kind", "User", "Kind", "user_profiles", "", true},
|
||||||
|
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
|
||||||
|
{"Kind", "Profile", "Kind", "user_profiles", "", false},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
|
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
|
||||||
type Profile struct {
|
type Profile struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
|
@ -112,7 +112,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
schemaCacheKey = modelType
|
schemaCacheKey = modelType
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load exist schmema cache, return if exists
|
// Load exist schema cache, return if exists
|
||||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||||
s := v.(*Schema)
|
s := v.(*Schema)
|
||||||
// Wait for the initialization of other goroutines to complete
|
// Wait for the initialization of other goroutines to complete
|
||||||
@ -146,7 +146,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
// When the schema initialization is completed, the channel will be closed
|
// When the schema initialization is completed, the channel will be closed
|
||||||
defer close(schema.initialized)
|
defer close(schema.initialized)
|
||||||
|
|
||||||
// Load exist schmema cache, return if exists
|
// Load exist schema cache, return if exists
|
||||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||||
s := v.(*Schema)
|
s := v.(*Schema)
|
||||||
// Wait for the initialization of other goroutines to complete
|
// Wait for the initialization of other goroutines to complete
|
||||||
@ -239,6 +239,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
|
|||||||
field.HasDefaultValue = true
|
field.HasDefaultValue = true
|
||||||
field.AutoIncrement = true
|
field.AutoIncrement = true
|
||||||
}
|
}
|
||||||
|
case String:
|
||||||
|
if _, ok := field.TagSettings["PRIMARYKEY"]; !ok {
|
||||||
|
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||||
|
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||||
|
}
|
||||||
|
|
||||||
|
field.HasDefaultValue = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,8 +88,10 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
|||||||
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(bytes) > 0 {
|
||||||
err = json.Unmarshal(bytes, fieldValue.Interface())
|
err = json.Unmarshal(bytes, fieldValue.Interface())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||||
return
|
return
|
||||||
@ -117,9 +119,15 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.
|
|||||||
|
|
||||||
// Value implements serializer interface
|
// Value implements serializer interface
|
||||||
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
|
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
|
||||||
|
rv := reflect.ValueOf(fieldValue)
|
||||||
switch v := fieldValue.(type) {
|
switch v := fieldValue.(type) {
|
||||||
case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
case int64, int, uint, uint64, int32, uint32, int16, uint16:
|
||||||
result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0)
|
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||||
|
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||||
|
if rv.IsZero() {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
||||||
}
|
}
|
||||||
@ -142,9 +150,11 @@ func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
|||||||
default:
|
default:
|
||||||
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
|
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
|
||||||
}
|
}
|
||||||
|
if len(bytesValue) > 0 {
|
||||||
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
||||||
err = decoder.Decode(fieldValue.Interface())
|
err = decoder.Decode(fieldValue.Interface())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_0-9]+?)[\W]?\.[\W]?([a-z_0-9]+?)[\W]?$`)
|
var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
|
||||||
|
|
||||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||||
@ -672,8 +672,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
|
|||||||
}
|
}
|
||||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||||
results[field.DBName] = true
|
results[field.DBName] = true
|
||||||
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
|
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
|
||||||
results[matches[1]] = true
|
results[matches[2]] = true
|
||||||
} else {
|
} else {
|
||||||
results[column] = true
|
results[column] = true
|
||||||
}
|
}
|
||||||
|
@ -36,17 +36,21 @@ func TestWhereCloneCorruption(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestNameMatcher(t *testing.T) {
|
func TestNameMatcher(t *testing.T) {
|
||||||
for k, v := range map[string]string{
|
for k, v := range map[string][]string{
|
||||||
"table.name": "name",
|
"table.name": {"table", "name"},
|
||||||
"`table`.`name`": "name",
|
"`table`.`name`": {"table", "name"},
|
||||||
"'table'.'name'": "name",
|
"'table'.'name'": {"table", "name"},
|
||||||
"'table'.name": "name",
|
"'table'.name": {"table", "name"},
|
||||||
"table1.name_23": "name_23",
|
"table1.name_23": {"table1", "name_23"},
|
||||||
"`table_1`.`name23`": "name23",
|
"`table_1`.`name23`": {"table_1", "name23"},
|
||||||
"'table23'.'name_1'": "name_1",
|
"'table23'.'name_1'": {"table23", "name_1"},
|
||||||
"'table23'.name1": "name1",
|
"'table23'.name1": {"table23", "name1"},
|
||||||
|
"'name1'": {"", "name1"},
|
||||||
|
"`name_1`": {"", "name_1"},
|
||||||
|
"`Name_1`": {"", "Name_1"},
|
||||||
|
"`Table`.`nAme`": {"Table", "nAme"},
|
||||||
} {
|
} {
|
||||||
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v {
|
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] {
|
||||||
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
|
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
"gorm.io/gorm/schema"
|
||||||
. "gorm.io/gorm/utils/tests"
|
. "gorm.io/gorm/utils/tests"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -284,3 +286,65 @@ func TestAssociationError(t *testing.T) {
|
|||||||
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
|
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
|
||||||
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
|
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
myType string
|
||||||
|
emptyQueryClause struct {
|
||||||
|
Field *schema.Field
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (myType) QueryClauses(f *schema.Field) []clause.Interface {
|
||||||
|
return []clause.Interface{emptyQueryClause{Field: f}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd emptyQueryClause) Name() string {
|
||||||
|
return "empty"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd emptyQueryClause) Build(clause.Builder) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd emptyQueryClause) MergeClause(*clause.Clause) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sd emptyQueryClause) ModifyStatement(stmt *gorm.Statement) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssociationEmptyQueryClause(t *testing.T) {
|
||||||
|
type Organization struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
type Region struct {
|
||||||
|
gorm.Model
|
||||||
|
Name string
|
||||||
|
Organizations []Organization `gorm:"many2many:region_orgs;"`
|
||||||
|
}
|
||||||
|
type RegionOrg struct {
|
||||||
|
RegionId uint
|
||||||
|
OrganizationId uint
|
||||||
|
Empty myType
|
||||||
|
}
|
||||||
|
if err := DB.SetupJoinTable(&Region{}, "Organizations", &RegionOrg{}); err != nil {
|
||||||
|
t.Fatalf("Failed to set up join table, got error: %s", err)
|
||||||
|
}
|
||||||
|
if err := DB.Migrator().DropTable(&Organization{}, &Region{}); err != nil {
|
||||||
|
t.Fatalf("Failed to migrate, got error: %s", err)
|
||||||
|
}
|
||||||
|
if err := DB.AutoMigrate(&Organization{}, &Region{}); err != nil {
|
||||||
|
t.Fatalf("Failed to migrate, got error: %v", err)
|
||||||
|
}
|
||||||
|
region := &Region{Name: "Region1"}
|
||||||
|
if err := DB.Create(region).Error; err != nil {
|
||||||
|
t.Fatalf("fail to create region %v", err)
|
||||||
|
}
|
||||||
|
var orgs []Organization
|
||||||
|
|
||||||
|
if err := DB.Model(&Region{}).Association("Organizations").Find(&orgs); err != nil {
|
||||||
|
t.Fatalf("fail to find region organizations %v", err)
|
||||||
|
} else {
|
||||||
|
AssertEqual(t, len(orgs), 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -113,6 +113,9 @@ func TestCallbacks(t *testing.T) {
|
|||||||
|
|
||||||
for idx, data := range datas {
|
for idx, data := range datas {
|
||||||
db, err := gorm.Open(nil, nil)
|
db, err := gorm.Open(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
callbacks := db.Callback()
|
callbacks := db.Callback()
|
||||||
|
|
||||||
for _, c := range data.callbacks {
|
for _, c := range data.callbacks {
|
||||||
|
@ -168,3 +168,29 @@ func TestEmbeddedRelations(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEmbeddedTagSetting(t *testing.T) {
|
||||||
|
type Tag1 struct {
|
||||||
|
Id int64 `gorm:"autoIncrement"`
|
||||||
|
}
|
||||||
|
type Tag2 struct {
|
||||||
|
Id int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type EmbeddedTag struct {
|
||||||
|
Tag1 Tag1 `gorm:"Embedded;"`
|
||||||
|
Tag2 Tag2 `gorm:"Embedded;EmbeddedPrefix:t2_"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&EmbeddedTag{})
|
||||||
|
err := DB.Migrator().AutoMigrate(&EmbeddedTag{})
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
t1 := EmbeddedTag{Name: "embedded_tag"}
|
||||||
|
err = DB.Save(&t1).Error
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
if t1.Tag1.Id == 0 {
|
||||||
|
t.Errorf("embedded struct's primary field should be rewrited")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
14
tests/go.mod
14
tests/go.mod
@ -1,20 +1,20 @@
|
|||||||
module gorm.io/gorm/tests
|
module gorm.io/gorm/tests
|
||||||
|
|
||||||
go 1.14
|
go 1.16
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/denisenkom/go-mssqldb v0.12.2 // indirect
|
github.com/denisenkom/go-mssqldb v0.12.2 // indirect
|
||||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/jinzhu/now v1.1.5
|
github.com/jinzhu/now v1.1.5
|
||||||
github.com/lib/pq v1.10.6
|
github.com/lib/pq v1.10.7
|
||||||
github.com/mattn/go-sqlite3 v1.14.14 // indirect
|
github.com/mattn/go-sqlite3 v1.14.15 // indirect
|
||||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
|
golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect
|
||||||
gorm.io/driver/mysql v1.3.4
|
gorm.io/driver/mysql v1.3.6
|
||||||
gorm.io/driver/postgres v1.3.8
|
gorm.io/driver/postgres v1.3.10
|
||||||
gorm.io/driver/sqlite v1.3.6
|
gorm.io/driver/sqlite v1.3.6
|
||||||
gorm.io/driver/sqlserver v1.3.2
|
gorm.io/driver/sqlserver v1.3.2
|
||||||
gorm.io/gorm v1.23.7
|
gorm.io/gorm v1.23.10
|
||||||
)
|
)
|
||||||
|
|
||||||
replace gorm.io/gorm => ../
|
replace gorm.io/gorm => ../
|
||||||
|
@ -80,6 +80,7 @@ func CheckPet(t *testing.T, pet Pet, expect Pet) {
|
|||||||
t.Fatalf("errors happened when query: %v", err)
|
t.Fatalf("errors happened when query: %v", err)
|
||||||
} else {
|
} else {
|
||||||
AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
|
AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
|
||||||
|
AssertObjEqual(t, newPet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -174,6 +175,7 @@ func CheckUser(t *testing.T, user User, expect User) {
|
|||||||
var manager User
|
var manager User
|
||||||
DB.First(&manager, "id = ?", *user.ManagerID)
|
DB.First(&manager, "id = ?", *user.ManagerID)
|
||||||
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||||
|
AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
|
||||||
}
|
}
|
||||||
} else if user.ManagerID != nil {
|
} else if user.ManagerID != nil {
|
||||||
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
|
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)
|
||||||
|
@ -229,3 +229,34 @@ func TestJoinWithSoftDeleted(t *testing.T) {
|
|||||||
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
|
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJoinWithSameColumnName(t *testing.T) {
|
||||||
|
user := GetUser("TestJoinWithSameColumnName", Config{
|
||||||
|
Languages: 1,
|
||||||
|
Pets: 1,
|
||||||
|
})
|
||||||
|
DB.Create(user)
|
||||||
|
type UserSpeak struct {
|
||||||
|
UserID uint
|
||||||
|
LanguageCode string
|
||||||
|
}
|
||||||
|
type Result struct {
|
||||||
|
User
|
||||||
|
UserSpeak
|
||||||
|
Language
|
||||||
|
Pet
|
||||||
|
}
|
||||||
|
|
||||||
|
results := make([]Result, 0, 1)
|
||||||
|
DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id").
|
||||||
|
Joins("JOIN languages ON languages.code = user_speaks.language_code").
|
||||||
|
Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results)
|
||||||
|
|
||||||
|
if len(results) == 0 {
|
||||||
|
t.Fatalf("no record find")
|
||||||
|
} else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID {
|
||||||
|
t.Fatalf("wrong user id in pet")
|
||||||
|
} else if results[0].Pet.Name != user.Pets[0].Name {
|
||||||
|
t.Fatalf("wrong pet name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1048,3 +1048,41 @@ func TestMigrateDonotAlterColumn(t *testing.T) {
|
|||||||
err = mockM.AutoMigrate(&NotTriggerUpdate{})
|
err = mockM.AutoMigrate(&NotTriggerUpdate{})
|
||||||
AssertEqual(t, err, nil)
|
AssertEqual(t, err, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrateSameEmbeddedFieldName(t *testing.T) {
|
||||||
|
type UserStat struct {
|
||||||
|
GroundDestroyCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
type GameUser struct {
|
||||||
|
gorm.Model
|
||||||
|
StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserStat1 struct {
|
||||||
|
GroundDestroyCount string
|
||||||
|
}
|
||||||
|
|
||||||
|
type GroundRate struct {
|
||||||
|
GroundDestroyCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
type GameUser1 struct {
|
||||||
|
gorm.Model
|
||||||
|
StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"`
|
||||||
|
GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&GameUser{})
|
||||||
|
err := DB.AutoMigrate(&GameUser{})
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
err = DB.Table("game_users").AutoMigrate(&GameUser1{})
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
_, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count")
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
|
||||||
|
_, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count")
|
||||||
|
AssertEqual(t, nil, err)
|
||||||
|
}
|
||||||
|
@ -9,6 +9,56 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestPostgresReturningIDWhichHasStringType(t *testing.T) {
|
||||||
|
if DB.Dialector.Name() != "postgres" {
|
||||||
|
t.Skip()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Yasuo struct {
|
||||||
|
ID string `gorm:"default:gen_random_uuid()"`
|
||||||
|
Name string
|
||||||
|
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||||
|
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
|
||||||
|
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Migrator().DropTable(&Yasuo{})
|
||||||
|
|
||||||
|
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
|
||||||
|
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
yasuo := Yasuo{Name: "jinzhu"}
|
||||||
|
if err := DB.Create(&yasuo).Error; err != nil {
|
||||||
|
t.Fatalf("should be able to create data, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if yasuo.ID == "" {
|
||||||
|
t.Fatal("should be able to has ID, but got zero value")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result Yasuo
|
||||||
|
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||||
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
|
||||||
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
yasuo.Name = "jinzhu1"
|
||||||
|
if err := DB.Save(&yasuo).Error; err != nil {
|
||||||
|
t.Errorf("Failed to update date, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
|
||||||
|
t.Errorf("No error should happen, but got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPostgres(t *testing.T) {
|
func TestPostgres(t *testing.T) {
|
||||||
if DB.Dialector.Name() != "postgres" {
|
if DB.Dialector.Name() != "postgres" {
|
||||||
t.Skip()
|
t.Skip()
|
||||||
@ -63,13 +113,13 @@ func TestPostgres(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Post struct {
|
type Post struct {
|
||||||
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
|
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
|
||||||
Title string
|
Title string
|
||||||
Categories []*Category `gorm:"Many2Many:post_categories"`
|
Categories []*Category `gorm:"Many2Many:post_categories"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Category struct {
|
type Category struct {
|
||||||
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
|
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
|
||||||
Title string
|
Title string
|
||||||
Posts []*Post `gorm:"Many2Many:post_categories"`
|
Posts []*Post `gorm:"Many2Many:post_categories"`
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,8 @@ package tests_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sync"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -88,3 +90,81 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
|
|||||||
}
|
}
|
||||||
tx2.Commit()
|
tx2.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPreparedStmtDeadlock(t *testing.T) {
|
||||||
|
tx, err := OpenTestConnection()
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
sqlDB, _ := tx.DB()
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
|
||||||
|
tx = tx.Session(&gorm.Session{PrepareStmt: true})
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
user := User{Name: "jinzhu"}
|
||||||
|
tx.Create(&user)
|
||||||
|
|
||||||
|
var result User
|
||||||
|
tx.First(&result)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||||
|
AssertEqual(t, ok, true)
|
||||||
|
AssertEqual(t, len(conn.Stmts), 2)
|
||||||
|
for _, stmt := range conn.Stmts {
|
||||||
|
if stmt == nil {
|
||||||
|
t.Fatalf("stmt cannot bee nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, sqlDB.Stats().InUse, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreparedStmtError(t *testing.T) {
|
||||||
|
tx, err := OpenTestConnection()
|
||||||
|
AssertEqual(t, err, nil)
|
||||||
|
|
||||||
|
sqlDB, _ := tx.DB()
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
|
||||||
|
tx = tx.Session(&gorm.Session{PrepareStmt: true})
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
// err prepare
|
||||||
|
tag := Tag{Locale: "zh"}
|
||||||
|
tx.Table("users").Find(&tag)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
|
||||||
|
AssertEqual(t, ok, true)
|
||||||
|
AssertEqual(t, len(conn.Stmts), 0)
|
||||||
|
AssertEqual(t, sqlDB.Stats().InUse, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPreparedStmtInTransaction(t *testing.T) {
|
||||||
|
user := User{Name: "jinzhu"}
|
||||||
|
|
||||||
|
if err := DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user)
|
||||||
|
return errors.New("test")
|
||||||
|
}); err == nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result User
|
||||||
|
if err := DB.First(&result, user.ID).Error; err == nil {
|
||||||
|
t.Errorf("Failed, got error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -216,6 +216,30 @@ func TestFind(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test array
|
||||||
|
var models2 [3]User
|
||||||
|
if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 {
|
||||||
|
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2))
|
||||||
|
} else {
|
||||||
|
for idx, user := range users {
|
||||||
|
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||||
|
CheckUser(t, models2[idx], user)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// test smaller array
|
||||||
|
var models3 [2]User
|
||||||
|
if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 {
|
||||||
|
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3))
|
||||||
|
} else {
|
||||||
|
for idx, user := range users[:2] {
|
||||||
|
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
|
||||||
|
CheckUser(t, models3[idx], user)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var none []User
|
var none []User
|
||||||
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
|
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
|
||||||
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
|
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
|
||||||
@ -257,7 +281,7 @@ func TestFindInBatches(t *testing.T) {
|
|||||||
totalBatch int
|
totalBatch int
|
||||||
)
|
)
|
||||||
|
|
||||||
if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
if result := DB.Table("users as u").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
|
||||||
totalBatch += batch
|
totalBatch += batch
|
||||||
|
|
||||||
if tx.RowsAffected != 2 {
|
if tx.RowsAffected != 2 {
|
||||||
@ -273,7 +297,7 @@ func TestFindInBatches(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Save(results).Error; err != nil {
|
if err := tx.Save(results).Error; err != nil {
|
||||||
t.Errorf("failed to save users, got error %v", err)
|
t.Fatalf("failed to save users, got error %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -113,6 +113,43 @@ func TestSerializer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
AssertEqual(t, result, data)
|
AssertEqual(t, result, data)
|
||||||
|
|
||||||
|
if err := DB.Model(&result).Update("roles", "").Error; err != nil {
|
||||||
|
t.Fatalf("failed to update data's roles, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializerZeroValue(t *testing.T) {
|
||||||
|
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
|
||||||
|
DB.Migrator().DropTable(&SerializerStruct{})
|
||||||
|
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
|
||||||
|
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := SerializerStruct{}
|
||||||
|
|
||||||
|
if err := DB.Create(&data).Error; err != nil {
|
||||||
|
t.Fatalf("failed to create data, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result SerializerStruct
|
||||||
|
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
AssertEqual(t, result, data)
|
||||||
|
|
||||||
|
if err := DB.Model(&result).Update("roles", "").Error; err != nil {
|
||||||
|
t.Fatalf("failed to update data's roles, got error %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DB.First(&result, data.ID).Error; err != nil {
|
||||||
|
t.Fatalf("failed to query data, got error %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
func TestSerializerAssignFirstOrCreate(t *testing.T) {
|
||||||
|
@ -102,7 +102,7 @@ func TestTransactionWithBlock(t *testing.T) {
|
|||||||
return errors.New("the error message")
|
return errors.New("the error message")
|
||||||
})
|
})
|
||||||
|
|
||||||
if err.Error() != "the error message" {
|
if err != nil && err.Error() != "the error message" {
|
||||||
t.Fatalf("Transaction return error will equal the block returns error")
|
t.Fatalf("Transaction return error will equal the block returns error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,4 +41,19 @@ func TestUpdateBelongsTo(t *testing.T) {
|
|||||||
var user4 User
|
var user4 User
|
||||||
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
|
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
|
||||||
CheckUser(t, user4, user)
|
CheckUser(t, user4, user)
|
||||||
|
|
||||||
|
user.Company.Name += "new2"
|
||||||
|
user.Manager.Name += "new2"
|
||||||
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Select("`Company`").Save(&user).Error; err != nil {
|
||||||
|
t.Fatalf("errors happened when update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var user5 User
|
||||||
|
DB.Preload("Company").Preload("Manager").Find(&user5, "id = ?", user.ID)
|
||||||
|
if user5.Manager.Name != user4.Manager.Name {
|
||||||
|
t.Errorf("should not update user's manager")
|
||||||
|
} else {
|
||||||
|
user.Manager.Name = user4.Manager.Name
|
||||||
|
}
|
||||||
|
CheckUser(t, user, user5)
|
||||||
}
|
}
|
||||||
|
@ -92,6 +92,7 @@ func TestUpdateHasOne(t *testing.T) {
|
|||||||
gorm.Model
|
gorm.Model
|
||||||
UserID sql.NullInt64
|
UserID sql.NullInt64
|
||||||
Number string `gorm:"<-:create"`
|
Number string `gorm:"<-:create"`
|
||||||
|
Number2 string
|
||||||
}
|
}
|
||||||
|
|
||||||
type CustomizeUser struct {
|
type CustomizeUser struct {
|
||||||
@ -115,6 +116,7 @@ func TestUpdateHasOne(t *testing.T) {
|
|||||||
Name: "update-has-one-associations",
|
Name: "update-has-one-associations",
|
||||||
Account: CustomizeAccount{
|
Account: CustomizeAccount{
|
||||||
Number: number,
|
Number: number,
|
||||||
|
Number2: number,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,6 +124,7 @@ func TestUpdateHasOne(t *testing.T) {
|
|||||||
t.Fatalf("errors happened when create: %v", err)
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
}
|
}
|
||||||
cusUser.Account.Number += "-update"
|
cusUser.Account.Number += "-update"
|
||||||
|
cusUser.Account.Number2 += "-update"
|
||||||
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil {
|
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil {
|
||||||
t.Fatalf("errors happened when create: %v", err)
|
t.Fatalf("errors happened when create: %v", err)
|
||||||
}
|
}
|
||||||
@ -129,5 +132,6 @@ func TestUpdateHasOne(t *testing.T) {
|
|||||||
var account2 CustomizeAccount
|
var account2 CustomizeAccount
|
||||||
DB.Find(&account2, "user_id = ?", cusUser.ID)
|
DB.Find(&account2, "user_id = ?", cusUser.ID)
|
||||||
AssertEqual(t, account2.Number, number)
|
AssertEqual(t, account2.Number, number)
|
||||||
|
AssertEqual(t, account2.Number2, cusUser.Account.Number2)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -307,6 +307,8 @@ func TestSelectWithUpdate(t *testing.T) {
|
|||||||
if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) {
|
if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) {
|
||||||
t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt)
|
t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AssertObjEqual(t, result, User{Name: "update_with_select"}, "Name", "Age")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectWithUpdateWithMap(t *testing.T) {
|
func TestSelectWithUpdateWithMap(t *testing.T) {
|
||||||
|
@ -62,7 +62,7 @@ func TestUpsert(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"})
|
r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"})
|
||||||
if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) {
|
if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) {
|
||||||
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -64,8 +64,8 @@ type Language struct {
|
|||||||
type Coupon struct {
|
type Coupon struct {
|
||||||
ID int `gorm:"primarykey; size:255"`
|
ID int `gorm:"primarykey; size:255"`
|
||||||
AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"`
|
AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"`
|
||||||
AmountOff uint32 `gorm:"amount_off"`
|
AmountOff uint32 `gorm:"column:amount_off"`
|
||||||
PercentOff float32 `gorm:"percent_off"`
|
PercentOff float32 `gorm:"column:percent_off"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CouponProduct struct {
|
type CouponProduct struct {
|
||||||
|
@ -12,3 +12,20 @@ func TestIsValidDBNameChar(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToStringKey(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
values []interface{}
|
||||||
|
key string
|
||||||
|
}{
|
||||||
|
{[]interface{}{"a"}, "a"},
|
||||||
|
{[]interface{}{1, 2, 3}, "1_2_3"},
|
||||||
|
{[]interface{}{[]interface{}{1, 2, 3}}, "[1 2 3]"},
|
||||||
|
{[]interface{}{[]interface{}{"1", "2", "3"}}, "[1 2 3]"},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
if key := ToStringKey(c.values...); key != c.key {
|
||||||
|
t.Errorf("%v: expected %v, got %v", c.values, c.key, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user