Merge branch 'master' into test_migrate_column

This commit is contained in:
Jinzhu 2022-09-30 18:15:17 +08:00 committed by GitHub
commit 02b48e6798
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 741 additions and 191 deletions

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v5
uses: actions/stale@v6
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v5
uses: actions/stale@v6
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨"

View File

@ -16,7 +16,7 @@ jobs:
ACTIONS_STEP_DEBUG: true
steps:
- name: Close Stale Issues
uses: actions/stale@v5
uses: actions/stale@v6
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days"

View File

@ -16,7 +16,7 @@ jobs:
sqlite:
strategy:
matrix:
go: ['1.18', '1.17', '1.16']
go: ['1.19', '1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run in windows OS
runs-on: ${{ matrix.platform }}
@ -42,7 +42,7 @@ jobs:
strategy:
matrix:
dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
go: ['1.18', '1.17', '1.16']
go: ['1.19', '1.18', '1.17', '1.16']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
@ -86,7 +86,7 @@ jobs:
strategy:
matrix:
dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10']
go: ['1.18', '1.17', '1.16']
go: ['1.19', '1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run in macOS and Windows
runs-on: ${{ matrix.platform }}
@ -128,7 +128,7 @@ jobs:
sqlserver:
strategy:
matrix:
go: ['1.18', '1.17', '1.16']
go: ['1.19', '1.18', '1.17', '1.16']
platform: [ubuntu-latest] # can not run test in macOS and windows
runs-on: ${{ matrix.platform }}

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ coverage.txt
_book
.idea
vendor
.vscode

View File

@ -507,8 +507,10 @@ func (association *Association) buildCondition() *DB {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE")
if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
}
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},

View File

@ -206,7 +206,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
}
}
cacheKey := utils.ToStringKey(relPrimaryValues)
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true
if isPtr {
@ -292,7 +292,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
}
}
cacheKey := utils.ToStringKey(relPrimaryValues)
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true
distinctElems = reflect.Append(distinctElems, elem)

View File

@ -70,11 +70,13 @@ func Update(config *Config) func(db *gorm.DB) {
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
db.Statement.AddClause(set)
} else if _, ok := db.Statement.Clauses["SET"]; !ok {
} else {
return
}
}
db.Statement.Build(db.Statement.BuildClauses...)
}

View File

@ -13,7 +13,7 @@ import (
"gorm.io/gorm/utils"
)
// Create insert the value into database
// Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
return tx.callbacks.Create().Execute(tx)
}
// CreateInBatches insert the value in batches into database
// CreateInBatches inserts value in batches of batchSize
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
@ -68,7 +68,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
return
}
// Save update value in database, if the value doesn't have primary key, will insert it
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
@ -114,7 +114,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
return
}
// First find first record that match given conditions, order by primary key
// First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx)
}
// Take return a record that match given conditions, the order will depend on the database implementation
// Take finds the first record returned by the database in no specified order, matching given conditions conds
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1)
if len(conds) > 0 {
@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx)
}
// Last find last record that match given conditions, order by primary key
// Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx)
}
// Find find records that match given conditions
// Find finds all records matching given conditions conds
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx)
}
// FindInBatches find records in batches
// FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var (
tx = db.Order(clause.OrderByColumn{
@ -202,7 +202,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
batch++
if result.Error == nil && result.RowsAffected != 0 {
tx.AddError(fc(result, batch))
fcTx := result.Session(&Session{NewDB: true})
fcTx.RowsAffected = result.RowsAffected
tx.AddError(fc(fcTx, batch))
} else if result.Error != nil {
tx.AddError(result.Error)
}
@ -284,7 +286,8 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
}
}
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -310,7 +313,8 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
return
}
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
@ -358,14 +362,14 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx
}
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value}
return tx.callbacks.Update().Execute(tx)
}
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
@ -386,7 +390,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
return tx.callbacks.Update().Execute(tx)
}
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null.
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
@ -480,7 +486,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
return rows, tx.Error
}
// Scan scan value to a struct
// Scan scans selected value to the struct dest
func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
@ -505,7 +511,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
return
}
// Pluck used to query single column from a model as a map
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
// var ages []int64
// db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
@ -548,7 +554,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
return tx.Error
}
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil {
return db.Error
@ -570,7 +577,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
return fc(tx)
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
@ -613,7 +622,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
return
}
// Begin begins a transaction
// Begin begins a transaction with any transaction options opts
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
// clone statement
@ -642,7 +651,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
return tx
}
// Commit commit a transaction
// Commit commits the changes in a transaction
func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit())
@ -652,7 +661,7 @@ func (db *DB) Commit() *DB {
return db
}
// Rollback rollback a transaction
// Rollback rollbacks the changes in a transaction
func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
@ -682,7 +691,7 @@ func (db *DB) RollbackTo(name string) *DB {
return db
}
// Exec execute raw sql
// Exec executes raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}

