diff --git a/.github/labels.json b/.github/labels.json index 8b1ce849..6b9c2034 100644 --- a/.github/labels.json +++ b/.github/labels.json @@ -10,6 +10,11 @@ "colour": "#EDEDED", "description": "general questions" }, + "invalid_question": { + "name": "type:invalid question", + "colour": "#CF2E1F", + "description": "invalid question (not related to GORM or described in document or not enough information provided)" + }, "with_playground": { "name": "type:with reproduction steps", "colour": "#00ff00", diff --git a/.github/workflows/invalid_question.yml b/.github/workflows/invalid_question.yml new file mode 100644 index 00000000..5b0bd981 --- /dev/null +++ b/.github/workflows/invalid_question.yml @@ -0,0 +1,22 @@ +name: "Close invalid questions issues" +on: + schedule: + - cron: "*/10 * * * *" + +jobs: + stale: + runs-on: ubuntu-latest + env: + ACTIONS_STEP_DEBUG: true + steps: + - name: Close Stale Issues + uses: actions/stale@v3.0.7 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 2 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" + stale-issue-label: "status:stale" + days-before-stale: 0 + days-before-close: 2 + remove-stale-when-updated: true + only-labels: "type:invalid question" + diff --git a/.github/workflows/missing_playground.yml b/.github/workflows/missing_playground.yml index 6fb714ca..ea3207d6 100644 --- a/.github/workflows/missing_playground.yml +++ b/.github/workflows/missing_playground.yml @@ -13,7 +13,7 @@ jobs: uses: actions/stale@v3.0.7 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - stale-issue-message: "This issue has been automatically marked as stale as it missing playground pull request link, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details, it will be closed in 2 days if no further activity occurs." + stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 2 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 days-before-close: 2 diff --git a/callbacks.go b/callbacks.go index baeb6c09..e21e0718 100644 --- a/callbacks.go +++ b/callbacks.go @@ -8,7 +8,6 @@ import ( "sort" "time" - "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -75,15 +74,20 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) { curTime := time.Now() stmt := db.Statement - db.RowsAffected = 0 if stmt.Model == nil { stmt.Model = stmt.Dest + } else if stmt.Dest == nil { + stmt.Dest = stmt.Model } if stmt.Model != nil { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { - db.AddError(err) + if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { + db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) + } else { + db.AddError(err) + } } } @@ -154,7 +158,7 @@ func (p *processor) compile() (err error) { p.callbacks = callbacks if p.fns, err = sortCallbacks(p.callbacks); err != nil { - logger.Default.Error(context.Background(), "Got error when compile callbacks, got %v", err) + p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) } return } @@ -177,7 +181,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { } func (c *callback) Remove(name string) error { - logger.Default.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) @@ -185,7 +189,7 @@ func (c *callback) Remove(name string) error { } func (c *callback) Replace(name string, fn func(*DB)) error { - logger.Default.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true @@ -215,7 +219,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { - logger.Default.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } diff --git a/callbacks/associations.go b/callbacks/associations.go index 2710ffe9..64d79f24 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -5,6 +5,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" ) func SaveBeforeAssociations(db *gorm.DB) { @@ -65,9 +66,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(elems.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -80,9 +79,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - DoNothing: true, - }).Create(rv.Interface()).Error) == nil { + if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -144,10 +141,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -167,10 +163,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(f.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(f.Interface()).Error) } } } @@ -229,10 +224,9 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: rel.FieldSchema.PrioritizedPrimaryField.DBName}}, - DoUpdates: clause.AssignmentColumns(assignmentColumns), - }).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses( + onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), + ).Create(elems.Interface()).Error) } } @@ -297,7 +291,7 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(elems.Interface()).Error) + db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -310,3 +304,32 @@ func SaveAfterAssociations(db *gorm.DB) { } } } + +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict { + if stmt.DB.FullSaveAssociations { + defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) + for _, dbName := range s.DBNames { + if !s.LookUpField(dbName).PrimaryKey { + defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) + } + } + } + + if len(defaultUpdatingColumns) > 0 { + var columns []clause.Column + if s.PrioritizedPrimaryField != nil { + columns = []clause.Column{{Name: s.PrioritizedPrimaryField.DBName}} + } else { + for _, dbName := range s.PrimaryFieldDBNames { + columns = append(columns, clause.Column{Name: dbName}) + } + } + + return clause.OnConflict{ + Columns: columns, + DoUpdates: clause.AssignmentColumns(defaultUpdatingColumns), + } + } + + return clause.OnConflict{DoNothing: true} +} diff --git a/callbacks/callbacks.go b/callbacks/callbacks.go index 0a12468c..dda4b046 100644 --- a/callbacks/callbacks.go +++ b/callbacks/callbacks.go @@ -31,6 +31,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { deleteCallback := db.Callback().Delete() deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Register("gorm:before_delete", BeforeDelete) + deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) deleteCallback.Register("gorm:delete", Delete) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) diff --git a/callbacks/create.go b/callbacks/create.go index 5de19d35..8e2454e8 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -1,6 +1,7 @@ package callbacks import ( + "fmt" "reflect" "gorm.io/gorm" @@ -87,7 +88,10 @@ func Create(config *Config) func(db *gorm.DB) { } case reflect.Struct: if insertID > 0 { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } } } else { @@ -252,8 +256,18 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { stmt.SQL.Grow(stmt.ReflectValue.Len() * 15) values.Values = make([][]interface{}, stmt.ReflectValue.Len()) defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} + if stmt.ReflectValue.Len() == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + for i := 0; i < stmt.ReflectValue.Len(); i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) + if !rv.IsValid() { + stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) + return + } + values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] diff --git a/callbacks/delete.go b/callbacks/delete.go index e95117a1..85f11f4b 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -2,6 +2,7 @@ package callbacks import ( "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -21,6 +22,85 @@ func BeforeDelete(db *gorm.DB) { } } +func DeleteBeforeAssociations(db *gorm.DB) { + if db.Error == nil && db.Statement.Schema != nil { + selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) + + if restricted { + for column, v := range selectColumns { + if v { + if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok { + switch rel.Type { + case schema.HasOne, schema.HasMany: + queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) + modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := db.Session(&gorm.Session{}).Model(modelValue) + withoutConditions := false + + if len(db.Statement.Selects) > 0 { + var selects []string + for _, s := range db.Statement.Selects { + if s == clause.Associations { + selects = append(selects, s) + } else if strings.HasPrefix(s, column+".") { + selects = append(selects, strings.TrimPrefix(s, column+".")) + } + } + + if len(selects) > 0 { + tx = tx.Select(selects) + } + } + + for _, cond := range queryConds { + if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { + withoutConditions = true + break + } + } + + if !withoutConditions { + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + case schema.Many2Many: + var ( + queryConds []clause.Expression + foreignFields []*schema.Field + relForeignKeys []string + modelValue = reflect.New(rel.JoinTable.ModelType).Interface() + table = rel.JoinTable.Table + tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table) + ) + + for _, ref := range rel.References { + if ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.PrimaryKey) + relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) + } else if ref.PrimaryValue != "" { + queryConds = append(queryConds, clause.Eq{ + Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + }) + } + } + + _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields) + column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) + queryConds = append(queryConds, clause.IN{Column: column, Values: values}) + + if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { + return + } + } + } + } + } + } + } +} + func Delete(db *gorm.DB) { if db.Error == nil { if db.Statement.Schema != nil && !db.Statement.Unscoped { diff --git a/callbacks/helper.go b/callbacks/helper.go index e0a66dd2..09ec4582 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -46,6 +46,11 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) ) + if len(mapValues) == 0 { + stmt.AddError(gorm.ErrEmptySlice) + return + } + for idx, mapValue := range mapValues { for k, v := range mapValue { if stmt.Schema != nil { diff --git a/callbacks/preload.go b/callbacks/preload.go index 25b8cb2b..aec10ec5 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -107,6 +107,26 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { fieldValues := make([]interface{}, len(relForeignFields)) + // clean up old values before preloading + switch reflectValue.Kind() { + case reflect.Struct: + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + default: + rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface()) + } + case reflect.Slice, reflect.Array: + for i := 0; i < reflectValue.Len(); i++ { + switch rel.Type { + case schema.HasMany, schema.Many2Many: + rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 0).Interface()) + default: + rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()) + } + } + } + for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { diff --git a/callbacks/row.go b/callbacks/row.go index 7e70382e..4f985d7b 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -11,11 +11,13 @@ func RowQuery(db *gorm.DB) { } if !db.DryRun { - if _, ok := db.Get("rows"); ok { + if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } + + db.RowsAffected = -1 } } } diff --git a/chainable_api.go b/chainable_api.go index c8417a6d..ae2ac4f1 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -198,7 +198,7 @@ func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { // Order specify order when retrieve records from database // db.Order("name DESC") -// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression +// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() diff --git a/clause/expression.go b/clause/expression.go index 4d5e328b..6a0dde8d 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -3,6 +3,7 @@ package clause import ( "database/sql" "database/sql/driver" + "go/ast" "reflect" ) @@ -30,18 +31,22 @@ func (expr Expr) Build(builder Builder) { ) for _, v := range []byte(expr.SQL) { - if v == '?' { + if v == '?' && len(expr.Vars) > idx { if afterParenthesis { if _, ok := expr.Vars[idx].(driver.Valuer); ok { builder.AddVar(builder, expr.Vars[idx]) } else { switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + if rv.Len() == 0 { + builder.AddVar(builder, nil) + } else { + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) } default: builder.AddVar(builder, expr.Vars[idx]) @@ -85,6 +90,17 @@ func (expr NamedExpr) Build(builder Builder) { for k, v := range value { namedMap[k] = v } + default: + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + switch reflectValue.Kind() { + case reflect.Struct: + modelType := reflectValue.Type() + for i := 0; i < modelType.NumField(); i++ { + if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { + namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() + } + } + } } } @@ -94,7 +110,7 @@ func (expr NamedExpr) Build(builder Builder) { if v == '@' && !inName { inName = true name = []byte{} - } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' { + } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' { if inName { if nv, ok := namedMap[string(name)]; ok { builder.AddVar(builder, nv) @@ -106,7 +122,7 @@ func (expr NamedExpr) Build(builder Builder) { } builder.WriteByte(v) - } else if v == '?' { + } else if v == '?' && len(expr.Vars) > idx { builder.AddVar(builder, expr.Vars[idx]) idx++ } else if inName { diff --git a/clause/expression_test.go b/clause/expression_test.go index 17af737d..19e30e6c 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -37,6 +37,11 @@ func TestExpr(t *testing.T) { } func TestNamedExpr(t *testing.T) { + type NamedArgument struct { + Name1 string + Name2 string + } + results := []struct { SQL string Result string @@ -66,6 +71,15 @@ func TestNamedExpr(t *testing.T) { Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @Notexist", + Vars: []interface{}{NamedArgument{Name1: "jinzhu", Name2: "jinzhu2"}}, + Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? ?", + ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu", nil}, + }, { + SQL: "create table ? (? ?, ? ?)", + Vars: []interface{}{}, + Result: "create table ? (? ?, ? ?)", }} for idx, result := range results { diff --git a/clause/limit.go b/clause/limit.go index 1946820d..2082f4d9 100644 --- a/clause/limit.go +++ b/clause/limit.go @@ -33,10 +33,8 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if limit.Limit == 0 && v.Limit > 0 { + if limit.Limit == 0 && v.Limit != 0 { limit.Limit = v.Limit - } else if limit.Limit < 0 { - limit.Limit = 0 } if limit.Offset == 0 && v.Offset > 0 { diff --git a/clause/where.go b/clause/where.go index 9af9701c..a3774e1c 100644 --- a/clause/where.go +++ b/clause/where.go @@ -1,5 +1,9 @@ package clause +import ( + "strings" +) + // Where where clause type Where struct { Exprs []Expression @@ -22,6 +26,7 @@ func (where Where) Build(builder Builder) { } } + wrapInParentheses := false for idx, expr := range where.Exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { @@ -31,7 +36,36 @@ func (where Where) Build(builder Builder) { } } - expr.Build(builder) + if len(where.Exprs) > 1 { + switch v := expr.(type) { + case OrConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case AndConditions: + if len(v.Exprs) == 1 { + if e, ok := v.Exprs[0].(Expr); ok { + sql := strings.ToLower(e.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + case Expr: + sql := strings.ToLower(v.SQL) + wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or") + } + } + + if wrapInParentheses { + builder.WriteString(`(`) + expr.Build(builder) + builder.WriteString(`)`) + wrapInParentheses = false + } else { + expr.Build(builder) + } } } @@ -50,6 +84,8 @@ func (where Where) MergeClause(clause *Clause) { func And(exprs ...Expression) Expression { if len(exprs) == 0 { return nil + } else if len(exprs) == 1 { + return exprs[0] } return AndConditions{Exprs: exprs} } @@ -118,6 +154,7 @@ func (not NotConditions) Build(builder Builder) { if len(not.Exprs) > 1 { builder.WriteByte('(') } + for idx, c := range not.Exprs { if idx > 0 { builder.WriteString(" AND ") @@ -127,9 +164,22 @@ func (not NotConditions) Build(builder Builder) { negationBuilder.NegationBuild(builder) } else { builder.WriteString("NOT ") + e, wrapInParentheses := c.(Expr) + if wrapInParentheses { + sql := strings.ToLower(e.SQL) + if wrapInParentheses = strings.Contains(sql, "and") || strings.Contains(sql, "or"); wrapInParentheses { + builder.WriteByte('(') + } + } + c.Build(builder) + + if wrapInParentheses { + builder.WriteByte(')') + } } } + if len(not.Exprs) > 1 { builder.WriteByte(')') } diff --git a/errors.go b/errors.go index 32ff8ec1..08755083 100644 --- a/errors.go +++ b/errors.go @@ -27,4 +27,8 @@ var ( ErrRegistered = errors.New("registered") // ErrInvalidField invalid field ErrInvalidField = errors.New("invalid field") + // ErrEmptySlice empty slice found + ErrEmptySlice = errors.New("empty slice found") + // ErrDryRunModeUnsupported dry run mode unsupported + ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") ) diff --git a/finisher_api.go b/finisher_api.go index a205b859..63061553 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -51,7 +52,8 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.callbacks.Update().Execute(tx) if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { - if err := tx.Session(&Session{}).First(value).Error; errors.Is(err, ErrRecordNotFound) { + result := reflect.New(tx.Statement.Schema.ModelType).Interface() + if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } @@ -331,22 +333,43 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance() + tx := db.getInstance().InstanceSet("rows", false) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Row) + row, ok := tx.Statement.Dest.(*sql.Row) + if !ok && tx.DryRun { + db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) + } + return row } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.Set("rows", true) + tx := db.getInstance().InstanceSet("rows", true) tx.callbacks.Row().Execute(tx) - return tx.Statement.Dest.(*sql.Rows), tx.Error + rows, ok := tx.Statement.Dest.(*sql.Rows) + if !ok && tx.DryRun && tx.Error == nil { + tx.Error = ErrDryRunModeUnsupported + } + return rows, tx.Error } // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { + currentLogger, newLogger := db.Logger, logger.Recorder.New() tx = db.getInstance() - tx.Statement.Dest = dest - tx.callbacks.Query().Execute(tx) + tx.Logger = newLogger + if rows, err := tx.Rows(); err != nil { + tx.AddError(err) + } else { + defer rows.Close() + if rows.Next() { + tx.ScanRows(rows, dest) + } + } + + currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { + return newLogger.SQL, tx.RowsAffected + }, tx.Error) + tx.Logger = currentLogger return } @@ -377,9 +400,14 @@ func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() - tx.Error = tx.Statement.Parse(dest) + if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { + tx.AddError(err) + } tx.Statement.Dest = dest - tx.Statement.ReflectValue = reflect.Indirect(reflect.ValueOf(dest)) + tx.Statement.ReflectValue = reflect.ValueOf(dest) + for tx.Statement.ReflectValue.Kind() == reflect.Ptr { + tx.Statement.ReflectValue = tx.Statement.ReflectValue.Elem() + } Scan(rows, tx, true) return tx.Error } diff --git a/gorm.go b/gorm.go index 3c187f42..e5c4a8a4 100644 --- a/gorm.go +++ b/gorm.go @@ -20,6 +20,8 @@ type Config struct { SkipDefaultTransaction bool // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer + // FullSaveAssociations full save associations + FullSaveAssociations bool // Logger Logger logger.Interface // NowFunc the function to be used when creating a new timestamp @@ -64,6 +66,7 @@ type Session struct { WithConditions bool SkipDefaultTransaction bool AllowGlobalUpdate bool + FullSaveAssociations bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -161,6 +164,10 @@ func (db *DB) Session(config *Session) *DB { txConfig.AllowGlobalUpdate = true } + if config.FullSaveAssociations { + txConfig.FullSaveAssociations = true + } + if config.Context != nil { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx @@ -169,12 +176,15 @@ func (db *DB) Session(config *Session) *DB { if config.PrepareStmt { if v, ok := db.cacheStore.Load("preparedStmt"); ok { + tx.Statement = tx.Statement.clone() preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, Mux: preparedStmt.Mux, Stmts: preparedStmt.Stmts, } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } } @@ -316,6 +326,9 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { f.DataType = ref.ForeignKey.DataType f.GORMDataType = ref.ForeignKey.GORMDataType + if f.Size == 0 { + f.Size = ref.ForeignKey.Size + } ref.ForeignKey = f } else { return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) diff --git a/logger/logger.go b/logger/logger.go index 49ae988c..b278ad5d 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -2,6 +2,7 @@ package logger import ( "context" + "io/ioutil" "log" "os" "time" @@ -19,6 +20,7 @@ const ( Magenta = "\033[35m" Cyan = "\033[36m" White = "\033[37m" + BlueBold = "\033[34;1m" MagentaBold = "\033[35;1m" RedBold = "\033[31;1m" YellowBold = "\033[33;1m" @@ -54,29 +56,33 @@ type Interface interface { Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) } -var Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ - SlowThreshold: 100 * time.Millisecond, - LogLevel: Warn, - Colorful: true, -}) +var ( + Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) + Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ + SlowThreshold: 100 * time.Millisecond, + LogLevel: Warn, + Colorful: true, + }) + Recorder = traceRecorder{Interface: Default} +) func New(writer Writer, config Config) Interface { var ( infoStr = "%s\n[info] " warnStr = "%s\n[warn] " errStr = "%s\n[error] " - traceStr = "%s\n[%v] [rows:%d] %s" - traceWarnStr = "%s\n[%v] [rows:%d] %s" - traceErrStr = "%s %s\n[%v] [rows:%d] %s" + traceStr = "%s\n[%.3fms] [rows:%v] %s" + traceWarnStr = "%s\n[%.3fms] [rows:%v] %s" + traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" ) if config.Colorful { infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset - warnStr = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset + warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset - traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" - traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%d]" + Magenta + " %s" + Reset - traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + Blue + "[rows:%d]" + Reset + " %s" + traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" + traceWarnStr = Green + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset + traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" } return &logger{ @@ -133,13 +139,43 @@ func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, i switch { case err != nil && l.LogLevel >= Error: sql, rows := fc() - l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) + } case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: sql, rows := fc() - l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceWarnStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } case l.LogLevel >= Info: sql, rows := fc() - l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + if rows == -1 { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) + } else { + l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) + } } } } + +type traceRecorder struct { + Interface + BeginAt time.Time + SQL string + RowsAffected int64 + Err error +} + +func (l traceRecorder) New() *traceRecorder { + return &traceRecorder{Interface: l.Interface} +} + +func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { + l.BeginAt = begin + l.SQL, l.RowsAffected = fc() + l.Err = err +} diff --git a/logger/sql.go b/logger/sql.go index 02d559c5..138a35ec 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -9,6 +9,8 @@ import ( "strings" "time" "unicode" + + "gorm.io/gorm/utils" ) func isPrintable(s []byte) bool { @@ -24,19 +26,38 @@ var convertableTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var convertParams func(interface{}, int) - var vars = make([]interface{}, len(avars)) - copy(vars, avars) + var vars = make([]string, len(avars)) convertParams = func(v interface{}, idx int) { switch v := v.(type) { case bool: - vars[idx] = fmt.Sprint(v) + vars[idx] = strconv.FormatBool(v) case time.Time: if v.IsZero() { vars[idx] = escaper + "0000-00-00 00:00:00" + escaper } else { vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper } + case *time.Time: + if v != nil { + if v.IsZero() { + vars[idx] = escaper + "0000-00-00 00:00:00" + escaper + } else { + vars[idx] = escaper + v.Format("2006-01-02 15:04:05.999") + escaper + } + } else { + vars[idx] = "NULL" + } + case fmt.Stringer: + vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + case driver.Valuer: + reflectValue := reflect.ValueOf(v) + if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { + r, _ := v.Value() + convertParams(r, idx) + } else { + vars[idx] = "NULL" + } case []byte: if isPrintable(v) { vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper @@ -44,7 +65,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a vars[idx] = escaper + "" + escaper } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - vars[idx] = fmt.Sprintf("%d", v) + vars[idx] = utils.ToString(v) case float64, float32: vars[idx] = fmt.Sprintf("%.6f", v) case string: @@ -70,18 +91,30 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } } - for idx, v := range vars { + for idx, v := range avars { convertParams(v, idx) } if numericPlaceholder == nil { - for _, v := range vars { - sql = strings.Replace(sql, "?", v.(string), 1) + var idx int + var newSQL strings.Builder + + for _, v := range []byte(sql) { + if v == '?' { + if len(vars) > idx { + newSQL.WriteString(vars[idx]) + idx++ + continue + } + } + newSQL.WriteByte(v) } + + sql = newSQL.String() } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") for idx, v := range vars { - sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v.(string), 1) + sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1) } } diff --git a/logger/sql_test.go b/logger/sql_test.go index 180570b8..71aa841a 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -1,13 +1,39 @@ package logger_test import ( + "database/sql/driver" + "encoding/json" + "fmt" "regexp" + "strings" "testing" "github.com/jinzhu/now" "gorm.io/gorm/logger" ) +type JSON json.RawMessage + +func (j JSON) Value() (driver.Value, error) { + if len(j) == 0 { + return nil, nil + } + return json.RawMessage(j).MarshalJSON() +} + +type ExampleStruct struct { + Name string + Val string +} + +func (s ExampleStruct) Value() (driver.Value, error) { + return json.Marshal(s) +} + +func format(v []byte, escaper string) string { + return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper +} + func TestExplainSQL(t *testing.T) { type role string type password []byte @@ -15,6 +41,10 @@ func TestExplainSQL(t *testing.T) { tt = now.MustParse("2020-02-23 11:10:10") myrole = role("admin") pwd = password([]byte("pass")) + jsVal = []byte(`{"Name":"test","Val":"test"}`) + js = JSON(jsVal) + esVal = []byte(`{"Name":"test","Val":"test"}`) + es = ExampleStruct{Name: "test", Val: "test"} ) results := []struct { @@ -29,6 +59,12 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, + Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass")`, + }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), @@ -47,6 +83,18 @@ func TestExplainSQL(t *testing.T) { Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, + { + SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + NumericRegexp: nil, + Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, + Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.990000, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.\"com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), + }, } for idx, r := range results { diff --git a/migrator.go b/migrator.go index ed8a8e26..162fe680 100644 --- a/migrator.go +++ b/migrator.go @@ -9,7 +9,7 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator(db) + return db.Dialector.Migrator(db.Session(&Session{WithConditions: true})) } // AutoMigrate run auto migration for given models diff --git a/migrator/migrator.go b/migrator/migrator.go index d93b8a6d..f390ff9f 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -71,7 +71,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) - } else { + } else if field.DefaultValue != "(-)" { expr.SQL += " DEFAULT " + field.DefaultValue } } @@ -133,6 +133,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } } + + for _, idx := range stmt.Schema.ParseIndexes() { + if !tx.Migrator().HasIndex(value, idx.Name) { + if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + return nil }); err != nil { return err @@ -297,10 +306,12 @@ func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { + fileType := clause.Expr{SQL: m.DataTypeOf(field)} return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), + clause.Table{Name: stmt.Table}, clause.Column{Name: field.DBName}, fileType, ).Error + } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -354,9 +365,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn = true } else { // has size in data type and not equal - matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) + matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1) matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) - if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { + if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { alterColumn = true } } @@ -386,8 +397,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err == nil { + defer rows.Close() columnTypes, err = rows.ColumnTypes() } return err @@ -584,6 +596,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i var ( modelNames, orderedModelNames []string orderedModelNamesMap = map[string]bool{} + parsedSchemas = map[*schema.Schema]bool{} valuesMap = map[string]Dependency{} insertIntoOrderedList func(name string) parseDependence func(value interface{}, addToList bool) @@ -593,23 +606,35 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i dep := Dependency{ Statement: &gorm.Statement{DB: m.DB, Dest: value}, } + beDependedOn := map[*schema.Schema]bool{} if err := dep.Parse(value); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } + if _, ok := parsedSchemas[dep.Statement.Schema]; ok { + return + } + parsedSchemas[dep.Statement.Schema] = true for _, rel := range dep.Schema.Relationships.Relations { if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { dep.Depends = append(dep.Depends, c.ReferenceSchema) } + if rel.Type == schema.HasOne || rel.Type == schema.HasMany { + beDependedOn[rel.FieldSchema] = true + } + if rel.JoinTable != nil { - if rel.Schema != rel.FieldSchema { - dep.Depends = append(dep.Depends, rel.FieldSchema) - } // append join value - defer func(joinValue interface{}) { + defer func(rel *schema.Relationship, joinValue interface{}) { + if !beDependedOn[rel.FieldSchema] { + dep.Depends = append(dep.Depends, rel.FieldSchema) + } else { + fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() + parseDependence(fieldValue, autoAdd) + } parseDependence(joinValue, autoAdd) - }(reflect.New(rel.JoinTable.ModelType).Interface()) + }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) } } diff --git a/prepare_stmt.go b/prepare_stmt.go index 7e87558d..14a6aaec 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -25,7 +25,7 @@ func (db *PreparedStmtDB) Close() { db.Mux.Unlock() } -func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, query string) (*sql.Stmt, error) { db.Mux.RLock() if stmt, ok := db.Stmts[query]; ok { db.Mux.RUnlock() @@ -40,7 +40,7 @@ func (db *PreparedStmtDB) prepare(query string) (*sql.Stmt, error) { return stmt, nil } - stmt, err := db.ConnPool.PrepareContext(context.Background(), query) + stmt, err := db.ConnPool.PrepareContext(ctx, query) if err == nil { db.Stmts[query] = stmt db.PreparedSQL = append(db.PreparedSQL, query) @@ -59,7 +59,7 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn } func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := db.prepare(query) + stmt, err := db.prepare(ctx, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) if err != nil { @@ -73,7 +73,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. } func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := db.prepare(query) + stmt, err := db.prepare(ctx, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) if err != nil { @@ -87,7 +87,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := db.prepare(query) + stmt, err := db.prepare(ctx, query) if err == nil { return stmt.QueryRowContext(ctx, args...) } @@ -99,10 +99,24 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (tx *PreparedStmtTX) Commit() error { + if tx.Tx != nil { + return tx.Tx.Commit() + } + return ErrInvalidTransaction +} + +func (tx *PreparedStmtTX) Rollback() error { + if tx.Tx != nil { + return tx.Tx.Rollback() + } + return ErrInvalidTransaction +} + func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { - stmt, err := tx.PreparedStmtDB.prepare(query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - result, err = tx.Tx.Stmt(stmt).ExecContext(ctx, args...) + result, err = tx.Tx.StmtContext(ctx, stmt).ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -114,7 +128,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. } func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { - stmt, err := tx.PreparedStmtDB.prepare(query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { rows, err = tx.Tx.Stmt(stmt).QueryContext(ctx, args...) if err != nil { @@ -128,9 +142,9 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { - stmt, err := tx.PreparedStmtDB.prepare(query) + stmt, err := tx.PreparedStmtDB.prepare(ctx, query) if err == nil { - return tx.Tx.Stmt(stmt).QueryRowContext(ctx, args...) + return tx.Tx.StmtContext(ctx, stmt).QueryRowContext(ctx, args...) } return &sql.Row{} } diff --git a/scan.go b/scan.go index 0b199029..8d737b17 100644 --- a/scan.go +++ b/scan.go @@ -2,22 +2,63 @@ package gorm import ( "database/sql" + "database/sql/driver" "reflect" "strings" + "time" "gorm.io/gorm/schema" ) +func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { + if db.Statement.Schema != nil { + for idx, name := range columns { + if field := db.Statement.Schema.LookUpField(name); field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + continue + } + values[idx] = new(interface{}) + } + } else if len(columnTypes) > 0 { + for idx, columnType := range columnTypes { + if columnType.ScanType() != nil { + values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() + } else { + values[idx] = new(interface{}) + } + } + } else { + for idx := range columns { + values[idx] = new(interface{}) + } + } +} + +func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { + for idx, column := range columns { + if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { + mapValue[column] = reflectValue.Interface() + if valuer, ok := mapValue[column].(driver.Valuer); ok { + mapValue[column], _ = valuer.Value() + } else if b, ok := mapValue[column].(sql.RawBytes); ok { + mapValue[column] = string(b) + } + } else { + mapValue[column] = nil + } + } +} + func Scan(rows *sql.Rows, db *DB, initialized bool) { columns, _ := rows.Columns() values := make([]interface{}, len(columns)) + db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: if initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } + columnTypes, _ := rows.ColumnTypes() + prepareValues(values, db, columnTypes, columns) db.RowsAffected++ db.AddError(rows.Scan(values...)) @@ -28,41 +69,22 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { mapValue = *v } } - - for idx, column := range columns { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - mapValue[column] = nil - } else { - mapValue[column] = *v - } - } - } + scanIntoMap(mapValue, values, columns) } case *[]map[string]interface{}: + columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { - for idx := range columns { - values[idx] = new(interface{}) - } + prepareValues(values, db, columnTypes, columns) initialized = false db.RowsAffected++ db.AddError(rows.Scan(values...)) mapValue := map[string]interface{}{} - for idx, column := range columns { - if v, ok := values[idx].(*interface{}); ok { - if v == nil { - mapValue[column] = nil - } else { - mapValue[column] = *v - } - } - } - + scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } - case *int, *int64, *uint, *uint64, *float32, *float64: + case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time: for initialized || rows.Next() { initialized = false db.RowsAffected++ @@ -114,7 +136,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } // pluck values into slice of data - isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct + isPluck := false + if len(fields) == 1 { + if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok { + isPluck = true + } else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) { + isPluck = true + } + } + for initialized || rows.Next() { initialized = false db.RowsAffected++ diff --git a/schema/field.go b/schema/field.go index 524d19fb..db516c33 100644 --- a/schema/field.go +++ b/schema/field.go @@ -18,6 +18,8 @@ type DataType string type TimeType int64 +var TimeReflectType = reflect.TypeOf(time.Time{}) + const ( UnixSecond TimeType = 1 UnixMillisecond TimeType = 2 @@ -70,6 +72,8 @@ type Field struct { } func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { + var err error + field := &Field{ Name: fieldStruct.Name, BindNames: []string{fieldStruct.Name}, @@ -100,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { rv := reflect.Indirect(v) - if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { for i := 0; i < rv.Type().NumField(); i++ { newFieldType := rv.Type().Field(i).Type for newFieldType.Kind() == reflect.Ptr { @@ -151,7 +155,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } if num, ok := field.TagSettings["SIZE"]; ok { - var err error if field.Size, err = strconv.Atoi(num); err != nil { field.Size = -1 } @@ -177,33 +180,42 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.Comment = val } + // default value is function or null or blank (primary keys) + skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && + strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool - if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseBool(field.DefaultValue) + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) + } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int - if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseInt(field.DefaultValue, 0, 64) + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) + } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint - if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseUint(field.DefaultValue, 0, 64) + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) + } } case reflect.Float32, reflect.Float64: field.DataType = Float - if field.HasDefaultValue && field.DefaultValue != "" { - field.DefaultValueInterface, _ = strconv.ParseFloat(field.DefaultValue, 64) + if field.HasDefaultValue && !skipParseDefaultValue { + if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { + schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) + } } case reflect.String: field.DataType = String - isFunc := strings.Contains(field.DefaultValue, "(") && - strings.Contains(field.DefaultValue, ")") - if field.HasDefaultValue && !isFunc { + if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, "\"") field.DefaultValueInterface = field.DefaultValue @@ -211,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time - } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { + } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { field.DataType = Time @@ -316,7 +328,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) - if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, schema.namer); err != nil { + if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { schema.err = err } @@ -331,23 +343,25 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) } - if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok { + if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" { ef.DBName = prefix + ef.DBName } - if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { - ef.PrimaryKey = true - } else { - ef.PrimaryKey = false + if ef.PrimaryKey { + if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) { + ef.PrimaryKey = true + } else { + ef.PrimaryKey = false - if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { - ef.AutoIncrement = false - } + if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { + ef.AutoIncrement = false + } - if ef.DefaultValue == "" { - ef.HasDefaultValue = false + if ef.DefaultValue == "" { + ef.HasDefaultValue = false + } } } @@ -612,6 +626,14 @@ func (field *Field) setupValuerAndSetter() { field.ReflectValueOf(value).SetUint(uint64(data)) case []byte: return field.Set(value, string(data)) + case time.Time: + if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano())) + } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { + field.ReflectValueOf(value).SetUint(uint64(data.UnixNano() / 1e6)) + } else { + field.ReflectValueOf(value).SetUint(uint64(data.Unix())) + } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { field.ReflectValueOf(value).SetUint(i) @@ -671,7 +693,7 @@ func (field *Field) setupValuerAndSetter() { case []byte: field.ReflectValueOf(value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - field.ReflectValueOf(value).SetString(fmt.Sprint(data)) + field.ReflectValueOf(value).SetString(utils.ToString(data)) case float64, float32: field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: diff --git a/schema/naming.go b/schema/naming.go index 9b7c9471..dbc71e04 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -14,7 +14,7 @@ import ( type Namer interface { TableName(table string) string ColumnName(table, column string) string - JoinTableName(table string) string + JoinTableName(joinTable string) string RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string @@ -41,6 +41,10 @@ func (ns NamingStrategy) ColumnName(table, column string) string { // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { + if strings.ToLower(str) == str { + return ns.TablePrefix + str + } + if ns.SingularTable { return ns.TablePrefix + toDBName(str) } @@ -49,17 +53,18 @@ func (ns NamingStrategy) JoinTableName(str string) string { // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { - return fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)) + return strings.Replace(fmt.Sprintf("fk_%s_%s", rel.Schema.Table, toDBName(rel.Name)), ".", "_", -1) } // CheckerName generate checker name func (ns NamingStrategy) CheckerName(table, column string) string { - return fmt.Sprintf("chk_%s_%s", table, column) + return strings.Replace(fmt.Sprintf("chk_%s_%s", table, column), ".", "_", -1) } // IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { idxName := fmt.Sprintf("idx_%v_%v", table, toDBName(column)) + idxName = strings.Replace(idxName, ".", "_", -1) if utf8.RuneCountInString(idxName) > 64 { h := sha1.New() diff --git a/schema/naming_test.go b/schema/naming_test.go index 96b83ced..26b0dcf6 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -32,3 +32,35 @@ func TestToDBName(t *testing.T) { } } } + +func TestNamingStrategy(t *testing.T) { + var ns = NamingStrategy{ + TablePrefix: "public.", + SingularTable: true, + } + idxName := ns.IndexName("public.table", "name") + + if idxName != "idx_public_table_name" { + t.Errorf("invalid index name generated, got %v", idxName) + } + + chkName := ns.CheckerName("public.table", "name") + if chkName != "chk_public_table_name" { + t.Errorf("invalid checker name generated, got %v", chkName) + } + + joinTable := ns.JoinTableName("user_languages") + if joinTable != "public.user_languages" { + t.Errorf("invalid join table generated, got %v", joinTable) + } + + joinTable2 := ns.JoinTableName("UserLanguage") + if joinTable2 != "public.user_language" { + t.Errorf("invalid join table generated, got %v", joinTable2) + } + + tableName := ns.TableName("Company") + if tableName != "public.company" { + t.Errorf("invalid table name generated, got %v", tableName) + } +} diff --git a/schema/relationship.go b/schema/relationship.go index 5132ff74..35af111f 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -165,6 +165,9 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi // use same data type for foreign keys relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType + if relation.Polymorphic.PolymorphicID.Size == 0 { + relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, @@ -254,12 +257,18 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel }) } + joinTableFields = append(joinTableFields, reflect.StructField{ + Name: schema.Name + field.Name, + Type: schema.ModelType, + Tag: `gorm:"-"`, + }) + if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many relation.JoinTable.Table = schema.namer.JoinTableName(many2many) - relation.JoinTable.PrimaryFields = make([]*Field, len(relation.JoinTable.Fields)) + relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields)) relName := relation.Schema.Name relRefName := relation.FieldSchema.Name @@ -290,36 +299,41 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel } // build references - for idx, f := range relation.JoinTable.Fields { - // use same data type for foreign keys - f.DataType = fieldsMap[f.Name].DataType - f.GORMDataType = fieldsMap[f.Name].GORMDataType - relation.JoinTable.PrimaryFields[idx] = f - ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] - - if ownPriamryField { - joinRel := relation.JoinTable.Relationships.Relations[relName] - joinRel.Field = relation.Field - joinRel.References = append(joinRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - }) - } else { - joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] - if joinRefRel.Field == nil { - joinRefRel.Field = relation.Field + for _, f := range relation.JoinTable.Fields { + if f.Creatable || f.Readable || f.Updatable { + // use same data type for foreign keys + f.DataType = fieldsMap[f.Name].DataType + f.GORMDataType = fieldsMap[f.Name].GORMDataType + if f.Size == 0 { + f.Size = fieldsMap[f.Name].Size } - joinRefRel.References = append(joinRefRel.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, + relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) + ownPriamryField := schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name] + + if ownPriamryField { + joinRel := relation.JoinTable.Relationships.Relations[relName] + joinRel.Field = relation.Field + joinRel.References = append(joinRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } else { + joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] + if joinRefRel.Field == nil { + joinRefRel.Field = relation.Field + } + joinRefRel.References = append(joinRefRel.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + }) + } + + relation.References = append(relation.References, &Reference{ + PrimaryKey: fieldsMap[f.Name], + ForeignKey: f, + OwnPrimaryKey: ownPriamryField, }) } - - relation.References = append(relation.References, &Reference{ - PrimaryKey: fieldsMap[f.Name], - ForeignKey: f, - OwnPrimaryKey: ownPriamryField, - }) } } @@ -428,6 +442,9 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl gue // use same data type for foreign keys foreignField.DataType = primaryFields[idx].DataType foreignField.GORMDataType = primaryFields[idx].GORMDataType + if foreignField.Size == 0 { + foreignField.Size = primaryFields[idx].Size + } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], diff --git a/schema/relationship_test.go b/schema/relationship_test.go index 2e85c538..7d7fd9c9 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -206,16 +206,39 @@ func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type User struct { gorm.Model - Profiles []Profile `gorm:"many2many:user_profiles;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` Refer uint } checkStructRelation(t, &User{}, Relation{ Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", - JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, References: []Reference{ - {"ID", "User", "UserReferID", "user_profiles", "", true}, - {"ID", "Profile", "ProfileRefer", "user_profiles", "", false}, + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, + }, + }) +} + +func TestBuildReadonlyMany2ManyRelation(t *testing.T) { + type Profile struct { + gorm.Model + Name string + UserRefer uint + } + + type User struct { + gorm.Model + Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` + Refer uint + } + + checkStructRelation(t, &User{}, Relation{ + Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", + JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, + References: []Reference{ + {"ID", "User", "UserReferID", "user_profile", "", true}, + {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, }, }) } @@ -267,3 +290,34 @@ func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { }, ) } + +func TestMultipleMany2Many(t *testing.T) { + type Thing struct { + ID int + } + + type Person struct { + ID int + Likes []Thing `gorm:"many2many:likes"` + Dislikes []Thing `gorm:"many2many:dislikes"` + } + + checkStructRelation(t, &Person{}, + Relation{ + Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "likes", Table: "likes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "likes", "", true}, + {"ID", "Thing", "ThingID", "likes", "", false}, + }, + }, + Relation{ + Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", + JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"}, + References: []Reference{ + {"ID", "Person", "PersonID", "dislikes", "", true}, + {"ID", "Thing", "ThingID", "dislikes", "", false}, + }, + }, + ) +} diff --git a/schema/schema.go b/schema/schema.go index ea81d683..cffc19a7 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -97,6 +97,9 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } + if en, ok := namer.(embeddedNamer); ok { + tableName = en.Table + } schema := &Schema{ Name: modelType.Name(), @@ -133,7 +136,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission - if v, ok := schema.FieldsByDBName[field.DBName]; !ok || (field.Creatable && len(field.BindNames) < len(v.BindNames)) { + if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { if _, ok := schema.FieldsByDBName[field.DBName]; !ok { schema.DBNames = append(schema.DBNames, field.DBName) } @@ -219,7 +222,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if _, loaded := cacheStore.LoadOrStore(modelType, schema); !loaded { if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { for _, field := range schema.Fields { - if field.DataType == "" && field.Creatable { + if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { if schema.parseRelation(field); schema.err != nil { return schema, schema.err } diff --git a/schema/schema_test.go b/schema/schema_test.go index 4d13ebd2..a426cd90 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "strings" "sync" "testing" @@ -194,6 +195,7 @@ func TestEmbeddedStruct(t *testing.T) { ID int OwnerID int Name string + Ignored string `gorm:"-"` } type Corp struct { @@ -211,15 +213,87 @@ func TestEmbeddedStruct(t *testing.T) { {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, } for _, f := range fields { checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { - f.Creatable = true - f.Updatable = true - f.Readable = true + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } + }) + } +} + +type CustomizedNamingStrategy struct { + schema.NamingStrategy +} + +func (ns CustomizedNamingStrategy) ColumnName(table, column string) string { + baseColumnName := ns.NamingStrategy.ColumnName(table, column) + + if table == "" { + return baseColumnName + } + + s := strings.Split(table, "_") + + var prefix string + switch len(s) { + case 1: + prefix = s[0][:3] + case 2: + prefix = s[0][:1] + s[1][:2] + default: + prefix = s[0][:1] + s[1][:1] + s[2][:1] + } + return prefix + "_" + baseColumnName +} + +func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { + type CorpBase struct { + gorm.Model + OwnerID string + } + + type Company struct { + ID int + OwnerID int + Name string + Ignored string `gorm:"-"` + } + + type Corp struct { + CorpBase + Base Company `gorm:"embedded;embeddedPrefix:company_"` + } + + cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) + + if err != nil { + t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) + } + + fields := []schema.Field{ + {Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, + {Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, + {Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, + } + + for _, f := range fields { + checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { + if f.Name != "Ignored" { + f.Creatable = true + f.Updatable = true + f.Readable = true + } }) } } diff --git a/schema/utils.go b/schema/utils.go index 41bd9d60..55cbdeb4 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -190,3 +190,8 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa return columns, queryValues } } + +type embeddedNamer struct { + Table string + Namer +} diff --git a/soft_delete.go b/soft_delete.go index 484f565c..b13fc63f 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -25,14 +25,7 @@ func (n DeletedAt) Value() (driver.Value, error) { } func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { - return []clause.Interface{ - clause.Where{Exprs: []clause.Expression{ - clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, - Value: nil, - }, - }}, - } + return []clause.Interface{SoftDeleteQueryClause{Field: f}} } type SoftDeleteQueryClause struct { diff --git a/statement.go b/statement.go index d72a086f..38d35926 100644 --- a/statement.go +++ b/statement.go @@ -12,6 +12,7 @@ import ( "sync" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) @@ -189,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } case *DB: - subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance() + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance() subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) subdb.callbacks.Query().Execute(subdb) writer.WriteString(subdb.Statement.SQL.String()) @@ -298,12 +299,18 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - values := make([]interface{}, reflectValue.Len()) - for i := 0; i < reflectValue.Len(); i++ { - values[i] = reflectValue.Index(i).Interface() - } + if _, ok := v[key].(driver.Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else if _, ok := v[key].(Valuer); ok { + conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + } else { + values := make([]interface{}, reflectValue.Len()) + for i := 0; i < reflectValue.Len(); i++ { + values[i] = reflectValue.Index(i).Interface() + } - conds = append(conds, clause.IN{Column: key, Values: values}) + conds = append(conds, clause.IN{Column: key, Values: values}) + } default: conds = append(conds, clause.Eq{Column: key, Value: v[key]}) } @@ -317,9 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c if field.Readable { if v, isZero := field.ValueOf(reflectValue); !isZero { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) } } } @@ -330,9 +337,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) (c if field.Readable { if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: s.Table, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) } } } diff --git a/tests/create_test.go b/tests/create_test.go index ab0a78d4..00674eec 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "errors" "testing" "time" @@ -287,6 +288,30 @@ func TestCreateEmptyStruct(t *testing.T) { } } +func TestCreateEmptySlice(t *testing.T) { + var data = []User{} + if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } + + var sliceMap = []map[string]interface{}{} + if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice { + t.Errorf("no data should be created, got %v", err) + } +} + +func TestCreateInvalidSlice(t *testing.T) { + users := []*User{ + GetUser("invalid_slice_1", Config{}), + GetUser("invalid_slice_2", Config{}), + nil, + } + + if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) { + t.Errorf("should returns error invalid data when creating from slice that contains invalid data") + } +} + func TestCreateWithExistingTimestamp(t *testing.T) { user := User{Name: "CreateUserExistingTimestamp"} curTime := now.MustParse("2016-01-01") diff --git a/tests/customize_field_test.go b/tests/customize_field_test.go index bf3c78fa..7802eb11 100644 --- a/tests/customize_field_test.go +++ b/tests/customize_field_test.go @@ -69,12 +69,12 @@ func TestCustomizeField(t *testing.T) { FieldAllowSave3 string `gorm:"->:false;<-:create"` FieldReadonly string `gorm:"->"` FieldIgnore string `gorm:"-"` - AutoUnixCreateTime int64 `gorm:"autocreatetime"` - AutoUnixMilliCreateTime int64 `gorm:"autocreatetime:milli"` + AutoUnixCreateTime int32 `gorm:"autocreatetime"` + AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"` AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` - AutoUnixUpdateTime int64 `gorm:"autoupdatetime"` - AutoUnixMilliUpdateTime int64 `gorm:"autoupdatetime:milli"` - AutoUnixNanoUpdateTime int64 `gorm:"autoupdatetime:nano"` + AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"` + AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"` + AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"` } DB.Migrator().DropTable(&CustomizeFieldStruct{}) @@ -116,15 +116,15 @@ func TestCustomizeField(t *testing.T) { t.Fatalf("invalid result: %#v", result) } - if result.AutoUnixCreateTime != result.AutoUnixUpdateTime || result.AutoUnixCreateTime == 0 { + if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 { t.Fatalf("invalid create/update unix time: %#v", result) } - if result.AutoUnixMilliCreateTime != result.AutoUnixMilliUpdateTime || result.AutoUnixMilliCreateTime == 0 || result.AutoUnixMilliCreateTime/result.AutoUnixCreateTime < 1e3 { + if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 { t.Fatalf("invalid create/update unix milli time: %#v", result) } - if result.AutoUnixNanoCreateTime != result.AutoUnixNanoUpdateTime || result.AutoUnixNanoCreateTime == 0 || result.AutoUnixNanoCreateTime/result.AutoUnixCreateTime < 1e6 { + if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 { t.Fatalf("invalid create/update unix nano time: %#v", result) } @@ -178,15 +178,15 @@ func TestCustomizeField(t *testing.T) { var createWithDefaultTimeResult CustomizeFieldStruct DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) - if createWithDefaultTimeResult.AutoUnixCreateTime != createWithDefaultTimeResult.AutoUnixUpdateTime || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) } - if createWithDefaultTimeResult.AutoUnixMilliCreateTime != createWithDefaultTimeResult.AutoUnixMilliUpdateTime || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) } - if createWithDefaultTimeResult.AutoUnixNanoCreateTime != createWithDefaultTimeResult.AutoUnixNanoUpdateTime || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { + if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) } } diff --git a/tests/delete_test.go b/tests/delete_test.go index 17299677..ecd5ec39 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -5,6 +5,7 @@ import ( "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -127,3 +128,56 @@ func TestBlockGlobalDelete(t *testing.T) { t.Errorf("should returns no error while enable global update, but got err %v", err) } } + +func TestDeleteWithAssociations(t *testing.T) { + user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) + + if err := DB.Create(user).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&user).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} + +func TestDeleteSliceWithAssociations(t *testing.T) { + users := []User{ + *GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), + *GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}), + *GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}), + *GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}), + } + + if err := DB.Create(users).Error; err != nil { + t.Fatalf("failed to create user, got error %v", err) + } + + if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil { + t.Fatalf("failed to delete user, got error %v", err) + } + + for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} { + if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } + + for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} { + if count := DB.Model(&users).Association(key).Count(); count != value { + t.Errorf("user's %v expects: %v, got %v", key, value, count) + } + } +} diff --git a/tests/embedded_struct_test.go b/tests/embedded_struct_test.go index c29078bd..312a5c37 100644 --- a/tests/embedded_struct_test.go +++ b/tests/embedded_struct_test.go @@ -163,6 +163,8 @@ func TestEmbeddedRelations(t *testing.T) { DB.Migrator().DropTable(&AdvancedUser{}) if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { - t.Errorf("Failed to auto migrate advanced user, got error %v", err) + if DB.Dialector.Name() != "sqlite" { + t.Errorf("Failed to auto migrate advanced user, got error %v", err) + } } } diff --git a/tests/go.mod b/tests/go.mod index 5c1f6e3d..01f5d49a 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -7,11 +7,13 @@ require ( github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 github.com/stevefan1999-personal/gorm-driver-oracle latest - gorm.io/driver/mysql v1.0.0 - gorm.io/driver/postgres v1.0.0 - gorm.io/driver/sqlite v1.1.0 - gorm.io/driver/sqlserver v1.0.1 - gorm.io/gorm v1.9.19 + gorm.io/driver/mysql v1.0.1 + gorm.io/driver/postgres v1.0.2 + gorm.io/driver/sqlite v1.1.3 + gorm.io/driver/sqlserver v1.0.4 + gorm.io/gorm v1.20.2 ) replace gorm.io/gorm => ../ + +replace github.com/jackc/pgx/v4 => github.com/jinzhu/pgx/v4 v4.8.2 diff --git a/tests/helper_test.go b/tests/helper_test.go index cc0d808c..eee34e99 100644 --- a/tests/helper_test.go +++ b/tests/helper_test.go @@ -115,7 +115,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Pets", func(t *testing.T) { if len(user.Pets) != len(expect.Pets) { - t.Errorf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) + t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) } sort.Slice(user.Pets, func(i, j int) bool { @@ -137,7 +137,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Toys", func(t *testing.T) { if len(user.Toys) != len(expect.Toys) { - t.Errorf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) + t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) } sort.Slice(user.Toys, func(i, j int) bool { @@ -177,7 +177,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Team", func(t *testing.T) { if len(user.Team) != len(expect.Team) { - t.Errorf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) + t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) } sort.Slice(user.Team, func(i, j int) bool { @@ -195,7 +195,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Languages", func(t *testing.T) { if len(user.Languages) != len(expect.Languages) { - t.Errorf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) + t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) } sort.Slice(user.Languages, func(i, j int) bool { @@ -212,7 +212,7 @@ func CheckUser(t *testing.T, user User, expect User) { t.Run("Friends", func(t *testing.T) { if len(user.Friends) != len(expect.Friends) { - t.Errorf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) + t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) } sort.Slice(user.Friends, func(i, j int) bool { diff --git a/tests/joins_table_test.go b/tests/joins_table_test.go index b8c1be77..084c2f2c 100644 --- a/tests/joins_table_test.go +++ b/tests/joins_table_test.go @@ -5,12 +5,14 @@ import ( "time" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type Person struct { ID int Name string Addresses []Address `gorm:"many2many:person_addresses;"` + DeletedAt gorm.DeletedAt } type Address struct { @@ -95,4 +97,20 @@ func TestOverrideJoinTable(t *testing.T) { if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { t.Fatalf("address should be deleted when clear with unscoped") } + + address2_1 := Address{Name: "address 2-1"} + address2_2 := Address{Name: "address 2-2"} + person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}} + DB.Create(&person2) + if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil { + t.Fatalf("failed to delete person, got error: %v", err) + } + + if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 { + t.Errorf("person's addresses expects 2, got %v", count) + } + + if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 { + t.Errorf("person's addresses expects 2, got %v", count) + } } diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 051e3ee2..68da8a88 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -6,6 +6,7 @@ import ( "testing" "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" ) type Blog struct { @@ -410,3 +411,38 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { t.Fatalf("EN Blog's tags should be cleared") } } + +func TestCompositePrimaryKeysAssociations(t *testing.T) { + type Label struct { + BookID *uint `gorm:"primarykey"` + Name string `gorm:"primarykey"` + Value string + } + + type Book struct { + ID int + Name string + Labels []Label + } + + DB.Migrator().DropTable(&Label{}, &Book{}) + if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { + t.Fatalf("failed to migrate") + } + + book := Book{ + Name: "my book", + Labels: []Label{ + {Name: "region", Value: "emea"}, + }, + } + + DB.Create(&book) + + var result Book + if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil { + t.Fatalf("failed to preload, got error %v", err) + } + + AssertEqual(t, book, result) +} diff --git a/tests/named_argument_test.go b/tests/named_argument_test.go index 56fad5f4..d0a6f915 100644 --- a/tests/named_argument_test.go +++ b/tests/named_argument_test.go @@ -48,10 +48,22 @@ func TestNamedArg(t *testing.T) { t.Errorf("failed to update with named arg") } + namedUser.Name1 = "jinzhu-new" + namedUser.Name2 = "jinzhu-new2" + namedUser.Name3 = "jinzhu-new" + var result5 NamedUser if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil { t.Errorf("failed to update with named arg") } - AssertEqual(t, result4, namedUser) + AssertEqual(t, result5, namedUser) + + var result6 NamedUser + if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name + AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil { + t.Errorf("failed to update with named arg") + } + + AssertEqual(t, result6, namedUser) } diff --git a/tests/preload_test.go b/tests/preload_test.go index 7e5d2622..d9035661 100644 --- a/tests/preload_test.go +++ b/tests/preload_test.go @@ -1,6 +1,8 @@ package tests_test import ( + "encoding/json" + "regexp" "sort" "strconv" "testing" @@ -31,6 +33,20 @@ func TestPreloadWithAssociations(t *testing.T) { var user2 User DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + var user3 = *GetUser("preload_with_associations_new", Config{ + Account: true, + Pets: 2, + Toys: 3, + Company: true, + Manager: true, + Team: 4, + Languages: 3, + Friends: 1, + }) + + DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID) + CheckUser(t, user3, user) } func TestNestedPreload(t *testing.T) { @@ -174,3 +190,25 @@ func TestNestedPreloadWithConds(t *testing.T) { CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) } } + +func TestPreloadEmptyData(t *testing.T) { + var user = *GetUser("user_without_associations", Config{}) + DB.Create(&user) + + DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name) + + if r, err := json.Marshal(&user); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } + + var results []User + DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name) + + if r, err := json.Marshal(&results); err != nil { + t.Errorf("failed to marshal users, got error %v", err) + } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { + t.Errorf("json marshal is not empty slice, got %v", string(r)) + } +} diff --git a/tests/prepared_stmt_test.go b/tests/prepared_stmt_test.go new file mode 100644 index 00000000..6b10b6dc --- /dev/null +++ b/tests/prepared_stmt_test.go @@ -0,0 +1,52 @@ +package tests_test + +import ( + "context" + "testing" + "time" + + "gorm.io/gorm" + . "gorm.io/gorm/utils/tests" +) + +func TestPreparedStmt(t *testing.T) { + tx := DB.Session(&gorm.Session{PrepareStmt: true}) + + if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok { + t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") + } + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + txCtx := tx.WithContext(ctx) + + user := *GetUser("prepared_stmt", Config{}) + + txCtx.Create(&user) + + var result1 User + if err := txCtx.Find(&result1, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + time.Sleep(time.Second) + + var result2 User + if err := tx.Find(&result2, user.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } + + user2 := *GetUser("prepared_stmt2", Config{}) + if err := txCtx.Create(&user2).Error; err == nil { + t.Fatalf("should failed to create with timeout context") + } + + if err := tx.Create(&user2).Error; err != nil { + t.Fatalf("failed to create, got error %v", err) + } + + var result3 User + if err := tx.Find(&result3, user2.ID).Error; err != nil { + t.Fatalf("no error should happen but got %v", err) + } +} diff --git a/tests/query_test.go b/tests/query_test.go index d71c813a..431ccce2 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -1,11 +1,13 @@ package tests_test import ( + "database/sql" "fmt" "reflect" "regexp" "sort" "strconv" + "strings" "testing" "time" @@ -61,6 +63,54 @@ func TestFind(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := first[dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Age": + if _, ok := first[dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + case "Birthday": + if _, ok := first[dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) + AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) + }) + } + } + }) + + t.Run("FirstMapWithTable", func(t *testing.T) { + var first = map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { + t.Errorf("errors happened when query first: %v", err) + } else { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(first[dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) + } + } + reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) }) @@ -86,13 +136,29 @@ func TestFind(t *testing.T) { t.Run("FirstSliceOfMap", func(t *testing.T) { var allMap = []map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { - t.Errorf("errors happened when query first: %v", err) + t.Errorf("errors happened when query find: %v", err) } else { for idx, user := range users { t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) + + switch name { + case "Name": + if _, ok := allMap[idx][dbName].(string); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Age": + if _, ok := allMap[idx][dbName].(uint); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + case "Birthday": + if _, ok := allMap[idx][dbName].(*time.Time); !ok { + t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) + } + } + reflectValue := reflect.Indirect(reflect.ValueOf(user)) AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) }) @@ -101,6 +167,58 @@ func TestFind(t *testing.T) { } } }) + + t.Run("FindSliceOfMapWithTable", func(t *testing.T) { + var allMap = []map[string]interface{}{} + if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { + t.Errorf("errors happened when query find: %v", err) + } else { + for idx, user := range users { + t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { + for _, name := range []string{"Name", "Age", "Birthday"} { + t.Run(name, func(t *testing.T) { + dbName := DB.NamingStrategy.ColumnName("", name) + resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name() + + switch name { + case "Name": + if !strings.Contains(resultType, "string") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Age": + if !strings.Contains(resultType, "int") { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + case "Birthday": + if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { + t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) + } + } + + reflectValue := reflect.Indirect(reflect.ValueOf(user)) + AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) + }) + } + }) + } + } + }) + + var models []User + if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { + t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) + } else { + for idx, user := range users { + t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { + CheckUser(t, models[idx], user) + }) + } + } + + var none []User + if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { + t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) + } } func TestQueryWithAssociation(t *testing.T) { @@ -228,6 +346,11 @@ func TestNot(t *testing.T) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } + result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) @@ -309,6 +432,33 @@ func TestPluck(t *testing.T) { t.Errorf("Unexpected result on pluck id, got %+v", ids) } } + + var times []time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range times { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var ptrtimes []*time.Time + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range ptrtimes { + AssertEqual(t, tv, users[idx].CreatedAt) + } + + var nulltimes []sql.NullTime + if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil { + t.Errorf("got error when pluck time: %v", err) + } + + for idx, tv := range nulltimes { + AssertEqual(t, tv.Time, users[idx].CreatedAt) + } } func TestSelect(t *testing.T) { @@ -508,6 +658,7 @@ func TestLimit(t *testing.T) { {Name: "LimitUser3", Age: 20}, {Name: "LimitUser4", Age: 10}, {Name: "LimitUser5", Age: 20}, + {Name: "LimitUser6", Age: 20}, } DB.Create(&users) @@ -516,7 +667,7 @@ func TestLimit(t *testing.T) { DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { - t.Errorf("Limit should works") + t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3)) } } @@ -531,6 +682,7 @@ func TestOffset(t *testing.T) { if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") } + DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { @@ -698,3 +850,11 @@ func TestScanNullValue(t *testing.T) { t.Fatalf("failed to query slice data with null age, got error %v", err) } } + +func TestQueryWithTableAndConditions(t *testing.T) { + result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) + + if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} diff --git a/tests/scan_test.go b/tests/scan_test.go index d6a372bb..785bb97e 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -16,14 +17,25 @@ func TestScan(t *testing.T) { DB.Save(&user1).Save(&user2).Save(&user3) type result struct { + ID uint Name string Age int } var res result - DB.Table("users").Select("name, age").Where("id = ?", user3.ID).Scan(&res) - if res.Name != user3.Name || res.Age != int(user3.Age) { - t.Errorf("Scan into struct should work") + DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) + } + + DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) + if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) + } + + DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res) + if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { + t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } var doubleAgeRes = &result{} @@ -39,11 +51,11 @@ func TestScan(t *testing.T) { DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) sort.Slice(results, func(i, j int) bool { - return strings.Compare(results[i].Name, results[j].Name) < -1 + return strings.Compare(results[i].Name, results[j].Name) <= -1 }) if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { - t.Errorf("Scan into struct map") + t.Errorf("Scan into struct map, got %#v", results) } } @@ -72,7 +84,21 @@ func TestScanRows(t *testing.T) { results = append(results, result) } + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) <= -1 + }) + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { t.Errorf("Should find expected results") } + + var ages int + if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages) + } + + var name string + if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + } } diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index e6038947..c0176fc3 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "regexp" "strings" "testing" @@ -188,3 +189,56 @@ func TestGroupConditions(t *testing.T) { t.Errorf("expects: %v, got %v", expects, result) } } + +func TestCombineStringConditions(t *testing.T) { + dryRunDB := DB.Session(&gorm.Session{DryRun: true}) + sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR c = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } + + sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() + if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) { + t.Fatalf("invalid sql generated, got %v", sql) + } +} diff --git a/tests/tests_all.sh b/tests/tests_all.sh index 59fa4881..87e302a8 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -13,7 +13,7 @@ if [ -d tests ] then cd tests cp go.mod go.mod.bak - sed '/$[[:space:]]*gorm.io\/driver/d' go.mod.bak > go.mod + sed '/^[[:blank:]]*gorm.io\/driver/d' go.mod.bak > go.mod cd .. fi diff --git a/tests/tests_test.go b/tests/tests_test.go index d71c15af..9fcf5e24 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -22,7 +22,7 @@ var DB *gorm.DB func init() { var err error if DB, err = OpenTestConnection(); err != nil { - log.Printf("failed to connect database, got error %v\n", err) + log.Printf("failed to connect database, got error %v", err) os.Exit(1) } else { sqlDB, err := DB.DB() @@ -31,7 +31,7 @@ func init() { } if err != nil { - log.Printf("failed to connect database, got error %v\n", err) + log.Printf("failed to connect database, got error %v", err) } RunMigrations() diff --git a/tests/transaction_test.go b/tests/transaction_test.go index aea151d9..334600b8 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -282,3 +282,24 @@ func TestNestedTransactionWithBlock(t *testing.T) { t.Fatalf("Should find saved record") } } + +func TestTransactionOnClosedConn(t *testing.T) { + DB, err := OpenTestConnection() + if err != nil { + t.Fatalf("failed to connect database, got error %v", err) + } + rawDB, _ := DB.DB() + rawDB.Close() + + if err := DB.Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } + + if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error { + return nil + }); err == nil { + t.Errorf("should returns error when commit with closed conn, got error %v", err) + } +} diff --git a/tests/update_belongs_to_test.go b/tests/update_belongs_to_test.go index 47076e69..736dfc5b 100644 --- a/tests/update_belongs_to_test.go +++ b/tests/update_belongs_to_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,4 +23,22 @@ func TestUpdateBelongsTo(t *testing.T) { var user2 User DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + user.Company.Name += "new" + user.Manager.Name += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) } diff --git a/tests/update_has_many_test.go b/tests/update_has_many_test.go index 01ea2e3a..9066cbac 100644 --- a/tests/update_has_many_test.go +++ b/tests/update_has_many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -22,6 +23,26 @@ func TestUpdateHasManyAssociations(t *testing.T) { DB.Preload("Pets").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + for _, pet := range user.Pets { + pet.Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Pets").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Pets").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var user = *GetUser("update-has-many", Config{}) @@ -37,5 +58,25 @@ func TestUpdateHasManyAssociations(t *testing.T) { var user2 User DB.Preload("Toys").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Toys { + user.Toys[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Toys").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Toys").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) }) } diff --git a/tests/update_has_one_test.go b/tests/update_has_one_test.go index 7b29f424..54568546 100644 --- a/tests/update_has_one_test.go +++ b/tests/update_has_one_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -23,6 +24,23 @@ func TestUpdateHasOne(t *testing.T) { DB.Preload("Account").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + user.Account.Number += "new" + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Account").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Account").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) + t.Run("Polymorphic", func(t *testing.T) { var pet = Pet{Name: "create"} @@ -39,5 +57,22 @@ func TestUpdateHasOne(t *testing.T) { var pet2 Pet DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) CheckPet(t, pet2, pet) + + pet.Toy.Name += "new" + if err := DB.Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet3 Pet + DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID) + CheckPet(t, pet2, pet3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var pet4 Pet + DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) + CheckPet(t, pet4, pet) }) } diff --git a/tests/update_many2many_test.go b/tests/update_many2many_test.go index a46deeb0..d94ef4ab 100644 --- a/tests/update_many2many_test.go +++ b/tests/update_many2many_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -26,4 +27,28 @@ func TestUpdateMany2ManyAssociations(t *testing.T) { var user2 User DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) + + for idx := range user.Friends { + user.Friends[idx].Name += "new" + } + + for idx := range user.Languages { + user.Languages[idx].Name += "new" + } + + if err := DB.Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user3 User + DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID) + CheckUser(t, user2, user3) + + if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { + t.Fatalf("errors happened when update: %v", err) + } + + var user4 User + DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID) + CheckUser(t, user4, user) } diff --git a/tests/update_test.go b/tests/update_test.go index 1944ed3f..a660647c 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -629,4 +629,26 @@ func TestSaveWithPrimaryValue(t *testing.T) { var result2 Language DB.First(&result2, "code = ?", "save") AssertEqual(t, result2, lang) + + DB.Table("langs").Migrator().DropTable(&Language{}) + DB.Table("langs").AutoMigrate(&Language{}) + + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result3 Language + if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3) + } + + lang.Name += "name2" + if err := DB.Table("langs").Save(&lang).Error; err != nil { + t.Errorf("no error should happen when creating data, but got %v", err) + } + + var result4 Language + if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name { + t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) + } } diff --git a/utils/utils.go b/utils/utils.go index 71336f4b..ecba7fb9 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -30,7 +30,7 @@ func FileWithLineNum() string { } func IsValidDBNameChar(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' + return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } func CheckTruth(val interface{}) bool { @@ -83,3 +83,31 @@ func AssertEqual(src, dst interface{}) bool { } return true } + +func ToString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case int: + return strconv.FormatInt(int64(v), 10) + case int8: + return strconv.FormatInt(int64(v), 10) + case int16: + return strconv.FormatInt(int64(v), 10) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.FormatUint(uint64(v), 10) + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint16: + return strconv.FormatUint(uint64(v), 10) + case uint32: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + } + return "" +}