diff --git a/association.go b/association.go index 3a2942fd..932d5f55 100644 --- a/association.go +++ b/association.go @@ -24,10 +24,12 @@ func (db *DB) Association(column string) *Association { if err := db.Statement.Parse(db.Statement.Model); err == nil { db.Statement.Table = table - association.Relationship = db.Statement.Schema.Relationships.Relations[column] + association.Relationship = db.Statement.Schema. + Relationships.Relations[column] if association.Relationship == nil { - association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column) + association.Error = fmt.Errorf("%w: %v", + ErrUnsupportedRelation, column) } db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) @@ -41,9 +43,11 @@ func (db *DB) Association(column string) *Association { return association } -func (association *Association) Find(out interface{}, conds ...interface{}) error { +func (association *Association) Find(out interface{}, + conds ...interface{}) error { if association.Error == nil { - association.Error = association.buildCondition().Find(out, conds...).Error + association.Error = association.buildCondition(). + Find(out, conds...).Error } return association.Error } @@ -80,10 +84,12 @@ func (association *Association) Replace(values ...interface{}) error { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { - association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(reflectValue.Index(i), + reflect.Zero(rel.Field.FieldType).Interface()) } case reflect.Struct: - association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) + association.Error = rel.Field.Set(reflectValue, + reflect.Zero(rel.Field.FieldType).Interface()) } for _, ref := range rel.References { @@ -118,9 +124,11 @@ func (association *Association) Replace(values ...interface{}) error { } } - if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { + if _, pvs := schema. + GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + association.Error = tx.Where(clause.IN{Column: column, Values: values}). + UpdateColumns(updateMap).Error } case schema.Many2Many: var ( @@ -152,7 +160,8 @@ func (association *Association) Replace(values ...interface{}) error { } _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) - if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { + if relColumn, relValues := schema. + ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } diff --git a/callbacks.go b/callbacks.go index cb14aff1..eefde14a 100644 --- a/callbacks.go +++ b/callbacks.go @@ -82,8 +82,11 @@ func (p *processor) Execute(db *DB) { } if stmt.Model != nil { - if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) { - if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" { + if err := stmt.Parse(stmt.Model); err != nil && + (!errors.Is(err, schema.ErrUnsupportedDataType) || + (stmt.Table == "" && stmt.SQL.Len() == 0)) { + if errors.Is(err, schema.ErrUnsupportedDataType) && + stmt.Table == "" { db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) } else { db.AddError(err) @@ -163,7 +166,8 @@ func (p *processor) compile() (err error) { p.callbacks = callbacks if p.fns, err = sortCallbacks(p.callbacks); err != nil { - p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) + p.db.Logger.Error(context.Background(), + "Got error when compile callbacks, got %v", err) } return } @@ -186,7 +190,8 @@ func (c *callback) Register(name string, fn func(*DB)) error { } func (c *callback) Remove(name string) error { - c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Warn(context.Background(), + "removing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) @@ -194,7 +199,8 @@ func (c *callback) Remove(name string) error { } func (c *callback) Replace(name string, fn func(*DB)) error { - c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) + c.processor.db.Logger.Info(context.Background(), + "replacing callback `%v` from %v\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true @@ -223,8 +229,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { for _, c := range cs { // show warning message the callback name already exists - if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { - c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum()) + if idx := getRIndex(names, c.name); idx > -1 && !c.replace && + !c.remove && !cs[idx].remove { + c.processor.db.Logger.Warn(context.Background(), + "duplicated callback `%v` from %v\n", c.name, + utils.FileWithLineNum()) } names = append(names, c.name) } @@ -238,9 +247,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if before callback already sorted, append current callback just after it - sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) + sorted = append(sorted[:sortedIdx], + append([]string{c.name}, sorted[sortedIdx:]...)...) } else if curIdx > sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before) + return fmt.Errorf("conflicting callback %v with before %v", + c.name, c.before) } } else if idx := getRIndex(names, c.before); idx != -1 { // if before callback exists @@ -258,7 +269,8 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { // if after callback sorted, append current callback to last sorted = append(sorted, c.name) } else if curIdx < sortedIdx { - return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after) + return fmt.Errorf("conflicting callback %v with before %v", + c.name, c.after) } } else if idx := getRIndex(names, c.after); idx != -1 { // if after callback exists but haven't sorted diff --git a/chainable_api.go b/chainable_api.go index 58b9336f..ded5cbbd 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -12,7 +12,8 @@ import ( // Model specify the model you would like to run db operations // // update all users's name to `hello` // db.Model(&User{}).Update("name", "hello") -// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello` +// // if user's primary key is non-blank, will use it as condition, +// then will only update the user's name to `hello` // db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() @@ -36,7 +37,8 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { } if len(whereConds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)}) + tx.Statement.AddClause(clause.Where{Exprs: tx.Statement. + BuildCondition(whereConds[0], whereConds[1:]...)}) } return } @@ -46,9 +48,11 @@ var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`) // Table specify the table you would like to run db operations func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() - if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { + if strings.Contains(name, " ") || strings.Contains(name, "`") || + len(args) > 0 { tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} - if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 { + if results := tableRegexp. + FindStringSubmatch(name); len(results) == 2 { tx.Statement.Table = results[1] return } @@ -87,7 +91,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { case []string: tx.Statement.Selects = append(tx.Statement.Selects, arg...) default: - tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) + tx.AddError(fmt.Errorf("unsupported select args %v %v", + query, args)) return } } @@ -125,12 +130,14 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { return } -// Omit specify fields that you want to ignore when creating, updating and querying +// Omit specify fields that you want to ignore when creating, +// updating and querying func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { - tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar) + tx.Statement.Omits = strings.FieldsFunc(columns[0], + utils.IsValidDBNameChar) } else { tx.Statement.Omits = columns } @@ -150,7 +157,9 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) + tx.Statement.AddClause(clause.Where{ + Exprs: []clause.Expression{clause.Not(conds...)}, + }) } return } @@ -158,8 +167,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { // Or add OR conditions func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() - if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { - tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}}) + if conds := tx.Statement. + BuildCondition(query, args...); len(conds) > 0 { + tx.Statement.AddClause(clause.Where{ + Exprs: []clause.Expression{clause.Or(clause.And(conds...))}}, + ) } return } @@ -169,7 +181,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() - tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args}) + tx.Statement.Joins = append(tx.Statement.Joins, join{ + Name: query, Conds: args, + }) + return } diff --git a/errors.go b/errors.go index 08755083..137d442e 100644 --- a/errors.go +++ b/errors.go @@ -7,7 +7,8 @@ import ( var ( // ErrRecordNotFound record not found error ErrRecordNotFound = errors.New("record not found") - // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + // ErrInvalidTransaction invalid transaction + // when you are trying to `Commit` or `Rollback` ErrInvalidTransaction = errors.New("no valid transaction") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") diff --git a/finisher_api.go b/finisher_api.go index 4a3c323b..4ce6c58b 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -67,7 +67,8 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { return } -// Save update value in database, if the value doesn't have primary key, will insert it +// Save update value in database, if the value doesn't have primary key, +// will insert it func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value @@ -78,9 +79,11 @@ func (db *DB) Save(value interface{}) (tx *DB) { if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) } - tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) + tx.callbacks.Create(). + Execute(tx.InstanceSet("gorm:update_track_time", true)) case reflect.Struct: - if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { + if err := tx.Statement.Parse(value); err == nil && + tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { if _, isZero := pf.ValueOf(reflectValue); isZero { tx.callbacks.Create().Execute(tx) @@ -99,9 +102,11 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.callbacks.Update().Execute(tx) - if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { + if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && + !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) { + if err := tx.Session(&Session{}). + First(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } @@ -113,10 +118,15 @@ func (db *DB) Save(value interface{}) (tx *DB) { // First find first record that match given conditions, order by primary key func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ - Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + Column: clause.Column{ + Table: clause.CurrentTable, + Name: clause.PrimaryKey, + }, }) + if len(conds) > 0 { - if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { + if exprs := tx. + Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) } } @@ -126,7 +136,8 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { return } -// Take return a record that match given conditions, the order will depend on the database implementation +// Take return a record that match given conditions, the order will +// depend on the database implementation func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { @@ -198,8 +209,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat tx.AddError(ErrPrimaryKeyRequired) break } else { - primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) - queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) + primaryValue, _ := result.Statement.Schema. + PrioritizedPrimaryField. + ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + queryDB = tx.Clauses(clause.Gt{ + Column: clause.Column{ + Table: clause.CurrentTable, + Name: clause.PrimaryKey, + }, + Value: primaryValue}) } } } diff --git a/interfaces.go b/interfaces.go index e933952b..7468cf9b 100644 --- a/interfaces.go +++ b/interfaces.go @@ -29,9 +29,12 @@ type Plugin interface { // ConnPool db conns pool interface type ConnPool interface { PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) - QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + ExecContext(ctx context.Context, query string, + args ...interface{}) (sql.Result, error) + QueryContext(ctx context.Context, query string, + args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, + args ...interface{}) *sql.Row } // SavePointerDialectorInterface save pointer interface diff --git a/migrator.go b/migrator.go index 28ac35e7..86d4d683 100644 --- a/migrator.go +++ b/migrator.go @@ -48,7 +48,8 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error - MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error + MigrateColumn(dst interface{}, field *schema.Field, + columnType ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]ColumnType, error) diff --git a/model.go b/model.go index 3334d17c..0245bb65 100644 --- a/model.go +++ b/model.go @@ -2,8 +2,10 @@ package gorm import "time" -// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt -// It may be embedded into your model or you may build your own model without it +// Model a basic GoLang struct which includes the following fields: ID, +// CreatedAt, UpdatedAt, DeletedAt +// It may be embedded into your model +// or you may build your own model without it // type User struct { // gorm.Model // } diff --git a/prepare_stmt.go b/prepare_stmt.go index 78a8adb4..ea57fb41 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -30,9 +30,11 @@ func (db *PreparedStmtDB) Close() { db.Mux.Unlock() } -func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { +func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, + isTransaction bool, query string) (Stmt, error) { db.Mux.RLock() - if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || + isTransaction) { db.Mux.RUnlock() return stmt, nil } @@ -40,7 +42,8 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact db.Mux.Lock() // double check - if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) { + if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || + isTransaction) { db.Mux.Unlock() return stmt, nil } else if ok { @@ -57,7 +60,8 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact return db.Stmts[query], err } -func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { +func (db *PreparedStmtDB) BeginTx(ctx context.Context, + opt *sql.TxOptions) (ConnPool, error) { if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err @@ -65,7 +69,8 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn return nil, ErrInvalidTransaction } -func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { +func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, + args ...interface{}) (result sql.Result, err error) { stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) @@ -79,7 +84,8 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args .. return result, err } -func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { +func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, + args ...interface{}) (rows *sql.Rows, err error) { stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) @@ -93,7 +99,8 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args . return rows, err } -func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { +func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, + args ...interface{}) *sql.Row { stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { return stmt.QueryRowContext(ctx, args...) @@ -120,10 +127,12 @@ func (tx *PreparedStmtTX) Rollback() error { return ErrInvalidTransaction } -func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { +func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, + args ...interface{}) (result sql.Result, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { - result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) + result, err = tx.Tx.StmtContext(ctx, stmt.Stmt). + ExecContext(ctx, args...) if err != nil { tx.PreparedStmtDB.Mux.Lock() stmt.Close() @@ -134,7 +143,8 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args .. return result, err } -func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { +func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, + args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...) @@ -148,7 +158,8 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args . return rows, err } -func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { +func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, + args ...interface{}) *sql.Row { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) diff --git a/scan.go b/scan.go index acd637a4..24e9c2ad 100644 --- a/scan.go +++ b/scan.go @@ -10,19 +10,24 @@ import ( "gorm.io/gorm/schema" ) -func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { +func prepareValues(values []interface{}, db *DB, + columnTypes []*sql.ColumnType, columns []string) { if db.Statement.Schema != nil { for idx, name := range columns { - if field := db.Statement.Schema.LookUpField(name); field != nil { - values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface() + field := db.Statement.Schema.LookUpField(name) + if field != nil { + values[idx] = reflect.New(reflect.PtrTo(field.FieldType)). + Interface() continue } + values[idx] = new(interface{}) } } else if len(columnTypes) > 0 { for idx, columnType := range columnTypes { if columnType.ScanType() != nil { - values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface() + values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())). + Interface() } else { values[idx] = new(interface{}) } @@ -34,9 +39,14 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, } } -func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { +func scanIntoMap(mapValue map[string]interface{}, + values []interface{}, columns []string) { for idx, column := range columns { - if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { + reflectValue := reflect.Indirect( + reflect.Indirect(reflect.ValueOf(values[idx])), + ) + + if reflectValue.IsValid() { mapValue[column] = reflectValue.Interface() if valuer, ok := mapValue[column].(driver.Valuer); ok { mapValue[column], _ = valuer.Value() @@ -111,28 +121,42 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { reflectValueType = reflectValueType.Elem() } - db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20)) + db.Statement.ReflectValue.Set( + reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20), + ) if Schema != nil { - if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + if reflectValueType != Schema.ModelType && + reflectValueType.Kind() == reflect.Struct { + Schema, _ = schema.Parse(db.Statement.Dest, + db.cacheStore, db.NamingStrategy) } for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && + field.Readable { fields[idx] = field - } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - fields[idx] = field + } else if names := strings. + Split(column, "__"); len(names) > 1 { + rel, ok := Schema.Relationships.Relations[names[0]] + if ok { + field2 := rel.FieldSchema.LookUpField( + strings.Join(names[1:], "__"), + ) + if field2 != nil && field2.Readable { + fields[idx] = field2 if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) + joinFields = make([][2]*schema.Field, + len(columns)) } - joinFields[idx] = [2]*schema.Field{rel.Field, field} + + joinFields[idx] = [2]*schema.Field{rel.Field, + field2} continue } } + values[idx] = &sql.RawBytes{} } else { values[idx] = &sql.RawBytes{} @@ -143,9 +167,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { // pluck values into slice of data isPluck := false if len(fields) == 1 { - if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner - reflectValueType.Kind() != reflect.Struct || // is not struct - Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time + _, ok := reflect.New(reflectValueType). + Interface().(sql.Scanner) + // is scanner or is not struct or is time + if ok || reflectValueType.Kind() != reflect.Struct || + Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { isPluck = true } } @@ -160,7 +186,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } else { for idx, field := range fields { if field != nil { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + values[idx] = reflect.New( + reflect.PtrTo(field.IndirectFieldType), + ).Interface() } } @@ -171,11 +199,14 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { value := reflect.ValueOf(values[idx]).Elem() relValue := joinFields[idx][0].ReflectValueOf(elem) - if relValue.Kind() == reflect.Ptr && relValue.IsNil() { + if relValue.Kind() == reflect.Ptr && + relValue.IsNil() { if value.IsNil() { continue } - relValue.Set(reflect.New(relValue.Type().Elem())) + relValue.Set( + reflect.New(relValue.Type().Elem()), + ) } field.Set(relValue, values[idx]) @@ -186,24 +217,36 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { } if isPtr { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem)) + db.Statement.ReflectValue.Set(reflect. + Append(db.Statement.ReflectValue, elem)) } else { - db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem())) + db.Statement.ReflectValue.Set(reflect. + Append(db.Statement.ReflectValue, elem.Elem())) } } case reflect.Struct, reflect.Ptr: if db.Statement.ReflectValue.Type() != Schema.ModelType { - Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) + Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, + db.NamingStrategy) } if initialized || rows.Next() { for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + if field := Schema.LookUpField(column); field != nil && + field.Readable { + values[idx] = reflect.New( + reflect.PtrTo(field.IndirectFieldType), + ).Interface() } else if names := strings.Split(column, "__"); len(names) > 1 { - if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { - values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface() + rel, ok := Schema.Relationships.Relations[names[0]] + if ok { + field := rel.FieldSchema. + LookUpField(strings.Join(names[1:], "__")) + if field != nil && + field.Readable { + values[idx] = reflect.New( + reflect.PtrTo(field.IndirectFieldType), + ).Interface() continue } } @@ -217,11 +260,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { db.AddError(rows.Scan(values...)) for idx, column := range columns { - if field := Schema.LookUpField(column); field != nil && field.Readable { + if field := Schema.LookUpField(column); field != nil && + field.Readable { field.Set(db.Statement.ReflectValue, values[idx]) } else if names := strings.Split(column, "__"); len(names) > 1 { if rel, ok := Schema.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && + field.Readable { relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) value := reflect.ValueOf(values[idx]).Elem() diff --git a/soft_delete.go b/soft_delete.go index bdbf03c2..bad21e6a 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -65,10 +65,15 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok { if c, ok := stmt.Clauses["WHERE"]; ok { - if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 { + if where, ok := c.Expression.(clause.Where); ok && + len(where.Exprs) > 1 { for _, expr := range where.Exprs { - if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { - where.Exprs = []clause.Expression{clause.And(where.Exprs...)} + if orCond, ok := expr.(clause.OrConditions); ok && + len(orCond.Exprs) == 1 { + where.Exprs = []clause.Expression{ + clause.And(where.Exprs...), + } + c.Expression = where stmt.Clauses["WHERE"] = c break @@ -78,7 +83,11 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { } stmt.AddClause(clause.Where{Exprs: []clause.Expression{ - clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil}, + clause.Eq{Column: clause.Column{ + Table: clause.CurrentTable, + Name: sd.Field.DBName}, + Value: nil, + }, }}) stmt.Clauses["soft_delete_enabled"] = clause.Clause{} } @@ -105,28 +114,50 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.String() == "" { curTime := stmt.DB.NowFunc() - stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) + stmt.AddClause(clause.Set{ + { + Column: clause.Column{Name: sd.Field.DBName}, + Value: curTime, + }, + }) + stmt.SetColumn(sd.Field.DBName, curTime, true) if stmt.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields) - column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + _, queryValues := schema.GetIdentityFieldValuesMap( + stmt.ReflectValue, stmt.Schema.PrimaryFields, + ) + column, values := schema.ToQueryValues(stmt.Table, + stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + stmt.AddClause(clause.Where{Exprs: []clause.Expression{ + clause.IN{Column: column, Values: values}}}) } - if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) - column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) + if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && + stmt.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap( + reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields, + ) + column, values = schema.ToQueryValues( + stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues, + ) if len(values) > 0 { - stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + stmt.AddClause(clause.Where{ + Exprs: []clause.Expression{ + clause.IN{ + Column: column, Values: values, + }, + }, + }) } } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { + if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && + !ok { stmt.DB.AddError(ErrMissingWhereClause) } else { SoftDeleteQueryClause{Field: sd.Field}.ModifyStatement(stmt) diff --git a/statement.go b/statement.go index de1b300f..306dc633 100644 --- a/statement.go +++ b/statement.go @@ -104,7 +104,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { if stmt.Schema == nil { stmt.DB.AddError(ErrModelValueRequired) } else if stmt.Schema.PrioritizedPrimaryField != nil { - stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName) + stmt.DB.Dialector.QuoteTo(writer, + stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0]) } @@ -181,7 +182,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } case *DB: - subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() + subdb := v.Session(&Session{Logger: logger.Discard, + DryRun: true}).getInstance() subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) subdb.callbacks.Query().Execute(subdb) writer.WriteString(subdb.Statement.SQL.String()) @@ -230,7 +232,8 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { } // BuildCondition build condition -func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression { +func (stmt *Statement) BuildCondition(query interface{}, + args ...interface{}) []clause.Expression { if s, ok := query.(string); ok { // if it is a number, then treats it as primary key if _, err := strconv.Atoi(s); err != nil { @@ -262,10 +265,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { - if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { - where.Exprs[0] = clause.AndConditions{Exprs: orConds.Exprs} + orConds, ok := where.Exprs[0].(clause.OrConditions) + if ok { + where.Exprs[0] = clause.AndConditions{ + Exprs: orConds.Exprs, + } } } + conds = append(conds, clause.And(where.Exprs...)) } else if cs.Expression != nil { conds = append(conds, cs.Expression) @@ -297,16 +304,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{ + Column: key, + Value: v[key], + }) } else if _, ok := v[key].(Valuer); ok { - conds = append(conds, clause.Eq{Column: key, Value: v[key]}) + conds = append(conds, clause.Eq{ + Column: key, + Value: v[key], + }) } else { values := make([]interface{}, reflectValue.Len()) for i := 0; i < reflectValue.Len(); i++ { values[i] = reflectValue.Index(i).Interface() } - conds = append(conds, clause.IN{Column: key, Values: values}) + conds = append(conds, clause.IN{ + Column: key, Values: values, + }) } default: conds = append(conds, clause.Eq{Column: key, Value: v[key]}) @@ -314,7 +329,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } default: reflectValue := reflect.Indirect(reflect.ValueOf(arg)) - if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { + if s, err := schema.Parse(arg, stmt.DB.cacheStore, + stmt.DB.NamingStrategy); err == nil { selectedColumns := map[string]bool{} if idx == 0 { for _, v := range args[1:] { @@ -328,27 +344,56 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] switch reflectValue.Kind() { case reflect.Struct: for _, field := range s.Fields { - selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + selected := selectedColumns[field.DBName] || + selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue); !isZero || selected { + v, isZero := field.ValueOf(reflectValue) + if !isZero || selected { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{ + Column: clause.Column{ + Table: clause.CurrentTable, + Name: field.DBName, + }, + Value: v, + }) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{ + Column: clause.Column{ + Table: clause.CurrentTable, + Name: field.Name, + }, + Value: v, + }) } + } } } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { - selected := selectedColumns[field.DBName] || selectedColumns[field.Name] + selected := selectedColumns[field.DBName] || + selectedColumns[field.Name] if selected || (!restricted && field.Readable) { - if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { + v, isZero := field.ValueOf(reflectValue.Index(i)) + if !isZero || selected { if field.DBName != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) + conds = append(conds, clause.Eq{ + Column: clause.Column{ + Table: clause.CurrentTable, + Name: field.DBName, + }, + Value: v, + }) } else if field.DataType != "" { - conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) + conds = append(conds, clause.Eq{ + Column: clause.Column{ + Table: clause.CurrentTable, + Name: field.Name, + }, + Value: v, + }) } } } @@ -371,13 +416,20 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } if len(values) > 0 { - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) + conds = append(conds, clause.IN{ + Column: clause.PrimaryColumn, + Values: values, + }) } + return conds } } - conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) + conds = append(conds, clause.IN{ + Column: clause.PrimaryColumn, + Values: args, + }) } } } @@ -406,7 +458,9 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" { + stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, + stmt.DB.NamingStrategy) + if err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] @@ -415,6 +469,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) { stmt.Table = stmt.Schema.Table } + return err } @@ -463,7 +518,8 @@ func (stmt *Statement) clone() *Statement { // SetColumn set column's value // stmt.SetColumn("Name", "jinzhu") // Hooks Method // stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method -func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { +func (stmt *Statement) SetColumn(name string, value interface{}, + fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value } else if v, ok := stmt.Dest.([]map[string]interface{}); ok { @@ -524,7 +580,8 @@ func (stmt *Statement) Changed(fields ...string) bool { selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { fieldValue, _ := field.ValueOf(modelValue) - if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { + if v, ok := selectColumns[field.DBName]; (ok && v) || + (!ok && !restricted) { if v, ok := stmt.Dest.(map[string]interface{}); ok { if fv, ok := v[field.Name]; ok { return !utils.AssertEqual(fv, fieldValue) @@ -563,8 +620,10 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } -// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false -func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { +// SelectAndOmitColumns get select and omit columns, +// select -> true, omit -> false +func (stmt *Statement) SelectAndOmitColumns(requireCreate, + requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} notRestricted := false @@ -579,7 +638,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( for _, rel := range stmt.Schema.Relationships.Relations { results[rel.Name] = true } - } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { + } else if field := stmt.Schema.LookUpField(column); field != nil && + field.DBName != "" { results[field.DBName] = true } else { results[column] = true @@ -594,7 +654,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( results[rel.Name] = false } } - } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { + } else if field := stmt.Schema.LookUpField(omit); field != nil && + field.DBName != "" { results[field.DBName] = false } else { results[omit] = false