2
go.mod
View File

@ -1,6 +1,6 @@
module gorm.io/gorm
go 1.14
go 1.16
require (
github.com/jinzhu/inflection v1.0.0

15
gorm.go
View File

@ -179,7 +179,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
preparedStmt := &PreparedStmtDB{
ConnPool: db.ConnPool,
Stmts: map[string]Stmt{},
Stmts: map[string](*Stmt){},
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
@ -248,11 +248,19 @@ func (db *DB) Session(config *Session) *DB {
if config.PrepareStmt {
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt := v.(*PreparedStmtDB)
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
@ -300,7 +308,8 @@ func (db *DB) WithContext(ctx context.Context) *DB {
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{
tx = db.getInstance()
return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info),
})
}
@ -412,7 +421,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil
if !isRelation {
return fmt.Errorf("failed to found relation: %s", field)
return fmt.Errorf("failed to find relation: %s", field)
}
for _, ref := range relation.References {

View File

@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"io"
"log"
"os"
"time"
@ -68,8 +68,8 @@ type Interface interface {
}
var (
// Discard Discard logger will print any log to ioutil.Discard
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{})
// Discard Discard logger will print any log to io.Discard
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond,

View File

@ -30,6 +30,8 @@ func isPrintable(s string) bool {
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var (
@ -138,9 +140,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
sql = newSQL.String()
} else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
num := v[1 : len(v)-1]
n, _ := strconv.Atoi(num)
// position var start from 1 ($1, $2)
n -= 1
if n >= 0 && n <= len(vars)-1 {
return vars[n]
}
return v
})
}
return sql

View File

@ -68,6 +68,7 @@ type Migrator interface {
// Database
CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string
// Tables
CreateTable(dst ...interface{}) error

View File

@ -15,7 +15,7 @@ import (
)
var (
regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`)
regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
)
// Migrator m struct
@ -135,6 +135,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
}
}
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
if !tx.Migrator().HasConstraint(value, chk.Name) {
@ -143,7 +144,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
}
}
}
}
for _, idx := range stmt.Schema.ParseIndexes() {
if !tx.Migrator().HasIndex(value, idx.Name) {
@ -408,10 +408,28 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn := false
if !field.PrimaryKey {
// check type
if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) {
var isSameType bool
if strings.HasPrefix(fullDataType, realDataType) {
isSameType = true
}
// check type aliases
if !isSameType {
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) {
isSameType = true
break
}
}
}
if !isSameType {
alterColumn = true
}
}
// check size
if length, ok := columnType.Length(); length != int64(field.Size) {
@ -478,7 +496,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
}
if alterColumn && !field.IgnoreMigration {
return m.DB.Migrator().AlterColumn(value, field.Name)
return m.DB.Migrator().AlterColumn(value, field.DBName)
}
return nil
@ -863,3 +881,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
return nil, errors.New("not support")
}
// GetTypeAliases return database type aliases
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return nil
}

View File

@ -9,10 +9,12 @@ import (
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
type PreparedStmtDB struct {
Stmts map[string]Stmt
Stmts map[string]*Stmt
PreparedSQL []string
Mux *sync.RWMutex
ConnPool
@ -46,27 +48,57 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
return stmt, nil
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
db.Mux.RUnlock()
db.Mux.Lock()
defer db.Mux.Unlock()
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
return stmt, nil
} else if ok {
go stmt.Close()
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeued because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err == nil {
db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
db.PreparedSQL = append(db.PreparedSQL, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
}
return db.Stmts[query], err
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()
return cacheStmt, nil
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {

53
scan.go
View File

@ -66,9 +66,12 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
db.RowsAffected++
db.AddError(rows.Scan(values...))
joinedSchemaMap := make(map[*schema.Field]interface{}, 0)
joinedSchemaMap := make(map[*schema.Field]interface{})
for idx, field := range fields {
if field != nil {
if field == nil {
continue
}
if len(joinFields) == 0 || joinFields[idx][0] == nil {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else {
@ -90,7 +93,6 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
// release data to pool
field.NewValuePool.Put(values[idx])
}
}
}
// ScanMode scan data mode
@ -162,7 +164,6 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
default:
var (
fields = make([]*schema.Field, len(columns))
selectedColumnsMap = make(map[string]int, len(columns))
joinFields [][2]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
@ -198,26 +199,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
// Not Pluck
if sch != nil {
schFieldsCount := len(sch.Fields)
matchedFieldCount := make(map[string]int, len(columns))
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
if curIndex, ok := selectedColumnsMap[column]; ok {
fields[idx] = field // handle duplicate fields
offset := curIndex + 1
// handle sch inconsistent with database
// like Raw(`...`).Scan
if schFieldsCount > offset {
for fieldIndex, selectField := range sch.Fields[offset:] {
fields[idx] = field
if count, ok := matchedFieldCount[column]; ok {
// handle duplicate fields
for _, selectField := range sch.Fields {
if selectField.DBName == column && selectField.Readable {
selectedColumnsMap[column] = curIndex + fieldIndex + 1
if count == 0 {
matchedFieldCount[column]++
fields[idx] = selectField
break
}
count--
}
}
} else {
fields[idx] = field
selectedColumnsMap[column] = idx
matchedFieldCount[column] = 1
}
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
@ -241,12 +240,21 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var elem reflect.Value
recyclableStruct := reflect.New(reflectValueType)
var (
elem reflect.Value
recyclableStruct = reflect.New(reflectValueType)
isArrayKind = reflectValue.Kind() == reflect.Array
)
if !update || reflectValue.Len() == 0 {
update = false
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else if !isArrayKind {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
}
for initialized || rows.Next() {
@ -277,10 +285,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update {
if isPtr {
reflectValue = reflect.Append(reflectValue, elem)
if !isPtr {
elem = elem.Elem()
}
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
}
} else {
reflectValue = reflect.Append(reflectValue, elem.Elem())
reflectValue = reflect.Append(reflectValue, elem)
}
}
}

View File

@ -403,18 +403,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
if ef.PrimaryKey {
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) {
ef.PrimaryKey = true
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
ef.PrimaryKey = true
} else {
if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) {
ef.PrimaryKey = false
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
ef.AutoIncrement = false
}
if ef.DefaultValue == "" {
if !ef.AutoIncrement && ef.DefaultValue == "" {
ef.HasDefaultValue = false
}
}
@ -472,9 +468,6 @@ func (field *Field) setupValuerAndSetter() {
oldValuerOf := field.ValueOf
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
value, zero := oldValuerOf(ctx, v)
if zero {
return value, zero
}
s, ok := value.(SerializerValuerInterface)
if !ok {
@ -487,7 +480,7 @@ func (field *Field) setupValuerAndSetter() {
Destination: v,
Context: ctx,
fieldValue: value,
}, false
}, zero
}
}

View File

@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
err error
joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]bool{} // fix self join many2many
ownFieldsMap = map[string]*Field{} // fix self join many2many
referFieldsMap = map[string]*Field{}
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
)
@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
joinFieldName = strings.Title(joinForeignKeys[idx])
}
ownFieldsMap[joinFieldName] = true
ownFieldsMap[joinFieldName] = ownField
fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
for idx, relField := range refForeignFields {
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name {
@ -254,6 +252,13 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
}
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
referFieldsMap[joinFieldName] = relField
if _, ok := fieldsMap[joinFieldName]; !ok {
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
@ -263,6 +268,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
}
joinTableFields = append(joinTableFields, reflect.StructField{
Name: strings.Title(schema.Name) + field.Name,
@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
f.Size = fieldsMap[f.Name].Size
}
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name]
if ownPrimaryField {
if of, ok := ownFieldsMap[f.Name]; ok {
joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
PrimaryKey: of,
ForeignKey: f,
})
} else {
relation.References = append(relation.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
OwnPrimaryKey: true,
})
}
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field
}
joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
PrimaryKey: rf,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
}
relation.References = append(relation.References, &Reference{
PrimaryKey: fieldsMap[f.Name],
ForeignKey: f,
OwnPrimaryKey: ownPrimaryField,
})
}
}
}

View File

@ -10,7 +10,7 @@ import (
func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) {
if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil {
t.Errorf("Failed to parse schema")
t.Errorf("Failed to parse schema, got error %v", err)
} else {
for _, rel := range relations {
checkSchemaRelation(t, s, rel)
@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) {
})
}
func TestMany2ManySharedForeignKey(t *testing.T) {
type Profile struct {
gorm.Model
Name string
Kind string
ProfileRefer uint
}
type User struct {
gorm.Model
Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"`
Kind string
Refer uint
}
checkStructRelation(t, &User{}, Relation{
Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile",
JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"},
References: []Reference{
{"Refer", "User", "UserRefer", "user_profiles", "", true},
{"Kind", "User", "Kind", "user_profiles", "", true},
{"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false},
{"Kind", "Profile", "Kind", "user_profiles", "", false},
},
})
}
func TestMany2ManyOverrideJoinForeignKey(t *testing.T) {
type Profile struct {
gorm.Model

View File

@ -112,7 +112,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
schemaCacheKey = modelType
}
// Load exist schmema cache, return if exists
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
@ -146,7 +146,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
// Load exist schmema cache, return if exists
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
@ -239,6 +239,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
field.HasDefaultValue = true
field.AutoIncrement = true
}
case String:
if _, ok := field.TagSettings["PRIMARYKEY"]; !ok {
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
field.HasDefaultValue = true
}
}
}

