feat: gofmt code

This commit is contained in:
daheige 2021-01-31 17:31:04 +08:00
parent 699093ae6c
commit 0df6b19a8d
12 changed files with 338 additions and 129 deletions

View File

@ -24,10 +24,12 @@ func (db *DB) Association(column string) *Association {
if err := db.Statement.Parse(db.Statement.Model); err == nil {
db.Statement.Table = table
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
association.Relationship = db.Statement.Schema.
Relationships.Relations[column]
if association.Relationship == nil {
association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
association.Error = fmt.Errorf("%w: %v",
ErrUnsupportedRelation, column)
}
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
@ -41,9 +43,11 @@ func (db *DB) Association(column string) *Association {
return association
}
func (association *Association) Find(out interface{}, conds ...interface{}) error {
func (association *Association) Find(out interface{},
conds ...interface{}) error {
if association.Error == nil {
association.Error = association.buildCondition().Find(out, conds...).Error
association.Error = association.buildCondition().
Find(out, conds...).Error
}
return association.Error
}
@ -80,10 +84,12 @@ func (association *Association) Replace(values ...interface{}) error {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
association.Error = rel.Field.Set(reflectValue.Index(i),
reflect.Zero(rel.Field.FieldType).Interface())
}
case reflect.Struct:
association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
association.Error = rel.Field.Set(reflectValue,
reflect.Zero(rel.Field.FieldType).Interface())
}
for _, ref := range rel.References {
@ -118,9 +124,11 @@ func (association *Association) Replace(values ...interface{}) error {
}
}
if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
if _, pvs := schema.
GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
association.Error = tx.Where(clause.IN{Column: column, Values: values}).
UpdateColumns(updateMap).Error
}
case schema.Many2Many:
var (
@ -152,7 +160,8 @@ func (association *Association) Replace(values ...interface{}) error {
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
if relColumn, relValues := schema.
ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}

View File

@ -82,8 +82,11 @@ func (p *processor) Execute(db *DB) {
}
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
if err := stmt.Parse(stmt.Model); err != nil &&
(!errors.Is(err, schema.ErrUnsupportedDataType) ||
(stmt.Table == "" && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) &&
stmt.Table == "" {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
@ -163,7 +166,8 @@ func (p *processor) compile() (err error) {
p.callbacks = callbacks
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
p.db.Logger.Error(context.Background(),
"Got error when compile callbacks, got %v", err)
}
return
}
@ -186,7 +190,8 @@ func (c *callback) Register(name string, fn func(*DB)) error {
}
func (c *callback) Remove(name string) error {
c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.processor.db.Logger.Warn(context.Background(),
"removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
@ -194,7 +199,8 @@ func (c *callback) Remove(name string) error {
}
func (c *callback) Replace(name string, fn func(*DB)) error {
c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.processor.db.Logger.Info(context.Background(),
"replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.handler = fn
c.replace = true
@ -223,8 +229,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
for _, c := range cs {
// show warning message the callback name already exists
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
if idx := getRIndex(names, c.name); idx > -1 && !c.replace &&
!c.remove && !cs[idx].remove {
c.processor.db.Logger.Warn(context.Background(),
"duplicated callback `%v` from %v\n", c.name,
utils.FileWithLineNum())
}
names = append(names, c.name)
}
@ -238,9 +247,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
sorted = append(sorted[:sortedIdx],
append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
return fmt.Errorf("conflicting callback %v with before %v",
c.name, c.before)
}
} else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists
@ -258,7 +269,8 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
// if after callback sorted, append current callback to last
sorted = append(sorted, c.name)
} else if curIdx < sortedIdx {
return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
return fmt.Errorf("conflicting callback %v with before %v",
c.name, c.after)
}
} else if idx := getRIndex(names, c.after); idx != -1 {
// if after callback exists but haven't sorted

View File

@ -12,7 +12,8 @@ import (
// Model specify the model you would like to run db operations
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
// // if user's primary key is non-blank, will use it as condition,
// then will only update the user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance()
@ -36,7 +37,8 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
}
if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.
BuildCondition(whereConds[0], whereConds[1:]...)})
}
return
}
@ -46,9 +48,11 @@ var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
// Table specify the table you would like to run db operations
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
if strings.Contains(name, " ") || strings.Contains(name, "`") ||
len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
if results := tableRegexp.
FindStringSubmatch(name); len(results) == 2 {
tx.Statement.Table = results[1]
return
}
@ -87,7 +91,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
case []string:
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
tx.AddError(fmt.Errorf("unsupported select args %v %v",
query, args))
return
}
}
@ -125,12 +130,14 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
return
}
// Omit specify fields that you want to ignore when creating, updating and querying
// Omit specify fields that you want to ignore when creating,
// updating and querying
func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
tx.Statement.Omits = strings.FieldsFunc(columns[0],
utils.IsValidDBNameChar)
} else {
tx.Statement.Omits = columns
}
@ -150,7 +157,9 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
tx.Statement.AddClause(clause.Where{
Exprs: []clause.Expression{clause.Not(conds...)},
})
}
return
}
@ -158,8 +167,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
// Or add OR conditions
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
if conds := tx.Statement.
BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{
Exprs: []clause.Expression{clause.Or(clause.And(conds...))}},
)
}
return
}
@ -169,7 +181,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
tx.Statement.Joins = append(tx.Statement.Joins, join{
Name: query, Conds: args,
})
return
}

