diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..e4e81074 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +--- +version: 2 +updates: + - package-ecosystem: gomod + directory: / + schedule: + interval: weekly + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly + - package-ecosystem: gomod + directory: /tests + schedule: + interval: weekly diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml index 5b0bd981..dfd2ddd9 100644 --- a/.github/workflows/invalid_question.yml +++ b/.github/workflows/invalid_question.yml @@ -10,7 +10,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v3.0.7 + uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index ea3207d6..cdb097de 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -10,7 +10,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v3.0.7 + uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" diff --git a/.github/workflows/reviewdog.yml b/.github/workflows/reviewdog.yml index 4511c378..d55a4699 100644 --- a/.github/workflows/reviewdog.yml +++ b/.github/workflows/reviewdog.yml @@ -8,4 +8,4 @@ jobs: - name: Check out code into the Go module directory uses: actions/checkout@v1 - name: golangci-lint - uses: reviewdog/action-golangci-lint@v1 + uses: reviewdog/action-golangci-lint@v2 diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index f9c1bece..d5419295 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -10,7 +10,7 @@ jobs: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues - uses: actions/stale@v3.0.7 + uses: actions/stale@v4 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 30 days" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8bd2bcb3..d5ee1e88 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.16', '1.15'] + go: ['1.17', '1.16', '1.15'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -39,7 +39,7 @@ jobs: strategy: matrix: dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] - go: ['1.16', '1.15'] + go: ['1.17', '1.16', '1.15'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -83,7 +83,7 @@ jobs: strategy: matrix: dbversion: ['postgres:latest', 'postgres:12', 'postgres:11', 'postgres:10'] - go: ['1.16', '1.15'] + go: ['1.17', '1.16', '1.15'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -125,7 +125,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.16', '1.15'] + go: ['1.17', '1.16', '1.15'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} diff --git a/callbacks.go b/callbacks.go index 02e741e7..7ab38926 100644 --- a/callbacks.go +++ b/callbacks.go @@ -102,8 +102,8 @@ func (p *processor) Execute(db *DB) *DB { // parse model values if stmt.Model != nil { - if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { - if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { + if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { + if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil { db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) } else { db.AddError(err) diff --git a/callbacks/associations.go b/callbacks/associations.go index 78f976c3..d78bd968 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -310,33 +310,22 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } } -func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict { - if stmt.DB.FullSaveAssociations { - defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) - for _, dbName := range s.DBNames { - if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) { - continue - } - - if !s.LookUpField(dbName).PrimaryKey { - defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) - } - } - } - - if len(defaultUpdatingColumns) > 0 { - columns := make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { + if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { + onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) for _, dbName := range s.PrimaryFieldDBNames { - columns = append(columns, clause.Column{Name: dbName}) + onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName}) } - return clause.OnConflict{ - Columns: columns, - DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), + onConflict.UpdateAll = stmt.DB.FullSaveAssociations + if !onConflict.UpdateAll { + onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns) } + } else { + onConflict.DoNothing = true } - return clause.OnConflict{DoNothing: true} + return } func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { diff --git a/callbacks/create.go b/callbacks/create.go index 04ee6b30..8a3c593c 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -37,7 +37,6 @@ func Create(config *Config) func(db *gorm.DB) { return func(db *gorm.DB) { if db.Error != nil { - // maybe record logger TODO return } @@ -64,11 +63,9 @@ func Create(config *Config) func(db *gorm.DB) { } db.RowsAffected, _ = result.RowsAffected() - if !(db.RowsAffected > 0) { - return - } - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if db.RowsAffected != 0 && db.Statement.Schema != nil && + db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: @@ -107,7 +104,6 @@ func Create(config *Config) func(db *gorm.DB) { db.AddError(err) } } - } } } @@ -348,12 +344,16 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { if c, ok := stmt.Clauses["ON CONFLICT"]; ok { if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { - if stmt.Schema != nil && len(values.Columns) > 1 { + if stmt.Schema != nil && len(values.Columns) >= 1 { + selectColumns, restricted := stmt.SelectAndOmitColumns(true, true) + columns := make([]string, 0, len(values.Columns)-1) for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { - if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { - columns = append(columns, column.Name) + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { + columns = append(columns, column.Name) + } } } } diff --git a/callbacks/helper.go b/callbacks/helper.go index ad85a1c6..d83d20ce 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -44,7 +44,7 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st columns = make([]string, 0, len(mapValues)) ) - // when the length of mapValues,return directly here + // when the length of mapValues is zero,return directly here // no need to call stmt.SelectAndOmitColumns method if len(mapValues) == 0 { stmt.AddError(gorm.ErrEmptySlice) diff --git a/callbacks/preload.go b/callbacks/preload.go index 25c5e659..9882590c 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "gorm.io/gorm" @@ -104,15 +105,17 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload reflectResults := rel.FieldSchema.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) - for _, cond := range conds { - if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { - tx = fc(tx) - } else { - inlineConds = append(inlineConds, cond) + if len(values) != 0 { + for _, cond := range conds { + if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { + tx = fc(tx) + } else { + inlineConds = append(inlineConds, cond) + } } - } - db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error) + } fieldValues := make([]interface{}, len(relForeignFields)) @@ -142,23 +145,27 @@ func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preload fieldValues[idx], _ = field.ValueOf(elem) } - for _, data := range identityMap[utils.ToStringKey(fieldValues...)] { - reflectFieldValue := rel.Field.ReflectValueOf(data) - if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { - reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) - } + if datas, ok := identityMap[utils.ToStringKey(fieldValues...)]; ok { + for _, data := range datas { + reflectFieldValue := rel.Field.ReflectValueOf(data) + if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { + reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) + } - reflectFieldValue = reflect.Indirect(reflectFieldValue) - switch reflectFieldValue.Kind() { - case reflect.Struct: - rel.Field.Set(data, reflectResults.Index(i).Interface()) - case reflect.Slice, reflect.Array: - if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) - } else { - rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + reflectFieldValue = reflect.Indirect(reflectFieldValue) + switch reflectFieldValue.Kind() { + case reflect.Struct: + rel.Field.Set(data, reflectResults.Index(i).Interface()) + case reflect.Slice, reflect.Array: + if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface()) + } else { + rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()) + } } } + } else { + db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())) } } } diff --git a/callbacks/query.go b/callbacks/query.go index d0341284..1cfd618c 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -147,6 +147,21 @@ func BuildQuerySQL(db *gorm.DB) { } } + if join.On != nil { + onStmt := gorm.Statement{Table: tableAliasName, DB: db} + join.On.Build(&onStmt) + onSQL := onStmt.SQL.String() + vars := onStmt.Vars + for idx, v := range onStmt.Vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } + joins = append(joins, clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, @@ -209,7 +224,7 @@ func Preload(db *gorm.DB) { if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { preload(db, rel, db.Statement.Preloads[name], preloadMap[name]) } else { - db.AddError(fmt.Errorf("%v: %w for schema %v", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) + db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) } } } diff --git a/callbacks/update.go b/callbacks/update.go index 75bb02db..7d5ea4a4 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -51,37 +51,39 @@ func BeforeUpdate(db *gorm.DB) { } func Update(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } + if db.Error != nil { + return + } - if db.Statement.SQL.String() == "" { - db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else { - return - } - db.Statement.Build(db.Statement.BuildClauses...) + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) } + } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { return } + db.Statement.Build(db.Statement.BuildClauses...) + } - if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) - } + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) } } } @@ -220,35 +222,45 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } default: + var updatingSchema = stmt.Schema + if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + // different schema + updatingStmt := &gorm.Statement{DB: stmt.DB} + if err := updatingStmt.Parse(stmt.Dest); err == nil { + updatingSchema = updatingStmt.Schema + } + } + switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, dbName := range stmt.Schema.DBNames { - field := stmt.Schema.LookUpField(dbName) - if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { - value, isZero := field.ValueOf(updatingValue) - if !stmt.SkipHooks && field.AutoUpdateTime > 0 { - if field.AutoUpdateTime == schema.UnixNanosecond { - value = stmt.DB.NowFunc().UnixNano() - } else if field.AutoUpdateTime == schema.UnixMillisecond { - value = stmt.DB.NowFunc().UnixNano() / 1e6 - } else if field.GORMDataType == schema.Time { - value = stmt.DB.NowFunc() - } else { - value = stmt.DB.NowFunc().Unix() + if field := updatingSchema.LookUpField(dbName); field != nil && field.Updatable { + if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { + if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { + value, isZero := field.ValueOf(updatingValue) + if !stmt.SkipHooks && field.AutoUpdateTime > 0 { + if field.AutoUpdateTime == schema.UnixNanosecond { + value = stmt.DB.NowFunc().UnixNano() + } else if field.AutoUpdateTime == schema.UnixMillisecond { + value = stmt.DB.NowFunc().UnixNano() / 1e6 + } else if field.GORMDataType == schema.Time { + value = stmt.DB.NowFunc() + } else { + value = stmt.DB.NowFunc().Unix() + } + isZero = false } - isZero = false - } - if ok || !isZero { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - assignValue(field, value) + if ok || !isZero { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) + assignValue(field, value) + } + } + } else { + if value, isZero := field.ValueOf(updatingValue); !isZero { + stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } - } - } else { - if value, isZero := field.ValueOf(updatingValue); !isZero { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } } diff --git a/chainable_api.go b/chainable_api.go index e17d9bb2..01ab2597 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -50,15 +50,14 @@ func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { tx.Statement.Table = results[1] - return } } else if tables := strings.Split(name, "."); len(tables) == 2 { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = tables[1] - return + } else { + tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} + tx.Statement.Table = name } - - tx.Statement.Table = name return } @@ -172,8 +171,19 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // Joins specify Joins conditions // db.Joins("Account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) +// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() + + if len(args) > 0 { + if db, ok := args[0].(*DB); ok { + if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args[1:], On: &where}) + } + return + } + } + tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) return } @@ -209,12 +219,14 @@ func (db *DB) Order(value interface{}) (tx *DB) { tx.Statement.AddClause(clause.OrderBy{ Columns: []clause.OrderByColumn{v}, }) - default: - tx.Statement.AddClause(clause.OrderBy{ - Columns: []clause.OrderByColumn{{ - Column: clause.Column{Name: fmt.Sprint(value), Raw: true}, - }}, - }) + case string: + if v != "" { + tx.Statement.AddClause(clause.OrderBy{ + Columns: []clause.OrderByColumn{{ + Column: clause.Column{Name: v, Raw: true}, + }}, + }) + } } return } diff --git a/clause/expression.go b/clause/expression.go index 2bdd4a30..f7b93f4c 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -173,7 +173,12 @@ func (expr NamedExpr) Build(builder Builder) { } if inName { - builder.AddVar(builder, namedMap[string(name)]) + if nv, ok := namedMap[string(name)]; ok { + builder.AddVar(builder, nv) + } else { + builder.WriteByte('@') + builder.WriteString(string(name)) + } } } @@ -205,11 +210,12 @@ func (in IN) Build(builder Builder) { } func (in IN) NegationBuild(builder Builder) { + builder.WriteQuoted(in.Column) switch len(in.Values) { case 0: + builder.WriteString(" IS NOT NULL") case 1: if _, ok := in.Values[0].([]interface{}); !ok { - builder.WriteQuoted(in.Column) builder.WriteString(" <> ") builder.AddVar(builder, in.Values[0]) break @@ -217,7 +223,6 @@ func (in IN) NegationBuild(builder Builder) { fallthrough default: - builder.WriteQuoted(in.Column) builder.WriteString(" NOT IN (") builder.AddVar(builder, in.Values...) builder.WriteByte(')') diff --git a/clause/expression_test.go b/clause/expression_test.go index 1c8217ed..05074865 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -60,6 +60,11 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name", "jinzhu")}, Result: "name1 = ? AND name2 = ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, + }, { + SQL: "name1 = @name AND name2 = @@name", + Vars: []interface{}{map[string]interface{}{"name": "jinzhu"}}, + Result: "name1 = ? AND name2 = @@name", + ExpectedVars: []interface{}{"jinzhu"}, }, { SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, @@ -73,13 +78,13 @@ func TestNamedExpr(t *testing.T) { }, { SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist", Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, - Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", - ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, }, { - SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist", + SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @notexist", Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}}, - Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", - ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, }, { SQL: "create table ? (? ?, ? ?)", Vars: []interface{}{}, @@ -151,6 +156,18 @@ func TestExpression(t *testing.T) { }, ExpectedVars: []interface{}{"a", "b"}, Result: "`column-name` NOT IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, + }, + ExpectedVars: []interface{}{100}, + Result: "SUM(`id`) = ?", + }, { + Expressions: []clause.Expression{ + clause.Gte{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Table: "users", Name: "id"}}}, Value: 100}, + }, + ExpectedVars: []interface{}{100}, + Result: "SUM(`users`.`id`) >= ?", }} for idx, result := range results { diff --git a/clause/select.go b/clause/select.go index b93b8769..d8e9f801 100644 --- a/clause/select.go +++ b/clause/select.go @@ -43,3 +43,17 @@ func (s Select) MergeClause(clause *Clause) { clause.Expression = s } } + +// CommaExpression represents a group of expressions separated by commas. +type CommaExpression struct { + Exprs []Expression +} + +func (comma CommaExpression) Build(builder Builder) { + for idx, expr := range comma.Exprs { + if idx > 0 { + _, _ = builder.WriteString(", ") + } + expr.Build(builder) + } +} diff --git a/clause/select_test.go b/clause/select_test.go index b7296434..9fce0783 100644 --- a/clause/select_test.go +++ b/clause/select_test.go @@ -31,6 +31,18 @@ func TestSelect(t *testing.T) { }, clause.From{}}, "SELECT `name` FROM `users`", nil, }, + { + []clause.Interface{clause.Select{ + Expression: clause.CommaExpression{ + Exprs: []clause.Expression{ + clause.NamedExpr{"?", []interface{}{clause.Column{Name: "id"}}}, + clause.NamedExpr{"?", []interface{}{clause.Column{Name: "name"}}}, + clause.NamedExpr{"LENGTH(?)", []interface{}{clause.Column{Name: "mobile"}}}, + }, + }, + }, clause.From{}}, + "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, + }, } for idx, result := range results { diff --git a/finisher_api.go b/finisher_api.go index 51f394b4..34e1596b 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -376,7 +376,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { defer func() { - db.Statement.Clauses["SELECT"] = selectClause + tx.Statement.Clauses["SELECT"] = selectClause }() } else { defer delete(tx.Statement.Clauses, "SELECT") @@ -390,7 +390,7 @@ func (db *DB) Count(count *int64) (tx *DB) { if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) - if len(fields) == 1 || (len(fields) == 3 && strings.ToUpper(fields[1]) == "AS") { + if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { dbName = f.DBName @@ -410,9 +410,9 @@ func (db *DB) Count(count *int64) (tx *DB) { if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { - delete(db.Statement.Clauses, "ORDER BY") + delete(tx.Statement.Clauses, "ORDER BY") defer func() { - db.Statement.Clauses["ORDER BY"] = orderByClause + tx.Statement.Clauses["ORDER BY"] = orderByClause }() } } diff --git a/logger/logger.go b/logger/logger.go index 98d1b32e..69d41113 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -64,9 +64,10 @@ type Interface interface { var ( Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ - SlowThreshold: 200 * time.Millisecond, - LogLevel: Warn, - Colorful: true, + SlowThreshold: 200 * time.Millisecond, + LogLevel: Warn, + IgnoreRecordNotFoundError: false, + Colorful: true, }) Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} ) diff --git a/migrator/migrator.go b/migrator/migrator.go index 03ffdd02..48db151e 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -2,6 +2,7 @@ package migrator import ( "context" + "database/sql" "fmt" "reflect" "regexp" @@ -166,10 +167,12 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] - createTableSQL += "? ?" - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") - values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) - createTableSQL += "," + if !field.IgnoreMigration { + createTableSQL += "? ?" + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) + createTableSQL += "," + } } if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { @@ -195,6 +198,10 @@ func (m Migrator) CreateTable(values ...interface{}) error { } createTableSQL += "INDEX ? ?" + if idx.Comment != "" { + createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + if idx.Option != "" { createTableSQL += " " + idx.Option } @@ -382,11 +389,11 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn = true } else { // has size in data type and not equal - // Since the following code is frequently called in the for loop, reg optimization is needed here matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { + if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && + (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { alterColumn = true } } @@ -414,22 +421,31 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy return nil } -func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { - columnTypes = make([]gorm.ColumnType, 0) - err = m.RunWithValue(value, func(stmt *gorm.Statement) error { +// ColumnTypes return columnTypes []gorm.ColumnType and execErr error +func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { + columnTypes := make([]gorm.ColumnType, 0) + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() - if err == nil { - defer rows.Close() - rawColumnTypes, err := rows.ColumnTypes() - if err == nil { - for _, c := range rawColumnTypes { - columnTypes = append(columnTypes, c) - } - } + if err != nil { + return err } - return err + + defer rows.Close() + + var rawColumnTypes []*sql.ColumnType + rawColumnTypes, err = rows.ColumnTypes() + if err != nil { + return err + } + + for _, c := range rawColumnTypes { + columnTypes = append(columnTypes, c) + } + + return nil }) - return + + return columnTypes, execErr } func (m Migrator) CreateView(name string, option gorm.ViewOption) error { @@ -489,9 +505,10 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ } if field := stmt.Schema.LookUpField(name); field != nil { - for _, cc := range checkConstraints { - if cc.Field == field { - return nil, &cc, stmt.Table + for k := range checkConstraints { + if checkConstraints[k].Field == field { + v := checkConstraints[k] + return nil, &v, stmt.Table } } @@ -601,6 +618,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { createIndexSQL += " USING " + idx.Type } + if idx.Comment != "" { + createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + if idx.Option != "" { createIndexSQL += " " + idx.Option } @@ -608,7 +629,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { return m.DB.Exec(createIndexSQL, values...).Error } - return fmt.Errorf("failed to create index with name %v", name) + return fmt.Errorf("failed to create index with name %s", name) }) } diff --git a/prepare_stmt.go b/prepare_stmt.go index 48a614b7..5faea995 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -35,7 +35,7 @@ func (db *PreparedStmtDB) Close() { for _, query := range db.PreparedSQL { if stmt, ok := db.Stmts[query]; ok { delete(db.Stmts, query) - stmt.Close() + go stmt.Close() } } @@ -56,7 +56,7 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.Unlock() return stmt, nil } else if ok { - stmt.Close() + go stmt.Close() } stmt, err := conn.PrepareContext(ctx, query) @@ -83,7 +83,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. result, err = stmt.ExecContext(ctx, args...) if err != nil { db.Mux.Lock() - stmt.Close() + go stmt.Close() delete(db.Stmts, query) db.Mux.Unlock() } @@ -97,7 +97,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . rows, err = stmt.QueryContext(ctx, args...) if err != nil { db.Mux.Lock() - stmt.Close() + go stmt.Close() delete(db.Stmts, query) db.Mux.Unlock() } @@ -138,7 +138,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() - stmt.Close() + go stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) tx.PreparedStmtDB.Mux.Unlock() } @@ -152,7 +152,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() - stmt.Close() + go stmt.Close() delete(tx.PreparedStmtDB.Stmts, query) tx.PreparedStmtDB.Mux.Unlock() } diff --git a/scan.go b/scan.go index e82e3f07..2beecd45 100644 --- a/scan.go +++ b/scan.go @@ -208,6 +208,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } values[idx] = &sql.RawBytes{} + } else if len(columns) == 1 { + values[idx] = dest } else { values[idx] = &sql.RawBytes{} } @@ -238,6 +240,8 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } } } + default: + db.AddError(rows.Scan(dest)) } } diff --git a/schema/field.go b/schema/field.go index 9efaa44a..ce0e3c13 100644 --- a/schema/field.go +++ b/schema/field.go @@ -490,21 +490,22 @@ func (field *Field) setupValuerAndSetter() { return } else if field.FieldType.Kind() == reflect.Ptr { fieldValue := field.ReflectValueOf(value) + fieldType := field.FieldType.Elem() - if reflectValType.AssignableTo(field.FieldType.Elem()) { + if reflectValType.AssignableTo(fieldType) { if !fieldValue.IsValid() { - fieldValue = reflect.New(field.FieldType.Elem()) + fieldValue = reflect.New(fieldType) } else if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.FieldType.Elem())) + fieldValue.Set(reflect.New(fieldType)) } fieldValue.Elem().Set(reflectV) return - } else if reflectValType.ConvertibleTo(field.FieldType.Elem()) { + } else if reflectValType.ConvertibleTo(fieldType) { if fieldValue.IsNil() { - fieldValue.Set(reflect.New(field.FieldType.Elem())) + fieldValue.Set(reflect.New(fieldType)) } - fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) + fieldValue.Elem().Set(reflectV.Convert(fieldType)) return } } @@ -520,7 +521,7 @@ func (field *Field) setupValuerAndSetter() { err = setter(value, v) } } else { - return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) + return fmt.Errorf("failed to set value %+v to field %s", v, field.Name) } } diff --git a/schema/naming.go b/schema/naming.go index 47e313a7..8407bffa 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -4,6 +4,7 @@ import ( "crypto/sha1" "encoding/hex" "fmt" + "regexp" "strings" "unicode/utf8" @@ -13,6 +14,7 @@ import ( // Namer namer interface type Namer interface { TableName(table string) string + SchemaName(table string) string ColumnName(table, column string) string JoinTableName(joinTable string) string RelationshipFKName(Relationship) string @@ -41,6 +43,16 @@ func (ns NamingStrategy) TableName(str string) string { return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) } +// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName +func (ns NamingStrategy) SchemaName(table string) string { + table = strings.TrimPrefix(table, ns.TablePrefix) + + if ns.SingularTable { + return ns.toSchemaName(table) + } + return ns.toSchemaName(inflection.Singular(table)) +} + // ColumnName convert string to column name func (ns NamingStrategy) ColumnName(table, column string) string { return ns.toDBName(column) @@ -154,3 +166,11 @@ func (ns NamingStrategy) toDBName(name string) string { ret := buf.String() return ret } + +func (ns NamingStrategy) toSchemaName(name string) string { + result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1) + for _, initialism := range commonInitialisms { + result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") + } + return result +} diff --git a/schema/naming_test.go b/schema/naming_test.go index face9364..6add338e 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -33,6 +33,26 @@ func TestToDBName(t *testing.T) { t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key)) } } + + maps = map[string]string{ + "x": "X", + "user_restrictions": "UserRestriction", + "this_is_a_test": "ThisIsATest", + "abc_and_jkl": "AbcAndJkl", + "employee_id": "EmployeeID", + "field_x": "FieldX", + "http_and_smtp": "HTTPAndSMTP", + "http_server_handler_for_url_id": "HTTPServerHandlerForURLID", + "uuid": "UUID", + "http_url": "HTTPURL", + "sha256_hash": "Sha256Hash", + "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id": "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIDCanBeUsedAtTheEndAsID", + } + for key, value := range maps { + if ns.SchemaName(key) != value { + t.Errorf("%v schema name should equal %v, but got %v", key, value, ns.SchemaName(key)) + } + } } func TestNamingStrategy(t *testing.T) { diff --git a/schema/relationship.go b/schema/relationship.go index db496e30..84556bae 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -238,7 +238,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } for idx, relField := range refForeignFields { - joinFieldName := relation.FieldSchema.Name + relField.Name + joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name if len(joinReferences) > idx { joinFieldName = strings.Title(joinReferences[idx]) } diff --git a/schema/schema.go b/schema/schema.go index 8ade2ed7..faba2e21 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -119,20 +119,13 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + if v, loaded := cacheStore.Load(modelType); loaded { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized return s, s.err } - defer func() { - if schema.err != nil { - logger.Default.Error(context.Background(), schema.err.Error()) - cacheStore.Delete(modelType) - } - }() - for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { @@ -228,11 +221,25 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) case "func(*gorm.DB) error": // TODO hack reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) } } } + if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err + } + + defer func() { + if schema.err != nil { + logger.Default.Error(context.Background(), schema.err.Error()) + cacheStore.Delete(modelType) + } + }() + if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { @@ -244,19 +251,20 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } fieldValue := reflect.New(field.IndirectFieldType) - if fc, ok := fieldValue.Interface().(CreateClausesInterface); ok { + fieldInterface := fieldValue.Interface() + if fc, ok := fieldInterface.(CreateClausesInterface); ok { field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) } - if fc, ok := fieldValue.Interface().(QueryClausesInterface); ok { + if fc, ok := fieldInterface.(QueryClausesInterface); ok { field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) } - if fc, ok := fieldValue.Interface().(UpdateClausesInterface); ok { + if fc, ok := fieldInterface.(UpdateClausesInterface); ok { field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) } - if fc, ok := fieldValue.Interface().(DeleteClausesInterface); ok { + if fc, ok := fieldInterface.(DeleteClausesInterface); ok { field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) } } diff --git a/statement.go b/statement.go index 8b682c84..38363443 100644 --- a/statement.go +++ b/statement.go @@ -50,6 +50,7 @@ type Statement struct { type join struct { Name string Conds []interface{} + On *clause.Where } // StatementModifier statement modifier interface @@ -129,6 +130,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { stmt.QuoteTo(writer, d) } writer.WriteByte(')') + case clause.Expr: + v.Build(stmt) case string: stmt.DB.Dialector.QuoteTo(writer, v) case []string: diff --git a/tests/delete_test.go b/tests/delete_test.go index abe85b0e..f62cc606 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -22,8 +22,8 @@ func TestDelete(t *testing.T) { } } - if err := DB.Delete(&users[1]).Error; err != nil { - t.Errorf("errors happened when delete: %v", err) + if res := DB.Delete(&users[1]); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("errors happened when delete: %v, affected: %v", res.Error, res.RowsAffected) } var result User diff --git a/tests/go.mod b/tests/go.mod index 815f8986..d7ab65ad 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,14 +3,14 @@ module gorm.io/gorm/tests go 1.14 require ( - github.com/google/uuid v1.2.0 + github.com/google/uuid v1.3.0 github.com/jinzhu/now v1.1.2 - github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v1.0.5 + github.com/lib/pq v1.10.3 + gorm.io/driver/mysql v1.1.2 gorm.io/driver/postgres v1.1.0 gorm.io/driver/sqlite v1.1.4 - gorm.io/driver/sqlserver v1.0.7 - gorm.io/gorm v1.21.9 + gorm.io/driver/sqlserver v1.0.9 + gorm.io/gorm v1.21.14 ) replace gorm.io/gorm => ../ diff --git a/tests/joins_test.go b/tests/joins_test.go index 46611f5f..e560f38a 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -104,6 +104,27 @@ func TestJoinConds(t *testing.T) { } } +func TestJoinOn(t *testing.T) { + var user = *GetUser("joins-on", Config{Pets: 2}) + DB.Save(&user) + + var user1 User + onQuery := DB.Where(&Pet{Name: "joins-on_pet_1"}) + + if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + + AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") + + onQuery2 := DB.Where(&Pet{Name: "joins-on_pet_2"}) + var user2 User + if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { + t.Fatalf("Failed to load with joins on, got error: %v", err) + } + AssertEqual(t, user2.NamedPet.Name, "joins-on_pet_2") +} + func TestJoinsWithSelect(t *testing.T) { type result struct { ID uint diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 4da3856f..599ca850 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -142,17 +142,36 @@ func TestSmartMigrateColumn(t *testing.T) { } -func TestMigrateWithComment(t *testing.T) { - type UserWithComment struct { +func TestMigrateWithColumnComment(t *testing.T) { + type UserWithColumnComment struct { gorm.Model - Name string `gorm:"size:111;index:,comment:这是一个index;comment:this is a 字段"` + Name string `gorm:"size:111;comment:this is a 字段"` } - if err := DB.Migrator().DropTable(&UserWithComment{}); err != nil { + if err := DB.Migrator().DropTable(&UserWithColumnComment{}); err != nil { t.Fatalf("Failed to drop table, got error %v", err) } - if err := DB.AutoMigrate(&UserWithComment{}); err != nil { + if err := DB.AutoMigrate(&UserWithColumnComment{}); err != nil { + t.Fatalf("Failed to auto migrate, but got error %v", err) + } +} + +func TestMigrateWithIndexComment(t *testing.T) { + if DB.Dialector.Name() != "mysql" { + t.Skip() + } + + type UserWithIndexComment struct { + gorm.Model + Name string `gorm:"size:111;index:,comment:这是一个index"` + } + + if err := DB.Migrator().DropTable(&UserWithIndexComment{}); err != nil { + t.Fatalf("Failed to drop table, got error %v", err) + } + + if err := DB.AutoMigrate(&UserWithIndexComment{}); err != nil { t.Fatalf("Failed to auto migrate, but got error %v", err) } } diff --git a/tests/query_test.go b/tests/query_test.go index 34999337..8a476598 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -436,6 +436,11 @@ func TestNot(t *testing.T) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + result = dryDB.Not(map[string]interface{}{"name": []string{}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IS NOT NULL").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) @@ -842,7 +847,17 @@ func TestSearchWithEmptyChain(t *testing.T) { func TestOrder(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) - result := dryDB.Order("age desc, name").Find(&User{}) + result := dryDB.Order("").Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order(nil).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("age desc, name").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc, name").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } diff --git a/tests/scan_test.go b/tests/scan_test.go index 86cb0399..67d5f385 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -63,6 +63,13 @@ func TestScan(t *testing.T) { if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { t.Errorf("Scan into struct map, got %#v", results) } + + type ID uint64 + var id ID + DB.Raw("select id from users where id = ?", user2.ID).Scan(&id) + if uint(id) != user2.ID { + t.Errorf("Failed to scan to customized data type") + } } func TestScanRows(t *testing.T) { diff --git a/tests/table_test.go b/tests/table_test.go index 0c6b3eb0..0289b7b8 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -30,6 +30,26 @@ func TestTable(t *testing.T) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } + r = dryDB.Table("`people`").Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("people as p").Table("user as u").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("people as p").Table("user").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.people").Table("user").Find(&User{}).Statement + if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index a61629f8..59d30e42 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "database/sql" "testing" "time" @@ -85,4 +86,48 @@ func TestUpdateHasOne(t *testing.T) { DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) CheckPet(t, pet4, pet) }) + + t.Run("Restriction", func(t *testing.T) { + type CustomizeAccount struct { + gorm.Model + UserID sql.NullInt64 + Number string `gorm:"<-:create"` + } + + type CustomizeUser struct { + gorm.Model + Name string + Account CustomizeAccount `gorm:"foreignkey:UserID"` + } + + DB.Migrator().DropTable(&CustomizeUser{}) + DB.Migrator().DropTable(&CustomizeAccount{}) + + if err := DB.AutoMigrate(&CustomizeUser{}); err != nil { + t.Fatalf("failed to migrate, got error: %v", err) + } + if err := DB.AutoMigrate(&CustomizeAccount{}); err != nil { + t.Fatalf("failed to migrate, got error: %v", err) + } + + number := "number-has-one-associations" + cusUser := CustomizeUser{ + Name: "update-has-one-associations", + Account: CustomizeAccount{ + Number: number, + }, + } + + if err := DB.Create(&cusUser).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + cusUser.Account.Number += "-update" + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil { + t.Fatalf("errors happened when create: %v", err) + } + + var account2 CustomizeAccount + DB.Find(&account2, "user_id = ?", cusUser.ID) + AssertEqual(t, account2.Number, number) + }) } diff --git a/tests/update_test.go b/tests/update_test.go index 5ad1bb39..631d0d6d 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -69,8 +69,10 @@ func TestUpdate(t *testing.T) { } values := map[string]interface{}{"Active": true, "age": 5} - if err := DB.Model(user).Updates(values).Error; err != nil { - t.Errorf("errors happened when update: %v", err) + if res := DB.Model(user).Updates(values); res.Error != nil { + t.Errorf("errors happened when update: %v", res.Error) + } else if res.RowsAffected != 1 { + t.Errorf("rows affected should be 1, but got : %v", res.RowsAffected) } else if user.Age != 5 { t.Errorf("Age should equals to 5, but got %v", user.Age) } else if user.Active != true { @@ -131,7 +133,10 @@ func TestUpdates(t *testing.T) { lastUpdatedAt := users[0].UpdatedAt // update with map - DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100}) + if res := DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100}); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("Failed to update users") + } + if users[0].Name != "updates_01_newname" || users[0].Age != 100 { t.Errorf("Record should be updated also with map") } @@ -642,6 +647,40 @@ func TestSave(t *testing.T) { if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } + + user3 := *GetUser("save3", Config{}) + DB.Create(&user3) + + if err := DB.First(&User{}, "name = ?", "save3").Error; err != nil { + t.Fatalf("failed to find created user") + } + + user3.Name = "save3_" + if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil { + t.Fatalf("failed to save user, got %v", err) + } + + var result2 User + if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { + t.Fatalf("failed to find updated user, got %v", err) + } + + if err := DB.Model(User{Model: user3.Model}).Save(&struct { + gorm.Model + Placeholder string + Name string + }{ + Model: user3.Model, + Placeholder: "placeholder", + Name: "save3__", + }).Error; err != nil { + t.Fatalf("failed to update user, got %v", err) + } + + var result3 User + if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { + t.Fatalf("failed to find updated user") + } } func TestSaveWithPrimaryValue(t *testing.T) { diff --git a/tests/upsert_test.go b/tests/upsert_test.go index 0ba8b9f0..867110d8 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -1,9 +1,11 @@ package tests_test import ( + "regexp" "testing" "time" + "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -51,6 +53,19 @@ func TestUpsert(t *testing.T) { if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { t.Fatalf("failed to upsert, got name %v", result.Name) } + + if name := DB.Dialector.Name(); name != "sqlserver" { + type RestrictedLanguage struct { + Code string `gorm:"primarykey"` + Name string + Lang string `gorm:"<-:create"` + } + + r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) + if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } } func TestUpsertSlice(t *testing.T) { diff --git a/utils/tests/models.go b/utils/tests/models.go index 2c5e71c0..8e833c93 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -11,6 +11,7 @@ import ( // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) +// NamedPet is a reference to a Named `Pets` (has many) type User struct { gorm.Model Name string @@ -18,6 +19,7 @@ type User struct { Birthday *time.Time Account Account Pets []*Pet + NamedPet *Pet Toys []Toy `gorm:"polymorphic:Owner"` CompanyID *int Company Company