View File

@ -88,8 +88,10 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
}
if len(bytes) > 0 {
err = json.Unmarshal(bytes, fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
@ -117,9 +119,15 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.
// Value implements serializer interface
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
rv := reflect.ValueOf(fieldValue)
switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0)
case int64, int, uint, uint64, int32, uint32, int16, uint16:
result = time.Unix(reflect.Indirect(rv).Int(), 0)
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
if rv.IsZero() {
return nil, nil
}
result = time.Unix(reflect.Indirect(rv).Int(), 0)
default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
}
@ -142,9 +150,11 @@ func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
}
if len(bytesValue) > 0 {
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
err = decoder.Decode(fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}

View File

@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false
}
var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_0-9]+?)[\W]?\.[\W]?([a-z_0-9]+?)[\W]?$`)
var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
@ -672,8 +672,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
results[matches[1]] = true
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
results[matches[2]] = true
} else {
results[column] = true
}

View File

@ -36,17 +36,21 @@ func TestWhereCloneCorruption(t *testing.T) {
}
func TestNameMatcher(t *testing.T) {
for k, v := range map[string]string{
"table.name": "name",
"`table`.`name`": "name",
"'table'.'name'": "name",
"'table'.name": "name",
"table1.name_23": "name_23",
"`table_1`.`name23`": "name23",
"'table23'.'name_1'": "name_1",
"'table23'.name1": "name1",
for k, v := range map[string][]string{
"table.name": {"table", "name"},
"`table`.`name`": {"table", "name"},
"'table'.'name'": {"table", "name"},
"'table'.name": {"table", "name"},
"table1.name_23": {"table1", "name_23"},
"`table_1`.`name23`": {"table_1", "name23"},
"'table23'.'name_1'": {"table23", "name_1"},
"'table23'.name1": {"table23", "name1"},
"'name1'": {"", "name1"},
"`name_1`": {"", "name_1"},
"`Name_1`": {"", "Name_1"},
"`Table`.`nAme`": {"Table", "nAme"},
} {
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v {
if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] {
t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v)
}
}

View File

@ -4,6 +4,8 @@ import (
"testing"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
. "gorm.io/gorm/utils/tests"
)
@ -284,3 +286,65 @@ func TestAssociationError(t *testing.T) {
err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages)
AssertEqual(t, err, gorm.ErrPrimaryKeyRequired)
}
type (
myType string
emptyQueryClause struct {
Field *schema.Field
}
)
func (myType) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{emptyQueryClause{Field: f}}
}
func (sd emptyQueryClause) Name() string {
return "empty"
}
func (sd emptyQueryClause) Build(clause.Builder) {
}
func (sd emptyQueryClause) MergeClause(*clause.Clause) {
}
func (sd emptyQueryClause) ModifyStatement(stmt *gorm.Statement) {
// do nothing
}
func TestAssociationEmptyQueryClause(t *testing.T) {
type Organization struct {
gorm.Model
Name string
}
type Region struct {
gorm.Model
Name string
Organizations []Organization `gorm:"many2many:region_orgs;"`
}
type RegionOrg struct {
RegionId uint
OrganizationId uint
Empty myType
}
if err := DB.SetupJoinTable(&Region{}, "Organizations", &RegionOrg{}); err != nil {
t.Fatalf("Failed to set up join table, got error: %s", err)
}
if err := DB.Migrator().DropTable(&Organization{}, &Region{}); err != nil {
t.Fatalf("Failed to migrate, got error: %s", err)
}
if err := DB.AutoMigrate(&Organization{}, &Region{}); err != nil {
t.Fatalf("Failed to migrate, got error: %v", err)
}
region := &Region{Name: "Region1"}
if err := DB.Create(region).Error; err != nil {
t.Fatalf("fail to create region %v", err)
}
var orgs []Organization
if err := DB.Model(&Region{}).Association("Organizations").Find(&orgs); err != nil {
t.Fatalf("fail to find region organizations %v", err)
} else {
AssertEqual(t, len(orgs), 0)
}
}

View File

@ -113,6 +113,9 @@ func TestCallbacks(t *testing.T) {
for idx, data := range datas {
db, err := gorm.Open(nil, nil)
if err != nil {
t.Fatal(err)
}
callbacks := db.Callback()
for _, c := range data.callbacks {

View File

@ -168,3 +168,29 @@ func TestEmbeddedRelations(t *testing.T) {
}
}
}
func TestEmbeddedTagSetting(t *testing.T) {
type Tag1 struct {
Id int64 `gorm:"autoIncrement"`
}
type Tag2 struct {
Id int64
}
type EmbeddedTag struct {
Tag1 Tag1 `gorm:"Embedded;"`
Tag2 Tag2 `gorm:"Embedded;EmbeddedPrefix:t2_"`
Name string
}
DB.Migrator().DropTable(&EmbeddedTag{})
err := DB.Migrator().AutoMigrate(&EmbeddedTag{})
AssertEqual(t, err, nil)
t1 := EmbeddedTag{Name: "embedded_tag"}
err = DB.Save(&t1).Error
AssertEqual(t, err, nil)
if t1.Tag1.Id == 0 {
t.Errorf("embedded struct's primary field should be rewrited")
}
}

View File

@ -1,20 +1,20 @@
module gorm.io/gorm/tests
go 1.14
go 1.16
require (
github.com/denisenkom/go-mssqldb v0.12.2 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.6
github.com/mattn/go-sqlite3 v1.14.14 // indirect
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect
gorm.io/driver/mysql v1.3.4
gorm.io/driver/postgres v1.3.8
github.com/lib/pq v1.10.7
github.com/mattn/go-sqlite3 v1.14.15 // indirect
golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect
gorm.io/driver/mysql v1.3.6
gorm.io/driver/postgres v1.3.10
gorm.io/driver/sqlite v1.3.6
gorm.io/driver/sqlserver v1.3.2
gorm.io/gorm v1.23.7
gorm.io/gorm v1.23.10
)
replace gorm.io/gorm => ../

View File

@ -80,6 +80,7 @@ func CheckPet(t *testing.T, pet Pet, expect Pet) {
t.Fatalf("errors happened when query: %v", err)
} else {
AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
AssertObjEqual(t, newPet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name")
}
}
@ -174,6 +175,7 @@ func CheckUser(t *testing.T, user User, expect User) {
var manager User
DB.First(&manager, "id = ?", *user.ManagerID)
AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active")
}
} else if user.ManagerID != nil {
t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID)

View File

@ -229,3 +229,34 @@ func TestJoinWithSoftDeleted(t *testing.T) {
t.Fatalf("joins NamedPet and Account should not empty:%v", user2)
}
}
func TestJoinWithSameColumnName(t *testing.T) {
user := GetUser("TestJoinWithSameColumnName", Config{
Languages: 1,
Pets: 1,
})
DB.Create(user)
type UserSpeak struct {
UserID uint
LanguageCode string
}
type Result struct {
User
UserSpeak
Language
Pet
}
results := make([]Result, 0, 1)
DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id").
Joins("JOIN languages ON languages.code = user_speaks.language_code").
Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results)
if len(results) == 0 {
t.Fatalf("no record find")
} else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID {
t.Fatalf("wrong user id in pet")
} else if results[0].Pet.Name != user.Pets[0].Name {
t.Fatalf("wrong pet name")
}
}

View File

@ -1048,3 +1048,41 @@ func TestMigrateDonotAlterColumn(t *testing.T) {
err = mockM.AutoMigrate(&NotTriggerUpdate{})
AssertEqual(t, err, nil)
}
func TestMigrateSameEmbeddedFieldName(t *testing.T) {
type UserStat struct {
GroundDestroyCount int
}
type GameUser struct {
gorm.Model
StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"`
}
type UserStat1 struct {
GroundDestroyCount string
}
type GroundRate struct {
GroundDestroyCount int
}
type GameUser1 struct {
gorm.Model
StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"`
GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"`
}
DB.Migrator().DropTable(&GameUser{})
err := DB.AutoMigrate(&GameUser{})
AssertEqual(t, nil, err)
err = DB.Table("game_users").AutoMigrate(&GameUser1{})
AssertEqual(t, nil, err)
_, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count")
AssertEqual(t, nil, err)
_, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count")
AssertEqual(t, nil, err)
}