View File

@ -7,7 +7,8 @@ import (
var (
// ErrRecordNotFound record not found error
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
// ErrInvalidTransaction invalid transaction
// when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction")
// ErrNotImplemented not implemented
ErrNotImplemented = errors.New("not implemented")

View File

@ -67,7 +67,8 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
return
}
// Save update value in database, if the value doesn't have primary key, will insert it
// Save update value in database, if the value doesn't have primary key,
// will insert it
func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
@ -78,9 +79,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
}
tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
tx.callbacks.Create().
Execute(tx.InstanceSet("gorm:update_track_time", true))
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
if err := tx.Statement.Parse(value); err == nil &&
tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(reflectValue); isZero {
tx.callbacks.Create().Execute(tx)
@ -99,9 +102,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
tx.callbacks.Update().Execute(tx)
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun &&
!selectedUpdate {
result := reflect.New(tx.Statement.Schema.ModelType).Interface()
if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
if err := tx.Session(&Session{}).
First(result).Error; errors.Is(err, ErrRecordNotFound) {
return tx.Create(value)
}
}
@ -113,10 +118,15 @@ func (db *DB) Save(value interface{}) (tx *DB) {
// First find first record that match given conditions, order by primary key
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Column: clause.Column{
Table: clause.CurrentTable,
Name: clause.PrimaryKey,
},
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
if exprs := tx.
Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
@ -126,7 +136,8 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
return
}
// Take return a record that match given conditions, the order will depend on the database implementation
// Take return a record that match given conditions, the order will
// depend on the database implementation
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1)
if len(conds) > 0 {
@ -198,8 +209,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
tx.AddError(ErrPrimaryKeyRequired)
break
} else {
primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
primaryValue, _ := result.Statement.Schema.
PrioritizedPrimaryField.
ValueOf(resultsValue.Index(resultsValue.Len() - 1))
queryDB = tx.Clauses(clause.Gt{
Column: clause.Column{
Table: clause.CurrentTable,
Name: clause.PrimaryKey,
},
Value: primaryValue})
}
}
}

View File

@ -29,9 +29,12 @@ type Plugin interface {
// ConnPool db conns pool interface
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string,
args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string,
args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string,
args ...interface{}) *sql.Row
}
// SavePointerDialectorInterface save pointer interface

View File

@ -48,7 +48,8 @@ type Migrator interface {
AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
MigrateColumn(dst interface{}, field *schema.Field,
columnType ColumnType) error
HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]ColumnType, error)

