diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 370417fc..8bd2bcb3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: sqlite: strategy: matrix: - go: ['1.16', '1.15', '1.14'] + go: ['1.16', '1.15'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} @@ -38,8 +38,8 @@ jobs: mysql: strategy: matrix: - dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest'] - go: ['1.16', '1.15', '1.14'] + dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest'] + go: ['1.16', '1.15'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} @@ -82,8 +82,8 @@ jobs: postgres: strategy: matrix: - dbversion: ['postgres:latest', 'postgres:11', 'postgres:10'] - go: ['1.16', '1.15', '1.14'] + dbversion: ['postgres:latest', 'postgres:12', 'postgres:11', 'postgres:10'] + go: ['1.16', '1.15'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} @@ -125,7 +125,7 @@ jobs: sqlserver: strategy: matrix: - go: ['1.16', '1.15', '1.14'] + go: ['1.16', '1.15'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} diff --git a/association.go b/association.go index 572f1526..62c25b71 100644 --- a/association.go +++ b/association.go @@ -26,7 +26,7 @@ func (db *DB) Association(column string) *Association { association.Relationship = db.Statement.Schema.Relationships.Relations[column] if association.Relationship == nil { - association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) + association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column) } db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) @@ -355,7 +355,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } else if ev.Type().Elem().AssignableTo(elemType) { fieldValue = reflect.Append(fieldValue, ev.Elem()) } else { - association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name) + association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name) } if elemType.Kind() == reflect.Struct { diff --git a/callbacks.go b/callbacks.go index 26e9c40d..02e741e7 100644 --- a/callbacks.go +++ b/callbacks.go @@ -212,7 +212,7 @@ func (c *callback) Register(name string, fn func(*DB)) error { } func (c *callback) Remove(name string) error { - c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) @@ -220,7 +220,7 @@ func (c *callback) Remove(name string) error { } func (c *callback) Replace(name string, fn func(*DB)) error { - c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true @@ -250,7 +250,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { - c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } @@ -266,7 +266,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { // if before callback already sorted, append current callback just after it sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) } else if curIdx > sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) } } else if idx := getRIndex(names, c.before); idx != -1 { // if before callback exists @@ -284,7 +284,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { // if after callback sorted, append current callback to last sorted = append(sorted, c.name) } else if curIdx < sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) + return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) } } else if idx := getRIndex(names, c.after); idx != -1 { // if after callback exists but haven't sorted diff --git a/callbacks/associations.go b/callbacks/associations.go index 6d74f20d..78f976c3 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -373,7 +373,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, }) if tx.Statement.FullSaveAssociations { - tx = tx.InstanceSet("gorm:update_track_time", true) + tx = tx.Set("gorm:update_track_time", true) } if len(selects) > 0 { diff --git a/callbacks/create.go b/callbacks/create.go index 727bd380..04ee6b30 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -33,75 +33,81 @@ func BeforeCreate(db *gorm.DB) { func Create(config *Config) func(db *gorm.DB) { if config.WithReturning { return CreateWithReturning - } else { - return func(db *gorm.DB) { - if db.Error == nil { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) - } - } + } - if db.Statement.SQL.String() == "" { - db.Statement.SQL.Grow(180) - db.Statement.AddClauseIfNotExists(clause.Insert{}) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + return func(db *gorm.DB) { + if db.Error != nil { + // maybe record logger TODO + return + } - db.Statement.Build(db.Statement.BuildClauses...) - } + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } + } - if !db.DryRun && db.Error == nil { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if db.Statement.SQL.String() == "" { + db.Statement.SQL.Grow(180) + db.Statement.AddClauseIfNotExists(clause.Insert{}) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - if err == nil { - db.RowsAffected, _ = result.RowsAffected() + db.Statement.Build(db.Statement.BuildClauses...) + } - if db.RowsAffected > 0 { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { - if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + if !db.DryRun && db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) - if isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - rv := db.Statement.ReflectValue.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break - } + if err != nil { + db.AddError(err) + return + } - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) - insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement - } - } - } - case reflect.Struct: - if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) - } - } - } else { - db.AddError(err) + db.RowsAffected, _ = result.RowsAffected() + if !(db.RowsAffected > 0) { + return + } + + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { + if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) + if isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement + } + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + rv := db.Statement.ReflectValue.Index(i) + if reflect.Indirect(rv).Kind() != reflect.Struct { + break + } + + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) + insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement } } } - } else { - db.AddError(err) + case reflect.Struct: + if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } } + } else { + db.AddError(err) } } + } } } @@ -237,9 +243,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) + _, updateTrackTime = stmt.Get("gorm:update_track_time") curTime = stmt.DB.NowFunc() isZero bool ) + stmt.Settings.Delete("gorm:update_track_time") + values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} for _, db := range stmt.Schema.DBNames { @@ -278,11 +287,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { field.Set(rv, curTime) values.Values[i][idx], _ = field.ValueOf(rv) } - } else if field.AutoUpdateTime > 0 { - if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { - field.Set(rv, curTime) - values.Values[i][idx], _ = field.ValueOf(rv) - } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + field.Set(rv, curTime) + values.Values[i][idx], _ = field.ValueOf(rv) } } @@ -320,11 +327,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { field.Set(stmt.ReflectValue, curTime) values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) } - } else if field.AutoUpdateTime > 0 { - if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { - field.Set(stmt.ReflectValue, curTime) - values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) - } + } else if field.AutoUpdateTime > 0 && updateTrackTime { + field.Set(stmt.ReflectValue, curTime) + values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) } } diff --git a/callbacks/row.go b/callbacks/row.go index 10e880e1..407c32d7 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -9,7 +9,8 @@ func RowQuery(db *gorm.DB) { BuildQuerySQL(db) if !db.DryRun { - if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { + if isRows, ok := db.Get("rows"); ok && isRows.(bool) { + db.Statement.Settings.Delete("rows") db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/clause/expression.go b/clause/expression.go index f76ce138..2bdd4a30 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -233,11 +233,24 @@ type Eq struct { func (eq Eq) Build(builder Builder) { builder.WriteQuoted(eq.Column) - if eqNil(eq.Value) { - builder.WriteString(" IS NULL") - } else { - builder.WriteString(" = ") - builder.AddVar(builder, eq.Value) + switch eq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + builder.WriteString(" IN (") + rv := reflect.ValueOf(eq.Value) + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + builder.WriteByte(')') + default: + if eqNil(eq.Value) { + builder.WriteString(" IS NULL") + } else { + builder.WriteString(" = ") + builder.AddVar(builder, eq.Value) + } } } @@ -251,11 +264,24 @@ type Neq Eq func (neq Neq) Build(builder Builder) { builder.WriteQuoted(neq.Column) - if eqNil(neq.Value) { - builder.WriteString(" IS NOT NULL") - } else { - builder.WriteString(" <> ") - builder.AddVar(builder, neq.Value) + switch neq.Value.(type) { + case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: + builder.WriteString(" NOT IN (") + rv := reflect.ValueOf(neq.Value) + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) + } + builder.WriteByte(')') + default: + if eqNil(neq.Value) { + builder.WriteString(" IS NOT NULL") + } else { + builder.WriteString(" <> ") + builder.AddVar(builder, neq.Value) + } } } diff --git a/clause/expression_test.go b/clause/expression_test.go index 4472bdb1..1c8217ed 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -105,13 +105,15 @@ func TestNamedExpr(t *testing.T) { func TestExpression(t *testing.T) { column := "column-name" results := []struct { - Expressions []clause.Expression - Result string + Expressions []clause.Expression + ExpectedVars []interface{} + Result string }{{ Expressions: []clause.Expression{ clause.Eq{Column: column, Value: "column-value"}, }, - Result: "`column-name` = ?", + ExpectedVars: []interface{}{"column-value"}, + Result: "`column-name` = ?", }, { Expressions: []clause.Expression{ clause.Eq{Column: column, Value: nil}, @@ -126,7 +128,8 @@ func TestExpression(t *testing.T) { Expressions: []clause.Expression{ clause.Neq{Column: column, Value: "column-value"}, }, - Result: "`column-name` <> ?", + ExpectedVars: []interface{}{"column-value"}, + Result: "`column-name` <> ?", }, { Expressions: []clause.Expression{ clause.Neq{Column: column, Value: nil}, @@ -136,6 +139,18 @@ func TestExpression(t *testing.T) { clause.Neq{Column: column, Value: (interface{})(nil)}, }, Result: "`column-name` IS NOT NULL", + }, { + Expressions: []clause.Expression{ + clause.Eq{Column: column, Value: []string{"a", "b"}}, + }, + ExpectedVars: []interface{}{"a", "b"}, + Result: "`column-name` IN (?,?)", + }, { + Expressions: []clause.Expression{ + clause.Neq{Column: column, Value: []string{"a", "b"}}, + }, + ExpectedVars: []interface{}{"a", "b"}, + Result: "`column-name` NOT IN (?,?)", }} for idx, result := range results { @@ -147,6 +162,10 @@ func TestExpression(t *testing.T) { if stmt.SQL.String() != result.Result { t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) } + + if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { + t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) + } }) } } diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 127d9bc1..64ee7f53 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -3,6 +3,7 @@ package clause type OnConflict struct { Columns []Column Where Where + TargetWhere Where OnConstraint string DoNothing bool DoUpdates Set @@ -25,6 +26,12 @@ func (onConflict OnConflict) Build(builder Builder) { } builder.WriteString(`) `) } + + if len(onConflict.TargetWhere.Exprs) > 0 { + builder.WriteString(" WHERE ") + onConflict.TargetWhere.Build(builder) + builder.WriteByte(' ') + } if onConflict.OnConstraint != "" { builder.WriteString("ON CONSTRAINT ") diff --git a/finisher_api.go b/finisher_api.go index c3941784..51f394b4 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -79,7 +79,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) } - tx = tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) + tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true)) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { @@ -190,16 +190,17 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat if tx.Error != nil || int(result.RowsAffected) < batchSize { break - } else { - resultsValue := reflect.Indirect(reflect.ValueOf(dest)) - if result.Statement.Schema.PrioritizedPrimaryField == nil { - tx.AddError(ErrPrimaryKeyRequired) - break - } else { - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) - queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) - } } + + // Optimize for-break + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + if result.Statement.Schema.PrioritizedPrimaryField == nil { + tx.AddError(ErrPrimaryKeyRequired) + break + } + + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } tx.RowsAffected = rowsAffected @@ -304,7 +305,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { return tx.Create(dest) } else if len(db.Statement.assigns) > 0 { - exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) + exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) assigns := map[string]interface{}{} for _, expr := range exprs { if eq, ok := expr.(clause.Eq); ok { @@ -382,9 +383,9 @@ func (db *DB) Count(count *int64) (tx *DB) { } if len(tx.Statement.Selects) == 0 { - tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}}) + tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}}) } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") { - expr := clause.Expr{SQL: "count(1)"} + expr := clause.Expr{SQL: "count(*)"} if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] @@ -425,7 +426,7 @@ func (db *DB) Count(count *int64) (tx *DB) { } func (db *DB) Row() *sql.Row { - tx := db.getInstance().InstanceSet("rows", false) + tx := db.getInstance().Set("rows", false) tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { @@ -435,7 +436,7 @@ func (db *DB) Row() *sql.Row { } func (db *DB) Rows() (*sql.Rows, error) { - tx := db.getInstance().InstanceSet("rows", true) + tx := db.getInstance().Set("rows", true) tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { @@ -473,7 +474,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { // Pluck used to query single column from a model as a map // var ages []int64 -// db.Find(&users).Pluck("age", &ages) +// db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Model != nil { diff --git a/gorm.go b/gorm.go index e105a933..7f7bad26 100644 --- a/gorm.go +++ b/gorm.go @@ -409,7 +409,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac } ref.ForeignKey = f } else { - return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) + return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) } } @@ -422,7 +422,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac relation.JoinTable = joinSchema } else { - return fmt.Errorf("failed to found relation: %v", field) + return fmt.Errorf("failed to found relation: %s", field) } return nil diff --git a/logger/logger.go b/logger/logger.go index 381199d5..98d1b32e 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -58,7 +58,7 @@ type Interface interface { Info(context.Context, string, ...interface{}) Warn(context.Context, string, ...interface{}) Error(context.Context, string, ...interface{}) - Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) + Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) } var ( diff --git a/migrator/migrator.go b/migrator/migrator.go index 1800ab54..03ffdd02 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -119,13 +119,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { - if constraint := rel.ParseConstraint(); constraint != nil { - if constraint.Schema == stmt.Schema { - if !tx.Migrator().HasConstraint(value, constraint.Name) { - if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { - return err - } - } + if constraint := rel.ParseConstraint(); constraint != nil && + constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err } } } @@ -294,16 +291,20 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { func (m Migrator) AddColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - if !field.IgnoreMigration { - return m.DB.Exec( - "ALTER TABLE ? ADD ? ?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), - ).Error - } - return nil + // avoid using the same name field + f := stmt.Schema.LookUpField(field) + if f == nil { + return fmt.Errorf("failed to look up field with name: %s", field) } - return fmt.Errorf("failed to look up field with name: %s", field) + + if !f.IgnoreMigration { + return m.DB.Exec( + "ALTER TABLE ? ADD ? ?", + m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f), + ).Error + } + + return nil }) } diff --git a/prepare_stmt.go b/prepare_stmt.go index 14570061..48a614b7 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -64,7 +64,7 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction} db.PreparedSQL = append(db.PreparedSQL, query) } - db.Mux.Unlock() + defer db.Mux.Unlock() return db.Stmts[query], err } diff --git a/schema/field.go b/schema/field.go index 5dbc96f1..9efaa44a 100644 --- a/schema/field.go +++ b/schema/field.go @@ -198,28 +198,28 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = Bool if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err) } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err) } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err) } } case reflect.Float32, reflect.Float64: field.DataType = Float if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { - schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err) + schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err) } } case reflect.String: @@ -227,7 +227,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") - field.DefaultValue = strings.Trim(field.DefaultValue, "\"") + field.DefaultValue = strings.Trim(field.DefaultValue, `"`) field.DefaultValueInterface = field.DefaultValue } case reflect.Struct: @@ -392,7 +392,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } } else { - schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) + schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } } @@ -423,12 +423,12 @@ func (field *Field) setupValuerAndSetter() { } else { v = v.Field(-idx - 1) - if v.Type().Elem().Kind() == reflect.Struct { - if !v.IsNil() { - v = v.Elem() - } else { - return nil, true - } + if v.Type().Elem().Kind() != reflect.Struct { + return nil, true + } + + if !v.IsNil() { + v = v.Elem() } else { return nil, true } @@ -736,7 +736,7 @@ func (field *Field) setupValuerAndSetter() { if t, err := now.Parse(data); err == nil { field.ReflectValueOf(value).Set(reflect.ValueOf(t)) } else { - return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: return fallbackSetter(value, v, field.Set) @@ -765,7 +765,7 @@ func (field *Field) setupValuerAndSetter() { } fieldValue.Elem().Set(reflect.ValueOf(t)) } else { - return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) + return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: return fallbackSetter(value, v, field.Set) diff --git a/schema/naming.go b/schema/naming.go index d53942e4..47e313a7 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -74,7 +74,9 @@ func (ns NamingStrategy) IndexName(table, column string) string { } func (ns NamingStrategy) formatName(prefix, table, name string) string { - formattedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1) + formattedName := strings.Replace(strings.Join([]string{ + prefix, table, name, + }, "_"), ".", "_", -1) if utf8.RuneCountInString(formattedName) > 64 { h := sha1.New() diff --git a/schema/relationship.go b/schema/relationship.go index 62256c28..db496e30 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -85,7 +85,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: - schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name) + schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) } } @@ -143,11 +143,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } if relation.Polymorphic.PolymorphicType == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { - schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID") + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { @@ -159,7 +159,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi primaryKeyField := schema.PrioritizedPrimaryField if len(relation.foreignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { - schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) + schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) } } @@ -203,7 +203,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel if field := schema.LookUpField(foreignKey); field != nil { ownForeignFields = append(ownForeignFields, field) } else { - schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) return } } @@ -215,7 +215,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { refForeignFields = append(refForeignFields, field) } else { - schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey) + schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) return } } @@ -379,7 +379,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: - schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a valid foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name) + schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) } } diff --git a/schema/schema.go b/schema/schema.go index 1ce88fa5..8ade2ed7 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -45,9 +45,9 @@ type Schema struct { func (schema Schema) String() string { if schema.ModelType.Name() == "" { - return fmt.Sprintf("%v(%v)", schema.Name, schema.Table) + return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } - return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name()) + return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } func (schema Schema) MakeSlice() reflect.Value { @@ -86,7 +86,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { @@ -275,7 +275,7 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } - return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) + return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } if v, ok := cacheStore.Load(modelType); ok { diff --git a/schema/utils.go b/schema/utils.go index add22047..e005cc74 100644 --- a/schema/utils.go +++ b/schema/utils.go @@ -178,17 +178,18 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa } return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues - } else { - columns := make([]clause.Column, len(foreignKeys)) - for idx, key := range foreignKeys { - columns[idx] = clause.Column{Table: table, Name: key} - } - - for idx, r := range foreignValues { - queryValues[idx] = r - } - return columns, queryValues } + + columns := make([]clause.Column, len(foreignKeys)) + for idx, key := range foreignKeys { + columns[idx] = clause.Column{Table: table, Name: key} + } + + for idx, r := range foreignValues { + queryValues[idx] = r + } + + return columns, queryValues } type embeddedNamer struct { diff --git a/soft_delete.go b/soft_delete.go index b16041f1..af02f8fd 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -84,6 +84,32 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { } } +func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { + return []clause.Interface{SoftDeleteUpdateClause{Field: f}} +} + +type SoftDeleteUpdateClause struct { + Field *schema.Field +} + +func (sd SoftDeleteUpdateClause) Name() string { + return "" +} + +func (sd SoftDeleteUpdateClause) Build(clause.Builder) { +} + +func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { +} + +func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { + if stmt.SQL.String() == "" { + if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { + SoftDeleteQueryClause(sd).ModifyStatement(stmt) + } + } +} + func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { return []clause.Interface{SoftDeleteDeleteClause{Field: f}} } diff --git a/statement.go b/statement.go index a87fd212..8b682c84 100644 --- a/statement.go +++ b/statement.go @@ -57,12 +57,12 @@ type StatementModifier interface { ModifyStatement(*Statement) } -// Write write string +// WriteString write string func (stmt *Statement) WriteString(str string) (int, error) { return stmt.SQL.WriteString(str) } -// Write write string +// WriteByte write byte func (stmt *Statement) WriteByte(c byte) error { return stmt.SQL.WriteByte(c) } @@ -152,7 +152,7 @@ func (stmt *Statement) Quote(field interface{}) string { return builder.String() } -// Write write string +// AddVar add var func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { for idx, v := range vars { if idx > 0 { @@ -506,7 +506,6 @@ func (stmt *Statement) clone() *Statement { return newStmt } -// Helpers // SetColumn set column's value // stmt.SetColumn("Name", "jinzhu") // Hooks Method // stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method @@ -540,11 +539,6 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . } } - if !stmt.ReflectValue.CanAddr() { - stmt.AddError(ErrInvalidValue) - return - } - switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { @@ -555,6 +549,11 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks . field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) } case reflect.Struct: + if !stmt.ReflectValue.CanAddr() { + stmt.AddError(ErrInvalidValue) + return + } + field.Set(stmt.ReflectValue, value) } } else { diff --git a/tests/associations_test.go b/tests/associations_test.go index f470338f..3b270625 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -64,7 +64,7 @@ func TestAssociationNotNullClear(t *testing.T) { } if err := DB.Model(member).Association("Profiles").Clear(); err == nil { - t.Fatalf("No error occured during clearind not null association") + t.Fatalf("No error occurred during clearind not null association") } } diff --git a/tests/count_test.go b/tests/count_test.go index 0fef82f7..dd25f8b6 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -124,7 +124,6 @@ func TestCount(t *testing.T) { var count9 int64 if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB { - fmt.Println("kdkdkdkdk") return tx.Table("users") }).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) diff --git a/tests/go.mod b/tests/go.mod index 643b72c7..815f8986 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -3,10 +3,9 @@ module gorm.io/gorm/tests go 1.14 require ( - github.com/google/uuid v1.1.1 + github.com/google/uuid v1.2.0 github.com/jinzhu/now v1.1.2 github.com/lib/pq v1.6.0 - github.com/stretchr/testify v1.5.1 gorm.io/driver/mysql v1.0.5 gorm.io/driver/postgres v1.1.0 gorm.io/driver/sqlite v1.1.4 diff --git a/tests/scopes_test.go b/tests/scopes_test.go index 0ec4783b..94fff308 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "context" "testing" "gorm.io/gorm" @@ -62,4 +63,12 @@ func TestScopes(t *testing.T) { if result.RowsAffected != 2 { t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) } + + var maxId int64 + userTable := func(db *gorm.DB) *gorm.DB { + return db.WithContext(context.Background()).Table("users") + } + if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil { + t.Errorf("select max(id)") + } } diff --git a/tests/tests_all.sh b/tests/tests_all.sh index e0ed97a4..f5657df1 100755 --- a/tests/tests_all.sh +++ b/tests/tests_all.sh @@ -11,6 +11,7 @@ then cd tests go get -u ./... go mod download + go mod tidy cd .. fi diff --git a/utils/utils.go b/utils/utils.go index ce6f35df..1110c7a7 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -15,17 +15,20 @@ var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) + // compatible solution to get gorm source directory with various operating systems gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "") } +// FileWithLineNum return the file name and line number of the current file func FileWithLineNum() string { + // the second caller usually from gorm internal, so set i start from 2 for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) { return file + ":" + strconv.FormatInt(int64(line), 10) } } + return "" }