feat: gofmt code
This commit is contained in:
parent
699093ae6c
commit
0df6b19a8d
@ -24,10 +24,12 @@ func (db *DB) Association(column string) *Association {
|
||||
|
||||
if err := db.Statement.Parse(db.Statement.Model); err == nil {
|
||||
db.Statement.Table = table
|
||||
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
|
||||
association.Relationship = db.Statement.Schema.
|
||||
Relationships.Relations[column]
|
||||
|
||||
if association.Relationship == nil {
|
||||
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
|
||||
association.Error = fmt.Errorf("%w: %v",
|
||||
ErrUnsupportedRelation, column)
|
||||
}
|
||||
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
@ -41,9 +43,11 @@ func (db *DB) Association(column string) *Association {
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||
func (association *Association) Find(out interface{},
|
||||
conds ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||
association.Error = association.buildCondition().
|
||||
Find(out, conds...).Error
|
||||
}
|
||||
return association.Error
|
||||
}
|
||||
@ -80,10 +84,12 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
|
||||
association.Error = rel.Field.Set(reflectValue.Index(i),
|
||||
reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
|
||||
association.Error = rel.Field.Set(reflectValue,
|
||||
reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
@ -118,9 +124,11 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
|
||||
if _, pvs := schema.
|
||||
GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).
|
||||
UpdateColumns(updateMap).Error
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
@ -152,7 +160,8 @@ func (association *Association) Replace(values ...interface{}) error {
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
|
||||
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
|
||||
if relColumn, relValues := schema.
|
||||
ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
|
||||
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
|
||||
}
|
||||
|
||||
|
32
callbacks.go
32
callbacks.go
@ -82,8 +82,11 @@ func (p *processor) Execute(db *DB) {
|
||||
}
|
||||
|
||||
if stmt.Model != nil {
|
||||
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
|
||||
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
|
||||
if err := stmt.Parse(stmt.Model); err != nil &&
|
||||
(!errors.Is(err, schema.ErrUnsupportedDataType) ||
|
||||
(stmt.Table == "" && stmt.SQL.Len() == 0)) {
|
||||
if errors.Is(err, schema.ErrUnsupportedDataType) &&
|
||||
stmt.Table == "" {
|
||||
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
|
||||
} else {
|
||||
db.AddError(err)
|
||||
@ -163,7 +166,8 @@ func (p *processor) compile() (err error) {
|
||||
p.callbacks = callbacks
|
||||
|
||||
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
||||
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
||||
p.db.Logger.Error(context.Background(),
|
||||
"Got error when compile callbacks, got %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -186,7 +190,8 @@ func (c *callback) Register(name string, fn func(*DB)) error {
|
||||
}
|
||||
|
||||
func (c *callback) Remove(name string) error {
|
||||
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.processor.db.Logger.Warn(context.Background(),
|
||||
"removing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.remove = true
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
@ -194,7 +199,8 @@ func (c *callback) Remove(name string) error {
|
||||
}
|
||||
|
||||
func (c *callback) Replace(name string, fn func(*DB)) error {
|
||||
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.processor.db.Logger.Info(context.Background(),
|
||||
"replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.handler = fn
|
||||
c.replace = true
|
||||
@ -223,8 +229,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
|
||||
for _, c := range cs {
|
||||
// show warning message the callback name already exists
|
||||
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
||||
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
|
||||
if idx := getRIndex(names, c.name); idx > -1 && !c.replace &&
|
||||
!c.remove && !cs[idx].remove {
|
||||
c.processor.db.Logger.Warn(context.Background(),
|
||||
"duplicated callback `%v` from %v\n", c.name,
|
||||
utils.FileWithLineNum())
|
||||
}
|
||||
names = append(names, c.name)
|
||||
}
|
||||
@ -238,9 +247,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
// if before callback already sorted, append current callback just after it
|
||||
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||
sorted = append(sorted[:sortedIdx],
|
||||
append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||
} else if curIdx > sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
|
||||
return fmt.Errorf("conflicting callback %v with before %v",
|
||||
c.name, c.before)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.before); idx != -1 {
|
||||
// if before callback exists
|
||||
@ -258,7 +269,8 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
// if after callback sorted, append current callback to last
|
||||
sorted = append(sorted, c.name)
|
||||
} else if curIdx < sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
|
||||
return fmt.Errorf("conflicting callback %v with before %v",
|
||||
c.name, c.after)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.after); idx != -1 {
|
||||
// if after callback exists but haven't sorted
|
||||
|
@ -12,7 +12,8 @@ import (
|
||||
// Model specify the model you would like to run db operations
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
|
||||
// // if user's primary key is non-blank, will use it as condition,
|
||||
// then will only update the user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
@ -36,7 +37,8 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
}
|
||||
|
||||
if len(whereConds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
|
||||
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.
|
||||
BuildCondition(whereConds[0], whereConds[1:]...)})
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -46,9 +48,11 @@ var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
|
||||
// Table specify the table you would like to run db operations
|
||||
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
|
||||
if strings.Contains(name, " ") || strings.Contains(name, "`") ||
|
||||
len(args) > 0 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
|
||||
if results := tableRegexp.
|
||||
FindStringSubmatch(name); len(results) == 2 {
|
||||
tx.Statement.Table = results[1]
|
||||
return
|
||||
}
|
||||
@ -87,7 +91,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
case []string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
|
||||
default:
|
||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||
tx.AddError(fmt.Errorf("unsupported select args %v %v",
|
||||
query, args))
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -125,12 +130,14 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Omit specify fields that you want to ignore when creating, updating and querying
|
||||
// Omit specify fields that you want to ignore when creating,
|
||||
// updating and querying
|
||||
func (db *DB) Omit(columns ...string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0],
|
||||
utils.IsValidDBNameChar)
|
||||
} else {
|
||||
tx.Statement.Omits = columns
|
||||
}
|
||||
@ -150,7 +157,9 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
|
||||
tx.Statement.AddClause(clause.Where{
|
||||
Exprs: []clause.Expression{clause.Not(conds...)},
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -158,8 +167,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
// Or add OR conditions
|
||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
|
||||
if conds := tx.Statement.
|
||||
BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{
|
||||
Exprs: []clause.Expression{clause.Or(clause.And(conds...))}},
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -169,7 +181,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{
|
||||
Name: query, Conds: args,
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,8 @@ import (
|
||||
var (
|
||||
// ErrRecordNotFound record not found error
|
||||
ErrRecordNotFound = errors.New("record not found")
|
||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||
// ErrInvalidTransaction invalid transaction
|
||||
// when you are trying to `Commit` or `Rollback`
|
||||
ErrInvalidTransaction = errors.New("no valid transaction")
|
||||
// ErrNotImplemented not implemented
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
|
@ -67,7 +67,8 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Save update value in database, if the value doesn't have primary key, will insert it
|
||||
// Save update value in database, if the value doesn't have primary key,
|
||||
// will insert it
|
||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
@ -78,9 +79,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
||||
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
|
||||
}
|
||||
tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
|
||||
tx.callbacks.Create().
|
||||
Execute(tx.InstanceSet("gorm:update_track_time", true))
|
||||
case reflect.Struct:
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
if err := tx.Statement.Parse(value); err == nil &&
|
||||
tx.Statement.Schema != nil {
|
||||
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
if _, isZero := pf.ValueOf(reflectValue); isZero {
|
||||
tx.callbacks.Create().Execute(tx)
|
||||
@ -99,9 +102,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
|
||||
tx.callbacks.Update().Execute(tx)
|
||||
|
||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
|
||||
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun &&
|
||||
!selectedUpdate {
|
||||
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
|
||||
if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
|
||||
if err := tx.Session(&Session{}).
|
||||
First(result).Error; errors.Is(err, ErrRecordNotFound) {
|
||||
return tx.Create(value)
|
||||
}
|
||||
}
|
||||
@ -113,10 +118,15 @@ func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
// First find first record that match given conditions, order by primary key
|
||||
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
Column: clause.Column{
|
||||
Table: clause.CurrentTable,
|
||||
Name: clause.PrimaryKey,
|
||||
},
|
||||
})
|
||||
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
if exprs := tx.
|
||||
Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
@ -126,7 +136,8 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Take return a record that match given conditions, the order will depend on the database implementation
|
||||
// Take return a record that match given conditions, the order will
|
||||
// depend on the database implementation
|
||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1)
|
||||
if len(conds) > 0 {
|
||||
@ -198,8 +209,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
} else {
|
||||
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
primaryValue, _ := result.Statement.Schema.
|
||||
PrioritizedPrimaryField.
|
||||
ValueOf(resultsValue.Index(resultsValue.Len() - 1))
|
||||
queryDB = tx.Clauses(clause.Gt{
|
||||
Column: clause.Column{
|
||||
Table: clause.CurrentTable,
|
||||
Name: clause.PrimaryKey,
|
||||
},
|
||||
Value: primaryValue})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,9 +29,12 @@ type Plugin interface {
|
||||
// ConnPool db conns pool interface
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
ExecContext(ctx context.Context, query string,
|
||||
args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string,
|
||||
args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string,
|
||||
args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// SavePointerDialectorInterface save pointer interface
|
||||
|
@ -48,7 +48,8 @@ type Migrator interface {
|
||||
AddColumn(dst interface{}, field string) error
|
||||
DropColumn(dst interface{}, field string) error
|
||||
AlterColumn(dst interface{}, field string) error
|
||||
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
MigrateColumn(dst interface{}, field *schema.Field,
|
||||
columnType ColumnType) error
|
||||
HasColumn(dst interface{}, field string) bool
|
||||
RenameColumn(dst interface{}, oldName, field string) error
|
||||
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
||||
|
6
model.go
6
model.go
@ -2,8 +2,10 @@ package gorm
|
||||
|
||||
import "time"
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embedded into your model or you may build your own model without it
|
||||
// Model a basic GoLang struct which includes the following fields: ID,
|
||||
// CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embedded into your model
|
||||
// or you may build your own model without it
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
|
@ -30,9 +30,11 @@ func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Unlock()
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool,
|
||||
isTransaction bool, query string) (Stmt, error) {
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction ||
|
||||
isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
return stmt, nil
|
||||
}
|
||||
@ -40,7 +42,8 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
||||
|
||||
db.Mux.Lock()
|
||||
// double check
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction ||
|
||||
isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
return stmt, nil
|
||||
} else if ok {
|
||||
@ -57,7 +60,8 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
|
||||
return db.Stmts[query], err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context,
|
||||
opt *sql.TxOptions) (ConnPool, error) {
|
||||
if beginner, ok := db.ConnPool.(TxBeginner); ok {
|
||||
tx, err := beginner.BeginTx(ctx, opt)
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||
@ -65,7 +69,8 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string,
|
||||
args ...interface{}) (result sql.Result, err error) {
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
@ -79,7 +84,8 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string,
|
||||
args ...interface{}) (rows *sql.Rows, err error) {
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
@ -93,7 +99,8 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string,
|
||||
args ...interface{}) *sql.Row {
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
@ -120,10 +127,12 @@ func (tx *PreparedStmtTX) Rollback() error {
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string,
|
||||
args ...interface{}) (result sql.Result, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).
|
||||
ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
stmt.Close()
|
||||
@ -134,7 +143,8 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string,
|
||||
args ...interface{}) (rows *sql.Rows, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
|
||||
@ -148,7 +158,8 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string,
|
||||
args ...interface{}) *sql.Row {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
|
||||
|
109
scan.go
109
scan.go
@ -10,19 +10,24 @@ import (
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
||||
func prepareValues(values []interface{}, db *DB,
|
||||
columnTypes []*sql.ColumnType, columns []string) {
|
||||
if db.Statement.Schema != nil {
|
||||
for idx, name := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
field := db.Statement.Schema.LookUpField(name)
|
||||
if field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).
|
||||
Interface()
|
||||
continue
|
||||
}
|
||||
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
} else if len(columnTypes) > 0 {
|
||||
for idx, columnType := range columnTypes {
|
||||
if columnType.ScanType() != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
|
||||
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).
|
||||
Interface()
|
||||
} else {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
@ -34,9 +39,14 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
|
||||
}
|
||||
}
|
||||
|
||||
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
|
||||
func scanIntoMap(mapValue map[string]interface{},
|
||||
values []interface{}, columns []string) {
|
||||
for idx, column := range columns {
|
||||
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
|
||||
reflectValue := reflect.Indirect(
|
||||
reflect.Indirect(reflect.ValueOf(values[idx])),
|
||||
)
|
||||
|
||||
if reflectValue.IsValid() {
|
||||
mapValue[column] = reflectValue.Interface()
|
||||
if valuer, ok := mapValue[column].(driver.Valuer); ok {
|
||||
mapValue[column], _ = valuer.Value()
|
||||
@ -111,28 +121,42 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20))
|
||||
db.Statement.ReflectValue.Set(
|
||||
reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20),
|
||||
)
|
||||
|
||||
if Schema != nil {
|
||||
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
if reflectValueType != Schema.ModelType &&
|
||||
reflectValueType.Kind() == reflect.Struct {
|
||||
Schema, _ = schema.Parse(db.Statement.Dest,
|
||||
db.cacheStore, db.NamingStrategy)
|
||||
}
|
||||
|
||||
for idx, column := range columns {
|
||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||
if field := Schema.LookUpField(column); field != nil &&
|
||||
field.Readable {
|
||||
fields[idx] = field
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
} else if names := strings.
|
||||
Split(column, "__"); len(names) > 1 {
|
||||
rel, ok := Schema.Relationships.Relations[names[0]]
|
||||
if ok {
|
||||
field2 := rel.FieldSchema.LookUpField(
|
||||
strings.Join(names[1:], "__"),
|
||||
)
|
||||
if field2 != nil && field2.Readable {
|
||||
fields[idx] = field2
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][2]*schema.Field, len(columns))
|
||||
joinFields = make([][2]*schema.Field,
|
||||
len(columns))
|
||||
}
|
||||
joinFields[idx] = [2]*schema.Field{rel.Field, field}
|
||||
|
||||
joinFields[idx] = [2]*schema.Field{rel.Field,
|
||||
field2}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
@ -143,9 +167,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
// pluck values into slice of data
|
||||
isPluck := false
|
||||
if len(fields) == 1 {
|
||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
|
||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||
_, ok := reflect.New(reflectValueType).
|
||||
Interface().(sql.Scanner)
|
||||
// is scanner or is not struct or is time
|
||||
if ok || reflectValueType.Kind() != reflect.Struct ||
|
||||
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) {
|
||||
isPluck = true
|
||||
}
|
||||
}
|
||||
@ -160,7 +186,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
} else {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
values[idx] = reflect.New(
|
||||
reflect.PtrTo(field.IndirectFieldType),
|
||||
).Interface()
|
||||
}
|
||||
}
|
||||
|
||||
@ -171,11 +199,14 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
value := reflect.ValueOf(values[idx]).Elem()
|
||||
relValue := joinFields[idx][0].ReflectValueOf(elem)
|
||||
|
||||
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
|
||||
if relValue.Kind() == reflect.Ptr &&
|
||||
relValue.IsNil() {
|
||||
if value.IsNil() {
|
||||
continue
|
||||
}
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
relValue.Set(
|
||||
reflect.New(relValue.Type().Elem()),
|
||||
)
|
||||
}
|
||||
|
||||
field.Set(relValue, values[idx])
|
||||
@ -186,24 +217,36 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
|
||||
db.Statement.ReflectValue.Set(reflect.
|
||||
Append(db.Statement.ReflectValue, elem))
|
||||
} else {
|
||||
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem()))
|
||||
db.Statement.ReflectValue.Set(reflect.
|
||||
Append(db.Statement.ReflectValue, elem.Elem()))
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if db.Statement.ReflectValue.Type() != Schema.ModelType {
|
||||
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore,
|
||||
db.NamingStrategy)
|
||||
}
|
||||
|
||||
if initialized || rows.Next() {
|
||||
for idx, column := range columns {
|
||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
if field := Schema.LookUpField(column); field != nil &&
|
||||
field.Readable {
|
||||
values[idx] = reflect.New(
|
||||
reflect.PtrTo(field.IndirectFieldType),
|
||||
).Interface()
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
|
||||
rel, ok := Schema.Relationships.Relations[names[0]]
|
||||
if ok {
|
||||
field := rel.FieldSchema.
|
||||
LookUpField(strings.Join(names[1:], "__"))
|
||||
if field != nil &&
|
||||
field.Readable {
|
||||
values[idx] = reflect.New(
|
||||
reflect.PtrTo(field.IndirectFieldType),
|
||||
).Interface()
|
||||
continue
|
||||
}
|
||||
}
|
||||
@ -217,11 +260,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
for idx, column := range columns {
|
||||
if field := Schema.LookUpField(column); field != nil && field.Readable {
|
||||
if field := Schema.LookUpField(column); field != nil &&
|
||||
field.Readable {
|
||||
field.Set(db.Statement.ReflectValue, values[idx])
|
||||
} else if names := strings.Split(column, "__"); len(names) > 1 {
|
||||
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
|
||||
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil &&
|
||||
field.Readable {
|
||||
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
|
||||
value := reflect.ValueOf(values[idx]).Elem()
|
||||
|
||||
|
@ -65,10 +65,15 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
|
||||
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
|
||||
if c, ok := stmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
|
||||
if where, ok := c.Expression.(clause.Where); ok &&
|
||||
len(where.Exprs) > 1 {
|
||||
for _, expr := range where.Exprs {
|
||||
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
|
||||
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
|
||||
if orCond, ok := expr.(clause.OrConditions); ok &&
|
||||
len(orCond.Exprs) == 1 {
|
||||
where.Exprs = []clause.Expression{
|
||||
clause.And(where.Exprs...),
|
||||
}
|
||||
|
||||
c.Expression = where
|
||||
stmt.Clauses["WHERE"] = c
|
||||
break
|
||||
@ -78,7 +83,11 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||
}
|
||||
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
|
||||
clause.Eq{Column: clause.Column{
|
||||
Table: clause.CurrentTable,
|
||||
Name: sd.Field.DBName},
|
||||
Value: nil,
|
||||
},
|
||||
}})
|
||||
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
||||
}
|
||||
@ -105,28 +114,50 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
|
||||
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.String() == "" {
|
||||
curTime := stmt.DB.NowFunc()
|
||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
|
||||
stmt.AddClause(clause.Set{
|
||||
{
|
||||
Column: clause.Column{Name: sd.Field.DBName},
|
||||
Value: curTime,
|
||||
},
|
||||
})
|
||||
|
||||
stmt.SetColumn(sd.Field.DBName, curTime, true)
|
||||
|
||||
if stmt.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(
|
||||
stmt.ReflectValue, stmt.Schema.PrimaryFields,
|
||||
)
|
||||
column, values := schema.ToQueryValues(stmt.Table,
|
||||
stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||
clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model &&
|
||||
stmt.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(
|
||||
reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields,
|
||||
)
|
||||
|
||||
column, values = schema.ToQueryValues(
|
||||
stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues,
|
||||
)
|
||||
if len(values) > 0 {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
stmt.AddClause(clause.Where{
|
||||
Exprs: []clause.Expression{
|
||||
clause.IN{
|
||||
Column: column, Values: values,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
|
||||
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate &&
|
||||
!ok {
|
||||
stmt.DB.AddError(ErrMissingWhereClause)
|
||||
} else {
|
||||
SoftDeleteQueryClause{Field: sd.Field}.ModifyStatement(stmt)
|
||||
|
113
statement.go
113
statement.go
@ -104,7 +104,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||
if stmt.Schema == nil {
|
||||
stmt.DB.AddError(ErrModelValueRequired)
|
||||
} else if stmt.Schema.PrioritizedPrimaryField != nil {
|
||||
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||
stmt.DB.Dialector.QuoteTo(writer,
|
||||
stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||
} else if len(stmt.Schema.DBNames) > 0 {
|
||||
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0])
|
||||
}
|
||||
@ -181,7 +182,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
writer.WriteString("(NULL)")
|
||||
}
|
||||
case *DB:
|
||||
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||
subdb := v.Session(&Session{Logger: logger.Discard,
|
||||
DryRun: true}).getInstance()
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
|
||||
subdb.callbacks.Query().Execute(subdb)
|
||||
writer.WriteString(subdb.Statement.SQL.String())
|
||||
@ -230,7 +232,8 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
||||
}
|
||||
|
||||
// BuildCondition build condition
|
||||
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
|
||||
func (stmt *Statement) BuildCondition(query interface{},
|
||||
args ...interface{}) []clause.Expression {
|
||||
if s, ok := query.(string); ok {
|
||||
// if it is a number, then treats it as primary key
|
||||
if _, err := strconv.Atoi(s); err != nil {
|
||||
@ -262,10 +265,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
if len(where.Exprs) == 1 {
|
||||
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
||||
where.Exprs[0] = clause.AndConditions{Exprs: orConds.Exprs}
|
||||
orConds, ok := where.Exprs[0].(clause.OrConditions)
|
||||
if ok {
|
||||
where.Exprs[0] = clause.AndConditions{
|
||||
Exprs: orConds.Exprs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conds = append(conds, clause.And(where.Exprs...))
|
||||
} else if cs.Expression != nil {
|
||||
conds = append(conds, cs.Expression)
|
||||
@ -297,16 +304,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := v[key].(driver.Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: key,
|
||||
Value: v[key],
|
||||
})
|
||||
} else if _, ok := v[key].(Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: key,
|
||||
Value: v[key],
|
||||
})
|
||||
} else {
|
||||
values := make([]interface{}, reflectValue.Len())
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
values[i] = reflectValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: key, Values: values})
|
||||
conds = append(conds, clause.IN{
|
||||
Column: key, Values: values,
|
||||
})
|
||||
}
|
||||
default:
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
@ -314,7 +329,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
default:
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
|
||||
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||
if s, err := schema.Parse(arg, stmt.DB.cacheStore,
|
||||
stmt.DB.NamingStrategy); err == nil {
|
||||
selectedColumns := map[string]bool{}
|
||||
if idx == 0 {
|
||||
for _, v := range args[1:] {
|
||||
@ -328,27 +344,56 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, field := range s.Fields {
|
||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||
selected := selectedColumns[field.DBName] ||
|
||||
selectedColumns[field.Name]
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
|
||||
v, isZero := field.ValueOf(reflectValue)
|
||||
if !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{
|
||||
Table: clause.CurrentTable,
|
||||
Name: field.DBName,
|
||||
},
|
||||
Value: v,
|
||||
})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{
|
||||
Table: clause.CurrentTable,
|
||||
Name: field.Name,
|
||||
},
|
||||
Value: v,
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
for _, field := range s.Fields {
|
||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||
selected := selectedColumns[field.DBName] ||
|
||||
selectedColumns[field.Name]
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
|
||||
v, isZero := field.ValueOf(reflectValue.Index(i))
|
||||
if !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{
|
||||
Table: clause.CurrentTable,
|
||||
Name: field.DBName,
|
||||
},
|
||||
Value: v,
|
||||
})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{
|
||||
Table: clause.CurrentTable,
|
||||
Name: field.Name,
|
||||
},
|
||||
Value: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -371,13 +416,20 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
||||
conds = append(conds, clause.IN{
|
||||
Column: clause.PrimaryColumn,
|
||||
Values: values,
|
||||
})
|
||||
}
|
||||
|
||||
return conds
|
||||
}
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
||||
conds = append(conds, clause.IN{
|
||||
Column: clause.PrimaryColumn,
|
||||
Values: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -406,7 +458,9 @@ func (stmt *Statement) Build(clauses ...string) {
|
||||
}
|
||||
|
||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
|
||||
stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore,
|
||||
stmt.DB.NamingStrategy)
|
||||
if err == nil && stmt.Table == "" {
|
||||
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
|
||||
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
|
||||
stmt.Table = tables[1]
|
||||
@ -415,6 +469,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
|
||||
stmt.Table = stmt.Schema.Table
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -463,7 +518,8 @@ func (stmt *Statement) clone() *Statement {
|
||||
// SetColumn set column's value
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
||||
func (stmt *Statement) SetColumn(name string, value interface{},
|
||||
fromCallbacks ...bool) {
|
||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||
v[name] = value
|
||||
} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
|
||||
@ -524,7 +580,8 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
||||
changed := func(field *schema.Field) bool {
|
||||
fieldValue, _ := field.ValueOf(modelValue)
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) ||
|
||||
(!ok && !restricted) {
|
||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||
if fv, ok := v[field.Name]; ok {
|
||||
return !utils.AssertEqual(fv, fieldValue)
|
||||
@ -563,8 +620,10 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||
// SelectAndOmitColumns get select and omit columns,
|
||||
// select -> true, omit -> false
|
||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate,
|
||||
requireUpdate bool) (map[string]bool, bool) {
|
||||
results := map[string]bool{}
|
||||
notRestricted := false
|
||||
|
||||
@ -579,7 +638,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = true
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil &&
|
||||
field.DBName != "" {
|
||||
results[field.DBName] = true
|
||||
} else {
|
||||
results[column] = true
|
||||
@ -594,7 +654,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
|
||||
results[rel.Name] = false
|
||||
}
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
|
||||
} else if field := stmt.Schema.LookUpField(omit); field != nil &&
|
||||
field.DBName != "" {
|
||||
results[field.DBName] = false
|
||||
} else {
|
||||
results[omit] = false
|
||||
|
Loading…
x
Reference in New Issue
Block a user