diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index aa1812d4..bc4487ae 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index c3c92beb..f9f51aa0 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index af8d3636..a9aff43a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -16,7 +16,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v5 + uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b97da3f4..367f4ccd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -42,7 +42,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -86,7 +86,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:13', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -128,7 +128,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.18', '1.17', '1.16'] + go: ['1.19', '1.18', '1.17', '1.16'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} diff --git a/callbacks/update.go b/callbacks/update.go index 48c61bf4..b596df9a 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -70,10 +70,12 @@ func Update(config *Config) func(db *gorm.DB) { if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else if _, ok := db.Statement.Clauses["SET"]; !ok { - return + if _, ok := db.Statement.Clauses["SET"]; !ok { + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } } db.Statement.Build(db.Statement.BuildClauses...) diff --git a/finisher_api.go b/finisher_api.go index af9afb63..835a6984 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -13,7 +13,7 @@ import ( "gorm.io/gorm/utils" ) -// Create insert the value into database +// Create inserts value, returning the inserted data's primary key in value's id func (db *DB) Create(value interface{}) (tx *DB) { if db.CreateBatchSize > 0 { return db.CreateInBatches(value, db.CreateBatchSize) @@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) { return tx.callbacks.Create().Execute(tx) } -// CreateInBatches insert the value in batches into database +// CreateInBatches inserts value in batches of batchSize func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { reflectValue := reflect.Indirect(reflect.ValueOf(value)) @@ -68,7 +68,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { return } -// Save update value in database, if the value doesn't have primary key, will insert it +// Save updates value in database. If value doesn't contain a matching primary key, value is inserted. func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value @@ -114,7 +114,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { return } -// First find first record that match given conditions, order by primary key +// First finds the first record ordered by primary key, matching given conditions conds func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Take return a record that match given conditions, the order will depend on the database implementation +// Take finds the first record returned by the database in no specified order, matching given conditions conds func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { @@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Last find last record that match given conditions, order by primary key +// Last finds the last record ordered by primary key, matching given conditions conds func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// Find find records that match given conditions +// Find finds all records matching given conditions conds func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { @@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { return tx.callbacks.Query().Execute(tx) } -// FindInBatches find records in batches +// FindInBatches finds all records in batches of batchSize func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { var ( tx = db.Order(clause.OrderByColumn{ @@ -286,7 +286,8 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) { } } -// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) +// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. +// Each conds must be a struct or map. func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, @@ -312,7 +313,8 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { return } -// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) +// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. +// Each conds must be a struct or map. func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ @@ -360,14 +362,14 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx } -// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} return tx.callbacks.Update().Execute(tx) } -// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields +// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values @@ -388,7 +390,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) { return tx.callbacks.Update().Execute(tx) } -// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If +// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current +// time if null. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { @@ -482,7 +486,7 @@ func (db *DB) Rows() (*sql.Rows, error) { return rows, tx.Error } -// Scan scan value to a struct +// Scan scans selected value to the struct dest func (db *DB) Scan(dest interface{}) (tx *DB) { config := *db.Config currentLogger, newLogger := config.Logger, logger.Recorder.New() @@ -507,7 +511,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { return } -// Pluck used to query single column from a model as a map +// Pluck queries a single column from a model, returning in the slice dest. E.g.: // var ages []int64 // db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { @@ -550,7 +554,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { return tx.Error } -// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. +// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is +// returned to the connection pool. func (db *DB) Connection(fc func(tx *DB) error) (err error) { if db.Error != nil { return db.Error @@ -572,7 +577,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) { return fc(tx) } -// Transaction start a transaction as a block, return error will rollback, otherwise to commit. +// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an +// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs +// they are rolled back. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true @@ -615,7 +622,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er return } -// Begin begins a transaction +// Begin begins a transaction with any transaction options opts func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement @@ -644,7 +651,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB { return tx } -// Commit commit a transaction +// Commit commits the changes in a transaction func (db *DB) Commit() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) @@ -654,7 +661,7 @@ func (db *DB) Commit() *DB { return db } -// Rollback rollback a transaction +// Rollback rollbacks the changes in a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if !reflect.ValueOf(committer).IsNil() { @@ -684,7 +691,7 @@ func (db *DB) RollbackTo(name string) *DB { return db } -// Exec execute raw sql +// Exec executes raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} diff --git a/go.mod b/go.mod index 57362745..03f84379 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module gorm.io/gorm -go 1.14 +go 1.16 require ( github.com/jinzhu/inflection v1.0.0 diff --git a/gorm.go b/gorm.go index 419f9155..589fc4ff 100644 --- a/gorm.go +++ b/gorm.go @@ -248,10 +248,18 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt := v.(*PreparedStmtDB) - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, + } } txConfig.ConnPool = tx.Statement.ConnPool txConfig.PrepareStmt = true @@ -413,7 +421,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac relation, ok := modelSchema.Relationships.Relations[field] isRelation := ok && relation.JoinTable != nil if !isRelation { - return fmt.Errorf("failed to found relation: %s", field) + return fmt.Errorf("failed to find relation: %s", field) } for _, ref := range relation.References { diff --git a/logger/logger.go b/logger/logger.go index 2ffd28d5..ce088561 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "io/ioutil" + "io" "log" "os" "time" @@ -68,8 +68,8 @@ type Interface interface { } var ( - // Discard Discard logger will print any log to ioutil.Discard - Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + // Discard Discard logger will print any log to io.Discard + Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, diff --git a/migrator.go b/migrator.go index 34e888f2..882fc4cc 100644 --- a/migrator.go +++ b/migrator.go @@ -68,6 +68,7 @@ type Migrator interface { // Database CurrentDatabase() string FullDataTypeOf(*schema.Field) clause.Expr + GetTypeAliases(databaseTypeName string) []string // Tables CreateTable(dst ...interface{}) error diff --git a/migrator/migrator.go b/migrator/migrator.go index 87ac7745..29c0c00c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -15,7 +15,7 @@ import ( ) var ( - regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) + regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) ) // Migrator m struct @@ -135,12 +135,12 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } } + } - for _, chk := range stmt.Schema.ParseCheckConstraints() { - if !tx.Migrator().HasConstraint(value, chk.Name) { - if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { - return err - } + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if !tx.Migrator().HasConstraint(value, chk.Name) { + if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err } } } @@ -408,9 +408,27 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn := false - // check type - if !field.PrimaryKey && !strings.HasPrefix(fullDataType, realDataType) { - alterColumn = true + if !field.PrimaryKey { + // check type + var isSameType bool + if strings.HasPrefix(fullDataType, realDataType) { + isSameType = true + } + + // check type aliases + if !isSameType { + aliases := m.DB.Migrator().GetTypeAliases(realDataType) + for _, alias := range aliases { + if strings.HasPrefix(fullDataType, alias) { + isSameType = true + break + } + } + } + + if !isSameType { + alterColumn = true + } } // check size @@ -478,7 +496,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.Name) + return m.DB.Migrator().AlterColumn(value, field.DBName) } return nil @@ -863,3 +881,8 @@ func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { return nil, errors.New("not support") } + +// GetTypeAliases return database type aliases +func (m Migrator) GetTypeAliases(databaseTypeName string) []string { + return nil +} diff --git a/scan.go b/scan.go index 6250fb57..3a753dca 100644 --- a/scan.go +++ b/scan.go @@ -66,30 +66,32 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - joinedSchemaMap := make(map[*schema.Field]interface{}, 0) + joinedSchemaMap := make(map[*schema.Field]interface{}) for idx, field := range fields { - if field != nil { - if len(joinFields) == 0 || joinFields[idx][0] == nil { - db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) - } else { - joinSchema := joinFields[idx][0] - relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr { - if _, ok := joinedSchemaMap[joinSchema]; !ok { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } - - relValue.Set(reflect.New(relValue.Type().Elem())) - joinedSchemaMap[joinSchema] = nil - } - } - db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) - } - - // release data to pool - field.NewValuePool.Put(values[idx]) + if field == nil { + continue } + + if len(joinFields) == 0 || joinFields[idx][0] == nil { + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } else { + joinSchema := joinFields[idx][0] + relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) + if relValue.Kind() == reflect.Ptr { + if _, ok := joinedSchemaMap[joinSchema]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + continue + } + + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedSchemaMap[joinSchema] = nil + } + } + db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) + } + + // release data to pool + field.NewValuePool.Put(values[idx]) } } @@ -161,11 +163,10 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } default: var ( - fields = make([]*schema.Field, len(columns)) - selectedColumnsMap = make(map[string]int, len(columns)) - joinFields [][2]*schema.Field - sch = db.Statement.Schema - reflectValue = db.Statement.ReflectValue + fields = make([]*schema.Field, len(columns)) + joinFields [][2]*schema.Field + sch = db.Statement.Schema + reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { @@ -198,26 +199,24 @@ func Scan(rows Rows, db *DB, mode ScanMode) { // Not Pluck if sch != nil { - schFieldsCount := len(sch.Fields) + matchedFieldCount := make(map[string]int, len(columns)) for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { - if curIndex, ok := selectedColumnsMap[column]; ok { - fields[idx] = field // handle duplicate fields - offset := curIndex + 1 - // handle sch inconsistent with database - // like Raw(`...`).Scan - if schFieldsCount > offset { - for fieldIndex, selectField := range sch.Fields[offset:] { - if selectField.DBName == column && selectField.Readable { - selectedColumnsMap[column] = curIndex + fieldIndex + 1 + fields[idx] = field + if count, ok := matchedFieldCount[column]; ok { + // handle duplicate fields + for _, selectField := range sch.Fields { + if selectField.DBName == column && selectField.Readable { + if count == 0 { + matchedFieldCount[column]++ fields[idx] = selectField break } + count-- } } } else { - fields[idx] = field - selectedColumnsMap[column] = idx + matchedFieldCount[column] = 1 } } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := sch.Relationships.Relations[names[0]]; ok { @@ -241,12 +240,21 @@ func Scan(rows Rows, db *DB, mode ScanMode) { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - var elem reflect.Value - recyclableStruct := reflect.New(reflectValueType) + var ( + elem reflect.Value + recyclableStruct = reflect.New(reflectValueType) + isArrayKind = reflectValue.Kind() == reflect.Array + ) if !update || reflectValue.Len() == 0 { update = false - db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + // if the slice cap is externally initialized, the externally initialized slice is directly used here + if reflectValue.Cap() == 0 { + db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) + } else if !isArrayKind { + reflectValue.SetLen(0) + db.Statement.ReflectValue.Set(reflectValue) + } } for initialized || rows.Next() { @@ -277,10 +285,15 @@ func Scan(rows Rows, db *DB, mode ScanMode) { db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { - if isPtr { - reflectValue = reflect.Append(reflectValue, elem) + if !isPtr { + elem = elem.Elem() + } + if isArrayKind { + if reflectValue.Len() >= int(db.RowsAffected) { + reflectValue.Index(int(db.RowsAffected - 1)).Set(elem) + } } else { - reflectValue = reflect.Append(reflectValue, elem.Elem()) + reflectValue = reflect.Append(reflectValue, elem) } } } @@ -304,4 +317,4 @@ func Scan(rows Rows, db *DB, mode ScanMode) { if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } -} +} \ No newline at end of file diff --git a/schema/relationship.go b/schema/relationship.go index 0aa33e51..bb8aeb64 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -191,7 +191,8 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} - ownFieldsMap = map[string]bool{} // fix self join many2many + ownFieldsMap = map[string]*Field{} // fix self join many2many + referFieldsMap = map[string]*Field{} joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) ) @@ -229,7 +230,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel joinFieldName = strings.Title(joinForeignKeys[idx]) } - ownFieldsMap[joinFieldName] = true + ownFieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField joinTableFields = append(joinTableFields, reflect.StructField{ Name: joinFieldName, @@ -242,9 +243,6 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel for idx, relField := range refForeignFields { joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name - if len(joinReferences) > idx { - joinFieldName = strings.Title(joinReferences[idx]) - } if _, ok := ownFieldsMap[joinFieldName]; ok { if field.Name != relation.FieldSchema.Name { @@ -254,14 +252,22 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } } - fieldsMap[joinFieldName] = relField - joinTableFields = append(joinTableFields, reflect.StructField{ - Name: joinFieldName, - PkgPath: relField.StructField.PkgPath, - Type: relField.StructField.Type, - Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), - "column", "autoincrement", "index", "unique", "uniqueindex"), - }) + if len(joinReferences) > idx { + joinFieldName = strings.Title(joinReferences[idx]) + } + + referFieldsMap[joinFieldName] = relField + + if _, ok := fieldsMap[joinFieldName]; !ok { + fieldsMap[joinFieldName] = relField + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: joinFieldName, + PkgPath: relField.StructField.PkgPath, + Type: relField.StructField.Type, + Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), + "column", "autoincrement", "index", "unique", "uniqueindex"), + }) + } } joinTableFields = append(joinTableFields, reflect.StructField{ @@ -317,31 +323,37 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel f.Size = fieldsMap[f.Name].Size } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) - ownPrimaryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - if ownPrimaryField { + if of, ok := ownFieldsMap[f.Name]; ok { joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel.Field = relation.Field joinRel.References = append(joinRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: of, ForeignKey: f, }) - } else { + + relation.References = append(relation.References, &Reference{ + PrimaryKey: of, + ForeignKey: f, + OwnPrimaryKey: true, + }) + } + + if rf, ok := referFieldsMap[f.Name]; ok { joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] if joinRefRel.Field == nil { joinRefRel.Field = relation.Field } joinRefRel.References = append(joinRefRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], + PrimaryKey: rf, + ForeignKey: f, + }) + + relation.References = append(relation.References, &Reference{ + PrimaryKey: rf, ForeignKey: f, }) } - - relation.References = append(relation.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - OwnPrimaryKey: ownPrimaryField, - }) } } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 6fffbfcb..85c45589 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -10,7 +10,7 @@ import ( func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { - t.Errorf("Failed to parse schema") + t.Errorf("Failed to parse schema, got error %v", err) } else { for _, rel := range relations { checkSchemaRelation(t, s, rel) @@ -305,6 +305,33 @@ func TestMany2ManyOverrideForeignKey(t *testing.T) { }) } +func TestMany2ManySharedForeignKey(t *testing.T) { + type Profile struct { + gorm.Model + Name string + Kind string + ProfileRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"` + Kind string + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + References: []Reference{ + {"Refer", "User", "UserRefer", "user_profiles", "", true}, + {"Kind", "User", "Kind", "user_profiles", "", true}, + {"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false}, + {"Kind", "Profile", "Kind", "user_profiles", "", false}, + }, + }) +} + func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type Profile struct { gorm.Model diff --git a/schema/schema.go b/schema/schema.go index 3791237d..42ff5c45 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -239,6 +239,14 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.HasDefaultValue = true field.AutoIncrement = true } + case String: + if _, ok := field.TagSettings["PRIMARYKEY"]; !ok { + if !field.HasDefaultValue || field.DefaultValueInterface != nil { + schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) + } + + field.HasDefaultValue = true + } } } diff --git a/statement.go b/statement.go index 12687810..cc26fe37 100644 --- a/statement.go +++ b/statement.go @@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -var nameMatcher = regexp.MustCompile(`^(?:[\W]?([A-Za-z_0-9]+?)[\W]?\.)?[\W]?([A-Za-z_0-9]+?)[\W]?$`) +var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`) // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { diff --git a/statement_test.go b/statement_test.go index a537c7be..761daf37 100644 --- a/statement_test.go +++ b/statement_test.go @@ -37,18 +37,18 @@ func TestWhereCloneCorruption(t *testing.T) { func TestNameMatcher(t *testing.T) { for k, v := range map[string][]string{ - "table.name": []string{"table", "name"}, - "`table`.`name`": []string{"table", "name"}, - "'table'.'name'": []string{"table", "name"}, - "'table'.name": []string{"table", "name"}, - "table1.name_23": []string{"table1", "name_23"}, - "`table_1`.`name23`": []string{"table_1", "name23"}, - "'table23'.'name_1'": []string{"table23", "name_1"}, - "'table23'.name1": []string{"table23", "name1"}, - "'name1'": []string{"", "name1"}, - "`name_1`": []string{"", "name_1"}, - "`Name_1`": []string{"", "Name_1"}, - "`Table`.`nAme`": []string{"Table", "nAme"}, + "table.name": {"table", "name"}, + "`table`.`name`": {"table", "name"}, + "'table'.'name'": {"table", "name"}, + "'table'.name": {"table", "name"}, + "table1.name_23": {"table1", "name_23"}, + "`table_1`.`name23`": {"table_1", "name23"}, + "'table23'.'name_1'": {"table23", "name_1"}, + "'table23'.name1": {"table23", "name1"}, + "'name1'": {"", "name1"}, + "`name_1`": {"", "name_1"}, + "`Name_1`": {"", "Name_1"}, + "`Table`.`nAme`": {"Table", "nAme"}, } { if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 3 || matches[1] != v[0] || matches[2] != v[1] { t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) diff --git a/tests/callbacks_test.go b/tests/callbacks_test.go index 2bf9496b..4479da4c 100644 --- a/tests/callbacks_test.go +++ b/tests/callbacks_test.go @@ -113,6 +113,9 @@ func TestCallbacks(t *testing.T) { for idx, data := range datas { db, err := gorm.Open(nil, nil) + if err != nil { + t.Fatal(err) + } callbacks := db.Callback() for _, c := range data.callbacks { diff --git a/tests/go.mod b/tests/go.mod index eb8f336d..c1e1e0ce 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -1,19 +1,20 @@ module gorm.io/gorm/tests -go 1.14 +go 1.16 require ( + github.com/denisenkom/go-mssqldb v0.12.2 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.5 - github.com/lib/pq v1.10.6 - github.com/mattn/go-sqlite3 v1.14.14 // indirect - golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d // indirect - gorm.io/driver/mysql v1.3.5 - gorm.io/driver/postgres v1.3.8 + github.com/lib/pq v1.10.7 + github.com/mattn/go-sqlite3 v1.14.15 // indirect + golang.org/x/crypto v0.0.0-20220926161630-eccd6366d1be // indirect + gorm.io/driver/mysql v1.3.6 + gorm.io/driver/postgres v1.3.10 gorm.io/driver/sqlite v1.3.6 gorm.io/driver/sqlserver v1.3.2 - gorm.io/gorm v1.23.8 + gorm.io/gorm v1.23.10 ) replace gorm.io/gorm => ../ diff --git a/tests/joins_test.go b/tests/joins_test.go index 4908e5ba..7519db82 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -229,3 +229,34 @@ func TestJoinWithSoftDeleted(t *testing.T) { t.Fatalf("joins NamedPet and Account should not empty:%v", user2) } } + +func TestJoinWithSameColumnName(t *testing.T) { + user := GetUser("TestJoinWithSameColumnName", Config{ + Languages: 1, + Pets: 1, + }) + DB.Create(user) + type UserSpeak struct { + UserID uint + LanguageCode string + } + type Result struct { + User + UserSpeak + Language + Pet + } + + results := make([]Result, 0, 1) + DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id"). + Joins("JOIN languages ON languages.code = user_speaks.language_code"). + Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results) + + if len(results) == 0 { + t.Fatalf("no record find") + } else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID { + t.Fatalf("wrong user id in pet") + } else if results[0].Pet.Name != user.Pets[0].Name { + t.Fatalf("wrong pet name") + } +} diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0b5bc5eb..32e84e77 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -959,3 +959,41 @@ func TestMigrateArrayTypeModel(t *testing.T) { AssertEqual(t, nil, err) AssertEqual(t, "integer[]", ct.DatabaseTypeName()) } + +func TestMigrateSameEmbeddedFieldName(t *testing.T) { + type UserStat struct { + GroundDestroyCount int + } + + type GameUser struct { + gorm.Model + StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"` + } + + type UserStat1 struct { + GroundDestroyCount string + } + + type GroundRate struct { + GroundDestroyCount int + } + + type GameUser1 struct { + gorm.Model + StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"` + GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"` + } + + DB.Migrator().DropTable(&GameUser{}) + err := DB.AutoMigrate(&GameUser{}) + AssertEqual(t, nil, err) + + err = DB.Table("game_users").AutoMigrate(&GameUser1{}) + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "stat_ab_ground_destory_count") + AssertEqual(t, nil, err) + + _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destory_count") + AssertEqual(t, nil, err) +} diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 66b988c3..b5b672a9 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -9,6 +9,56 @@ import ( "gorm.io/gorm" ) +func TestPostgresReturningIDWhichHasStringType(t *testing.T) { + if DB.Dialector.Name() != "postgres" { + t.Skip() + } + + type Yasuo struct { + ID string `gorm:"default:gen_random_uuid()"` + Name string + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` + } + + if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { + t.Errorf("Failed to create extension pgcrypto, got error %v", err) + } + + DB.Migrator().DropTable(&Yasuo{}) + + if err := DB.AutoMigrate(&Yasuo{}); err != nil { + t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) + } + + yasuo := Yasuo{Name: "jinzhu"} + if err := DB.Create(&yasuo).Error; err != nil { + t.Fatalf("should be able to create data, but got %v", err) + } + + if yasuo.ID == "" { + t.Fatal("should be able to has ID, but got zero value") + } + + var result Yasuo + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" { + t.Errorf("No error should happen, but got %v", err) + } + + yasuo.Name = "jinzhu1" + if err := DB.Save(&yasuo).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } +} + func TestPostgres(t *testing.T) { if DB.Dialector.Name() != "postgres" { t.Skip() @@ -63,13 +113,13 @@ func TestPostgres(t *testing.T) { } type Post struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Categories []*Category `gorm:"Many2Many:post_categories"` } type Category struct { - ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();autoincrement"` + ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Posts []*Post `gorm:"Many2Many:post_categories"` } diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go index c6a2dafe..c7f251f2 100644 --- a/tests/prepared_stmt_test.go +++ b/tests/prepared_stmt_test.go @@ -3,6 +3,7 @@ package tests_test import ( "context" "sync" + "errors" "testing" "time" @@ -151,3 +152,19 @@ func TestPreparedStmtError(t *testing.T) { AssertEqual(t, len(conn.Stmts), 0) AssertEqual(t, sqlDB.Stats().InUse, 0) } + +func TestPreparedStmtInTransaction(t *testing.T) { + user := User{Name: "jinzhu"} + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user) + return errors.New("test") + }); err == nil { + t.Error(err) + } + + var result User + if err := DB.First(&result, user.ID).Error; err == nil { + t.Errorf("Failed, got error: %v", err) + } +} diff --git a/tests/query_test.go b/tests/query_test.go index 4569fe1a..eccf0133 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -216,6 +216,30 @@ func TestFind(t *testing.T) { } } + // test array + var models2 [3]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil || len(models2) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models2[idx], user) + }) + } + } + + // test smaller array + var models3 [2]User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil || len(models3) != 2 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3)) + } else { + for idx, user := range users[:2] { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models3[idx], user) + }) + } + } + var none []User if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 0ac04a04..5872da94 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -102,7 +102,7 @@ func TestTransactionWithBlock(t *testing.T) { return errors.New("the error message") }) - if err.Error() != "the error message" { + if err != nil && err.Error() != "the error message" { t.Fatalf("Transaction return error will equal the block returns error") } diff --git a/tests/upsert_test.go b/tests/upsert_test.go index f90c4518..e84dc14a 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -62,7 +62,7 @@ func TestUpsert(t *testing.T) { } r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) - if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } diff --git a/utils/tests/models.go b/utils/tests/models.go index 22e8e659..ec1651a3 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -64,8 +64,8 @@ type Language struct { type Coupon struct { ID int `gorm:"primarykey; size:255"` AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` - AmountOff uint32 `gorm:"amount_off"` - PercentOff float32 `gorm:"percent_off"` + AmountOff uint32 `gorm:"column:amount_off"` + PercentOff float32 `gorm:"column:percent_off"` } type CouponProduct struct {