View File

@ -9,6 +9,56 @@ import (
"gorm.io/gorm"
)
func TestPostgresReturningIDWhichHasStringType(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
t.Skip()
}
type Yasuo struct {
ID string `gorm:"default:gen_random_uuid()"`
Name string
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"`
}
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
t.Errorf("Failed to create extension pgcrypto, got error %v", err)
}
DB.Migrator().DropTable(&Yasuo{})
if err := DB.AutoMigrate(&Yasuo{}); err != nil {
t.Fatalf("Failed to migrate for uuid default value, got error: %v", err)
}
yasuo := Yasuo{Name: "jinzhu"}
if err := DB.Create(&yasuo).Error; err != nil {
t.Fatalf("should be able to create data, but got %v", err)
}
if yasuo.ID == "" {
t.Fatal("should be able to has ID, but got zero value")
}
var result Yasuo
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" {
t.Errorf("No error should happen, but got %v", err)
}
yasuo.Name = "jinzhu1"
if err := DB.Save(&yasuo).Error; err != nil {
t.Errorf("Failed to update date, got error %v", err)
}
if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" {
t.Errorf("No error should happen, but got %v", err)
}
}
func TestPostgres(t *testing.T) {
if DB.Dialector.Name() != "postgres" {
t.Skip()
@ -63,13 +113,13 @@ func TestPostgres(t *testing.T) {
}
type Post struct {
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
Title string
Categories []*Category `gorm:"Many2Many:post_categories"`
}
type Category struct {
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"`
ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"`
Title string
Posts []*Post `gorm:"Many2Many:post_categories"`
}

View File

@ -2,6 +2,8 @@ package tests_test
import (
"context"
"sync"
"errors"
"testing"
"time"
@ -88,3 +90,81 @@ func TestPreparedStmtFromTransaction(t *testing.T) {
}
tx2.Commit()
}
func TestPreparedStmtDeadlock(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)
tx = tx.Session(&gorm.Session{PrepareStmt: true})
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
user := User{Name: "jinzhu"}
tx.Create(&user)
var result User
tx.First(&result)
wg.Done()
}()
}
wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 2)
for _, stmt := range conn.Stmts {
if stmt == nil {
t.Fatalf("stmt cannot bee nil")
}
}
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtError(t *testing.T) {
tx, err := OpenTestConnection()
AssertEqual(t, err, nil)
sqlDB, _ := tx.DB()
sqlDB.SetMaxOpenConns(1)
tx = tx.Session(&gorm.Session{PrepareStmt: true})
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
// err prepare
tag := Tag{Locale: "zh"}
tx.Table("users").Find(&tag)
wg.Done()
}()
}
wg.Wait()
conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB)
AssertEqual(t, ok, true)
AssertEqual(t, len(conn.Stmts), 0)
AssertEqual(t, sqlDB.Stats().InUse, 0)
}
func TestPreparedStmtInTransaction(t *testing.T) {
user := User{Name: "jinzhu"}
if err := DB.Transaction(func(tx *gorm.DB) error {
tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user)
return errors.New("test")
}); err == nil {
t.Error(err)
}
var result User
if err := DB.First(&result, user.ID).Error; err == nil {
t.Errorf("Failed, got error: %v", err)
}
}