View File

@ -2,8 +2,10 @@ package gorm
import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model or you may build your own model without it
// Model a basic GoLang struct which includes the following fields: ID,
// CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model
// or you may build your own model without it
// type User struct {
// gorm.Model
// }

View File

@ -30,9 +30,11 @@ func (db *PreparedStmtDB) Close() {
db.Mux.Unlock()
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool,
isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction ||
isTransaction) {
db.Mux.RUnlock()
return stmt, nil
}
@ -40,7 +42,8 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction ||
isTransaction) {
db.Mux.Unlock()
return stmt, nil
} else if ok {
@ -57,7 +60,8 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
return db.Stmts[query], err
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
func (db *PreparedStmtDB) BeginTx(ctx context.Context,
opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
@ -65,7 +69,8 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
return nil, ErrInvalidTransaction
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string,
args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
@ -79,7 +84,8 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
return result, err
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string,
args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
@ -93,7 +99,8 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
return rows, err
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string,
args ...interface{}) *sql.Row {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
return stmt.QueryRowContext(ctx, args...)
@ -120,10 +127,12 @@ func (tx *PreparedStmtTX) Rollback() error {
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string,
args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).
ExecContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
stmt.Close()
@ -134,7 +143,8 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
return result, err
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string,
args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
@ -148,7 +158,8 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
return rows, err
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string,
args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)

109
scan.go
View File

@ -10,19 +10,24 @@ import (
"gorm.io/gorm/schema"
)
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
func prepareValues(values []interface{}, db *DB,
columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
field := db.Statement.Schema.LookUpField(name)
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).
Interface()
continue
}
values[idx] = new(interface{})
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).
Interface()
} else {
values[idx] = new(interface{})
}
@ -34,9 +39,14 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
}
}
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
func scanIntoMap(mapValue map[string]interface{},
values []interface{}, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
reflectValue := reflect.Indirect(
reflect.Indirect(reflect.ValueOf(values[idx])),
)
if reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
@ -111,28 +121,42 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
reflectValueType = reflectValueType.Elem()
}
db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20))
db.Statement.ReflectValue.Set(
reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20),
)
if Schema != nil {
if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
if reflectValueType != Schema.ModelType &&
reflectValueType.Kind() == reflect.Struct {
Schema, _ = schema.Parse(db.Statement.Dest,
db.cacheStore, db.NamingStrategy)
}
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
if field := Schema.LookUpField(column); field != nil &&
field.Readable {
fields[idx] = field
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
fields[idx] = field
} else if names := strings.
Split(column, "__"); len(names) > 1 {
rel, ok := Schema.Relationships.Relations[names[0]]
if ok {
field2 := rel.FieldSchema.LookUpField(
strings.Join(names[1:], "__"),
)
if field2 != nil && field2.Readable {
fields[idx] = field2
if len(joinFields) == 0 {
joinFields = make([][2]*schema.Field, len(columns))
joinFields = make([][2]*schema.Field,
len(columns))
}
joinFields[idx] = [2]*schema.Field{rel.Field, field}
joinFields[idx] = [2]*schema.Field{rel.Field,
field2}
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
@ -143,9 +167,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
// pluck values into slice of data
isPluck := false
if len(fields) == 1 {
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
_, ok := reflect.New(reflectValueType).
Interface().(sql.Scanner)
// is scanner or is not struct or is time
if ok || reflectValueType.Kind() != reflect.Struct ||
Schema.ModelType.ConvertibleTo(schema.TimeReflectType) {
isPluck = true
}
}
@ -160,7 +186,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
} else {
for idx, field := range fields {
if field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
values[idx] = reflect.New(
reflect.PtrTo(field.IndirectFieldType),
).Interface()
}
}
@ -171,11 +199,14 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
value := reflect.ValueOf(values[idx]).Elem()
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
if relValue.Kind() == reflect.Ptr &&
relValue.IsNil() {
if value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem()))
relValue.Set(
reflect.New(relValue.Type().Elem()),
)
}
field.Set(relValue, values[idx])
@ -186,24 +217,36 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
}
if isPtr {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
db.Statement.ReflectValue.Set(reflect.
Append(db.Statement.ReflectValue, elem))
} else {
db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem()))
db.Statement.ReflectValue.Set(reflect.
Append(db.Statement.ReflectValue, elem.Elem()))
}
}
case reflect.Struct, reflect.Ptr:
if db.Statement.ReflectValue.Type() != Schema.ModelType {
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore,
db.NamingStrategy)
}
if initialized || rows.Next() {
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
if field := Schema.LookUpField(column); field != nil &&
field.Readable {
values[idx] = reflect.New(
reflect.PtrTo(field.IndirectFieldType),
).Interface()
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
rel, ok := Schema.Relationships.Relations[names[0]]
if ok {
field := rel.FieldSchema.
LookUpField(strings.Join(names[1:], "__"))
if field != nil &&
field.Readable {
values[idx] = reflect.New(
reflect.PtrTo(field.IndirectFieldType),
).Interface()
continue
}
}
@ -217,11 +260,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
db.AddError(rows.Scan(values...))
for idx, column := range columns {
if field := Schema.LookUpField(column); field != nil && field.Readable {
if field := Schema.LookUpField(column); field != nil &&
field.Readable {
field.Set(db.Statement.ReflectValue, values[idx])
} else if names := strings.Split(column, "__"); len(names) > 1 {
if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil &&
field.Readable {
relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
value := reflect.ValueOf(values[idx]).Elem()

View File

@ -65,10 +65,15 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
if c, ok := stmt.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
if where, ok := c.Expression.(clause.Where); ok &&
len(where.Exprs) > 1 {
for _, expr := range where.Exprs {
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
if orCond, ok := expr.(clause.OrConditions); ok &&
len(orCond.Exprs) == 1 {
where.Exprs = []clause.Expression{
clause.And(where.Exprs...),
}
c.Expression = where
stmt.Clauses["WHERE"] = c
break
@ -78,7 +83,11 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
clause.Eq{Column: clause.Column{
Table: clause.CurrentTable,
Name: sd.Field.DBName},
Value: nil,
},
}})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
}
@ -105,28 +114,50 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.String() == "" {
curTime := stmt.DB.NowFunc()
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
stmt.AddClause(clause.Set{
{
Column: clause.Column{Name: sd.Field.DBName},
Value: curTime,
},
})
stmt.SetColumn(sd.Field.DBName, curTime, true)
if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
_, queryValues := schema.GetIdentityFieldValuesMap(
stmt.ReflectValue, stmt.Schema.PrimaryFields,
)
column, values := schema.ToQueryValues(stmt.Table,
stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.IN{Column: column, Values: values}}})
}
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model &&
stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(
reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields,
)
column, values = schema.ToQueryValues(
stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues,
)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
stmt.AddClause(clause.Where{
Exprs: []clause.Expression{
clause.IN{
Column: column, Values: values,
},
},
})
}
}
}
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate &&
!ok {
stmt.DB.AddError(ErrMissingWhereClause)
} else {
SoftDeleteQueryClause{Field: sd.Field}.ModifyStatement(stmt)

View File

@ -104,7 +104,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
if stmt.Schema == nil {
stmt.DB.AddError(ErrModelValueRequired)
} else if stmt.Schema.PrioritizedPrimaryField != nil {
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
stmt.DB.Dialector.QuoteTo(writer,
stmt.Schema.PrioritizedPrimaryField.DBName)
} else if len(stmt.Schema.DBNames) > 0 {
stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0])
}
@ -181,7 +182,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
writer.WriteString("(NULL)")
}
case *DB:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
subdb := v.Session(&Session{Logger: logger.Discard,
DryRun: true}).getInstance()
subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
subdb.callbacks.Query().Execute(subdb)
writer.WriteString(subdb.Statement.SQL.String())
@ -230,7 +232,8 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
}
// BuildCondition build condition
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
func (stmt *Statement) BuildCondition(query interface{},
args ...interface{}) []clause.Expression {
if s, ok := query.(string); ok {
// if it is a number, then treats it as primary key
if _, err := strconv.Atoi(s); err != nil {
@ -262,10 +265,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
where.Exprs[0] = clause.AndConditions{Exprs: orConds.Exprs}
orConds, ok := where.Exprs[0].(clause.OrConditions)
if ok {
where.Exprs[0] = clause.AndConditions{
Exprs: orConds.Exprs,
}
}
}
conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
conds = append(conds, cs.Expression)
@ -297,16 +304,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{
Column: key,
Value: v[key],
})
} else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
conds = append(conds, clause.Eq{
Column: key,
Value: v[key],
})
} else {
values := make([]interface{}, reflectValue.Len())
for i := 0; i < reflectValue.Len(); i++ {
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: key, Values: values})
conds = append(conds, clause.IN{
Column: key, Values: values,
})
}
default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
@ -314,7 +329,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
}
default:
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
if s, err := schema.Parse(arg, stmt.DB.cacheStore,
stmt.DB.NamingStrategy); err == nil {
selectedColumns := map[string]bool{}
if idx == 0 {
for _, v := range args[1:] {
@ -328,27 +344,56 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
switch reflectValue.Kind() {
case reflect.Struct:
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
selected := selectedColumns[field.DBName] ||
selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
v, isZero := field.ValueOf(reflectValue)
if !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{
Column: clause.Column{
Table: clause.CurrentTable,
Name: field.DBName,
},
Value: v,
})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{
Column: clause.Column{
Table: clause.CurrentTable,
Name: field.Name,
},
Value: v,
})
}
}
}
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
selected := selectedColumns[field.DBName] ||
selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
v, isZero := field.ValueOf(reflectValue.Index(i))
if !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
conds = append(conds, clause.Eq{
Column: clause.Column{
Table: clause.CurrentTable,
Name: field.DBName,
},
Value: v,
})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
conds = append(conds, clause.Eq{
Column: clause.Column{
Table: clause.CurrentTable,
Name: field.Name,
},
Value: v,
})
}
}
}
@ -371,13 +416,20 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
}
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
conds = append(conds, clause.IN{
Column: clause.PrimaryColumn,
Values: values,
})
}
return conds
}
}
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
conds = append(conds, clause.IN{
Column: clause.PrimaryColumn,
Values: args,
})
}
}
}
@ -406,7 +458,9 @@ func (stmt *Statement) Build(clauses ...string) {
}
func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore,
stmt.DB.NamingStrategy)
if err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]
@ -415,6 +469,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
stmt.Table = stmt.Schema.Table
}
return err
}
@ -463,7 +518,8 @@ func (stmt *Statement) clone() *Statement {
// SetColumn set column's value
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
func (stmt *Statement) SetColumn(name string, value interface{},
fromCallbacks ...bool) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
v[name] = value
} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
@ -524,7 +580,8 @@ func (stmt *Statement) Changed(fields ...string) bool {
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if v, ok := selectColumns[field.DBName]; (ok && v) ||
(!ok && !restricted) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
if fv, ok := v[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue)
@ -563,8 +620,10 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false
}
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
// SelectAndOmitColumns get select and omit columns,
// select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate,
requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
@ -579,7 +638,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = true
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
} else if field := stmt.Schema.LookUpField(column); field != nil &&
field.DBName != "" {
results[field.DBName] = true
} else {
results[column] = true
@ -594,7 +654,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
results[rel.Name] = false
}
}
} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
} else if field := stmt.Schema.LookUpField(omit); field != nil &&
field.DBName != "" {
results[field.DBName] = false
} else {
results[omit] = false