View File

@ -216,6 +216,30 @@ func TestFind(t *testing.T) {
}
}
// test array
var models2 [3]User
if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 {
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2))
} else {
for idx, user := range users {
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
CheckUser(t, models2[idx], user)
})
}
}
// test smaller array
var models3 [2]User
if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 {
t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3))
} else {
for idx, user := range users[:2] {
t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) {
CheckUser(t, models3[idx], user)
})
}
}
var none []User
if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 {
t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none))
@ -257,7 +281,7 @@ func TestFindInBatches(t *testing.T) {
totalBatch int
)
if result := DB.Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
if result := DB.Table("users as u").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
totalBatch += batch
if tx.RowsAffected != 2 {
@ -273,7 +297,7 @@ func TestFindInBatches(t *testing.T) {
}
if err := tx.Save(results).Error; err != nil {
t.Errorf("failed to save users, got error %v", err)
t.Fatalf("failed to save users, got error %v", err)
}
return nil

View File

@ -113,6 +113,43 @@ func TestSerializer(t *testing.T) {
}
AssertEqual(t, result, data)
if err := DB.Model(&result).Update("roles", "").Error; err != nil {
t.Fatalf("failed to update data's roles, got error %v", err)
}
if err := DB.First(&result, data.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
}
func TestSerializerZeroValue(t *testing.T) {
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
DB.Migrator().DropTable(&SerializerStruct{})
if err := DB.Migrator().AutoMigrate(&SerializerStruct{}); err != nil {
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
}
data := SerializerStruct{}
if err := DB.Create(&data).Error; err != nil {
t.Fatalf("failed to create data, got error %v", err)
}
var result SerializerStruct
if err := DB.First(&result, data.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
AssertEqual(t, result, data)
if err := DB.Model(&result).Update("roles", "").Error; err != nil {
t.Fatalf("failed to update data's roles, got error %v", err)
}
if err := DB.First(&result, data.ID).Error; err != nil {
t.Fatalf("failed to query data, got error %v", err)
}
}
func TestSerializerAssignFirstOrCreate(t *testing.T) {

View File

@ -102,7 +102,7 @@ func TestTransactionWithBlock(t *testing.T) {
return errors.New("the error message")
})
if err.Error() != "the error message" {
if err != nil && err.Error() != "the error message" {
t.Fatalf("Transaction return error will equal the block returns error")
}

View File

@ -41,4 +41,19 @@ func TestUpdateBelongsTo(t *testing.T) {
var user4 User
DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID)
CheckUser(t, user4, user)
user.Company.Name += "new2"
user.Manager.Name += "new2"
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Select("`Company`").Save(&user).Error; err != nil {
t.Fatalf("errors happened when update: %v", err)
}
var user5 User
DB.Preload("Company").Preload("Manager").Find(&user5, "id = ?", user.ID)
if user5.Manager.Name != user4.Manager.Name {
t.Errorf("should not update user's manager")
} else {
user.Manager.Name = user4.Manager.Name
}
CheckUser(t, user, user5)
}

View File

@ -92,6 +92,7 @@ func TestUpdateHasOne(t *testing.T) {
gorm.Model
UserID sql.NullInt64
Number string `gorm:"<-:create"`
Number2 string
}
type CustomizeUser struct {
@ -115,6 +116,7 @@ func TestUpdateHasOne(t *testing.T) {
Name: "update-has-one-associations",
Account: CustomizeAccount{
Number: number,
Number2: number,
},
}
@ -122,6 +124,7 @@ func TestUpdateHasOne(t *testing.T) {
t.Fatalf("errors happened when create: %v", err)
}
cusUser.Account.Number += "-update"
cusUser.Account.Number2 += "-update"
if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil {
t.Fatalf("errors happened when create: %v", err)
}
@ -129,5 +132,6 @@ func TestUpdateHasOne(t *testing.T) {
var account2 CustomizeAccount
DB.Find(&account2, "user_id = ?", cusUser.ID)
AssertEqual(t, account2.Number, number)
AssertEqual(t, account2.Number2, cusUser.Account.Number2)
})
}

View File

@ -307,6 +307,8 @@ func TestSelectWithUpdate(t *testing.T) {
if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) {
t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt)
}
AssertObjEqual(t, result, User{Name: "update_with_select"}, "Name", "Age")
}
func TestSelectWithUpdateWithMap(t *testing.T) {

View File

@ -62,7 +62,7 @@ func TestUpsert(t *testing.T) {
}
r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"})
if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) {
if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) {
t.Errorf("Table with escape character, got %v", r.Statement.SQL.String())
}
}

View File

@ -64,8 +64,8 @@ type Language struct {
type Coupon struct {
ID int `gorm:"primarykey; size:255"`
AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"`
AmountOff uint32 `gorm:"amount_off"`
PercentOff float32 `gorm:"percent_off"`
AmountOff uint32 `gorm:"column:amount_off"`
PercentOff float32 `gorm:"column:percent_off"`
}
type CouponProduct struct {

View File

@ -12,3 +12,20 @@ func TestIsValidDBNameChar(t *testing.T) {
}
}
}
func TestToStringKey(t *testing.T) {
cases := []struct {
values []interface{}
key string
}{
{[]interface{}{"a"}, "a"},
{[]interface{}{1, 2, 3}, "1_2_3"},
{[]interface{}{[]interface{}{1, 2, 3}}, "[1 2 3]"},
{[]interface{}{[]interface{}{"1", "2", "3"}}, "[1 2 3]"},
}
for _, c := range cases {
if key := ToStringKey(c.values...); key != c.key {
t.Errorf("%v: expected %v, got %v", c.values, c.key, key)
}
}
}