feat: gofmt code
This commit is contained in:
		
							parent
							
								
									699093ae6c
								
							
						
					
					
						commit
						0df6b19a8d
					
				@ -24,10 +24,12 @@ func (db *DB) Association(column string) *Association {
 | 
			
		||||
 | 
			
		||||
	if err := db.Statement.Parse(db.Statement.Model); err == nil {
 | 
			
		||||
		db.Statement.Table = table
 | 
			
		||||
		association.Relationship = db.Statement.Schema.Relationships.Relations[column]
 | 
			
		||||
		association.Relationship = db.Statement.Schema.
 | 
			
		||||
			Relationships.Relations[column]
 | 
			
		||||
 | 
			
		||||
		if association.Relationship == nil {
 | 
			
		||||
			association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
 | 
			
		||||
			association.Error = fmt.Errorf("%w: %v",
 | 
			
		||||
				ErrUnsupportedRelation, column)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
 | 
			
		||||
@ -41,9 +43,11 @@ func (db *DB) Association(column string) *Association {
 | 
			
		||||
	return association
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
 | 
			
		||||
func (association *Association) Find(out interface{},
 | 
			
		||||
	conds ...interface{}) error {
 | 
			
		||||
	if association.Error == nil {
 | 
			
		||||
		association.Error = association.buildCondition().Find(out, conds...).Error
 | 
			
		||||
		association.Error = association.buildCondition().
 | 
			
		||||
			Find(out, conds...).Error
 | 
			
		||||
	}
 | 
			
		||||
	return association.Error
 | 
			
		||||
}
 | 
			
		||||
@ -80,10 +84,12 @@ func (association *Association) Replace(values ...interface{}) error {
 | 
			
		||||
				switch reflectValue.Kind() {
 | 
			
		||||
				case reflect.Slice, reflect.Array:
 | 
			
		||||
					for i := 0; i < reflectValue.Len(); i++ {
 | 
			
		||||
						association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
 | 
			
		||||
						association.Error = rel.Field.Set(reflectValue.Index(i),
 | 
			
		||||
							reflect.Zero(rel.Field.FieldType).Interface())
 | 
			
		||||
					}
 | 
			
		||||
				case reflect.Struct:
 | 
			
		||||
					association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
 | 
			
		||||
					association.Error = rel.Field.Set(reflectValue,
 | 
			
		||||
						reflect.Zero(rel.Field.FieldType).Interface())
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				for _, ref := range rel.References {
 | 
			
		||||
@ -118,9 +124,11 @@ func (association *Association) Replace(values ...interface{}) error {
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
 | 
			
		||||
			if _, pvs := schema.
 | 
			
		||||
				GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
 | 
			
		||||
				column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
 | 
			
		||||
				association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
 | 
			
		||||
				association.Error = tx.Where(clause.IN{Column: column, Values: values}).
 | 
			
		||||
					UpdateColumns(updateMap).Error
 | 
			
		||||
			}
 | 
			
		||||
		case schema.Many2Many:
 | 
			
		||||
			var (
 | 
			
		||||
@ -152,7 +160,8 @@ func (association *Association) Replace(values ...interface{}) error {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
 | 
			
		||||
			if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
 | 
			
		||||
			if relColumn, relValues := schema.
 | 
			
		||||
				ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
 | 
			
		||||
				tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										32
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										32
									
								
								callbacks.go
									
									
									
									
									
								
							@ -82,8 +82,11 @@ func (p *processor) Execute(db *DB) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if stmt.Model != nil {
 | 
			
		||||
		if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
 | 
			
		||||
			if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
 | 
			
		||||
		if err := stmt.Parse(stmt.Model); err != nil &&
 | 
			
		||||
			(!errors.Is(err, schema.ErrUnsupportedDataType) ||
 | 
			
		||||
				(stmt.Table == "" && stmt.SQL.Len() == 0)) {
 | 
			
		||||
			if errors.Is(err, schema.ErrUnsupportedDataType) &&
 | 
			
		||||
				stmt.Table == "" {
 | 
			
		||||
				db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
 | 
			
		||||
			} else {
 | 
			
		||||
				db.AddError(err)
 | 
			
		||||
@ -163,7 +166,8 @@ func (p *processor) compile() (err error) {
 | 
			
		||||
	p.callbacks = callbacks
 | 
			
		||||
 | 
			
		||||
	if p.fns, err = sortCallbacks(p.callbacks); err != nil {
 | 
			
		||||
		p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
 | 
			
		||||
		p.db.Logger.Error(context.Background(),
 | 
			
		||||
			"Got error when compile callbacks, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -186,7 +190,8 @@ func (c *callback) Register(name string, fn func(*DB)) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *callback) Remove(name string) error {
 | 
			
		||||
	c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
 | 
			
		||||
	c.processor.db.Logger.Warn(context.Background(),
 | 
			
		||||
		"removing callback `%v` from %v\n", name, utils.FileWithLineNum())
 | 
			
		||||
	c.name = name
 | 
			
		||||
	c.remove = true
 | 
			
		||||
	c.processor.callbacks = append(c.processor.callbacks, c)
 | 
			
		||||
@ -194,7 +199,8 @@ func (c *callback) Remove(name string) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *callback) Replace(name string, fn func(*DB)) error {
 | 
			
		||||
	c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
 | 
			
		||||
	c.processor.db.Logger.Info(context.Background(),
 | 
			
		||||
		"replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
 | 
			
		||||
	c.name = name
 | 
			
		||||
	c.handler = fn
 | 
			
		||||
	c.replace = true
 | 
			
		||||
@ -223,8 +229,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
 | 
			
		||||
 | 
			
		||||
	for _, c := range cs {
 | 
			
		||||
		// show warning message the callback name already exists
 | 
			
		||||
		if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
 | 
			
		||||
			c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
 | 
			
		||||
		if idx := getRIndex(names, c.name); idx > -1 && !c.replace &&
 | 
			
		||||
			!c.remove && !cs[idx].remove {
 | 
			
		||||
			c.processor.db.Logger.Warn(context.Background(),
 | 
			
		||||
				"duplicated callback `%v` from %v\n", c.name,
 | 
			
		||||
				utils.FileWithLineNum())
 | 
			
		||||
		}
 | 
			
		||||
		names = append(names, c.name)
 | 
			
		||||
	}
 | 
			
		||||
@ -238,9 +247,11 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
 | 
			
		||||
			} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
 | 
			
		||||
				if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
 | 
			
		||||
					// if before callback already sorted, append current callback just after it
 | 
			
		||||
					sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
 | 
			
		||||
					sorted = append(sorted[:sortedIdx],
 | 
			
		||||
						append([]string{c.name}, sorted[sortedIdx:]...)...)
 | 
			
		||||
				} else if curIdx > sortedIdx {
 | 
			
		||||
					return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
 | 
			
		||||
					return fmt.Errorf("conflicting callback %v with before %v",
 | 
			
		||||
						c.name, c.before)
 | 
			
		||||
				}
 | 
			
		||||
			} else if idx := getRIndex(names, c.before); idx != -1 {
 | 
			
		||||
				// if before callback exists
 | 
			
		||||
@ -258,7 +269,8 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
 | 
			
		||||
					// if after callback sorted, append current callback to last
 | 
			
		||||
					sorted = append(sorted, c.name)
 | 
			
		||||
				} else if curIdx < sortedIdx {
 | 
			
		||||
					return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
 | 
			
		||||
					return fmt.Errorf("conflicting callback %v with before %v",
 | 
			
		||||
						c.name, c.after)
 | 
			
		||||
				}
 | 
			
		||||
			} else if idx := getRIndex(names, c.after); idx != -1 {
 | 
			
		||||
				// if after callback exists but haven't sorted
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,8 @@ import (
 | 
			
		||||
// Model specify the model you would like to run db operations
 | 
			
		||||
//    // update all users's name to `hello`
 | 
			
		||||
//    db.Model(&User{}).Update("name", "hello")
 | 
			
		||||
//    // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
 | 
			
		||||
//    // if user's primary key is non-blank, will use it as condition,
 | 
			
		||||
//    then will only update the user's name to `hello`
 | 
			
		||||
//    db.Model(&user).Update("name", "hello")
 | 
			
		||||
func (db *DB) Model(value interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
@ -36,7 +37,8 @@ func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(whereConds) > 0 {
 | 
			
		||||
		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
 | 
			
		||||
		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.
 | 
			
		||||
			BuildCondition(whereConds[0], whereConds[1:]...)})
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -46,9 +48,11 @@ var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
 | 
			
		||||
// Table specify the table you would like to run db operations
 | 
			
		||||
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
	if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
 | 
			
		||||
	if strings.Contains(name, " ") || strings.Contains(name, "`") ||
 | 
			
		||||
		len(args) > 0 {
 | 
			
		||||
		tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
 | 
			
		||||
		if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
 | 
			
		||||
		if results := tableRegexp.
 | 
			
		||||
			FindStringSubmatch(name); len(results) == 2 {
 | 
			
		||||
			tx.Statement.Table = results[1]
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
@ -87,7 +91,8 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
 | 
			
		||||
			case []string:
 | 
			
		||||
				tx.Statement.Selects = append(tx.Statement.Selects, arg...)
 | 
			
		||||
			default:
 | 
			
		||||
				tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
 | 
			
		||||
				tx.AddError(fmt.Errorf("unsupported select args %v %v",
 | 
			
		||||
					query, args))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
@ -125,12 +130,14 @@ func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Omit specify fields that you want to ignore when creating, updating and querying
 | 
			
		||||
// Omit specify fields that you want to ignore when creating,
 | 
			
		||||
// updating and querying
 | 
			
		||||
func (db *DB) Omit(columns ...string) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
 | 
			
		||||
	if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
 | 
			
		||||
		tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
 | 
			
		||||
		tx.Statement.Omits = strings.FieldsFunc(columns[0],
 | 
			
		||||
			utils.IsValidDBNameChar)
 | 
			
		||||
	} else {
 | 
			
		||||
		tx.Statement.Omits = columns
 | 
			
		||||
	}
 | 
			
		||||
@ -150,7 +157,9 @@ func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
 | 
			
		||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
	if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
 | 
			
		||||
		tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
 | 
			
		||||
		tx.Statement.AddClause(clause.Where{
 | 
			
		||||
			Exprs: []clause.Expression{clause.Not(conds...)},
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -158,8 +167,11 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
 | 
			
		||||
// Or add OR conditions
 | 
			
		||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
	if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
 | 
			
		||||
		tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
 | 
			
		||||
	if conds := tx.Statement.
 | 
			
		||||
		BuildCondition(query, args...); len(conds) > 0 {
 | 
			
		||||
		tx.Statement.AddClause(clause.Where{
 | 
			
		||||
			Exprs: []clause.Expression{clause.Or(clause.And(conds...))}},
 | 
			
		||||
		)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@ -169,7 +181,10 @@ func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
 | 
			
		||||
//     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
 | 
			
		||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
	tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
 | 
			
		||||
	tx.Statement.Joins = append(tx.Statement.Joins, join{
 | 
			
		||||
		Name: query, Conds: args,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,8 @@ import (
 | 
			
		||||
var (
 | 
			
		||||
	// ErrRecordNotFound record not found error
 | 
			
		||||
	ErrRecordNotFound = errors.New("record not found")
 | 
			
		||||
	// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
 | 
			
		||||
	// ErrInvalidTransaction invalid transaction
 | 
			
		||||
	// when you are trying to `Commit` or `Rollback`
 | 
			
		||||
	ErrInvalidTransaction = errors.New("no valid transaction")
 | 
			
		||||
	// ErrNotImplemented not implemented
 | 
			
		||||
	ErrNotImplemented = errors.New("not implemented")
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,8 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Save update value in database, if the value doesn't have primary key, will insert it
 | 
			
		||||
// Save update value in database, if the value doesn't have primary key,
 | 
			
		||||
// will insert it
 | 
			
		||||
func (db *DB) Save(value interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
	tx.Statement.Dest = value
 | 
			
		||||
@ -78,9 +79,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
 | 
			
		||||
		if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
 | 
			
		||||
			tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
 | 
			
		||||
		}
 | 
			
		||||
		tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
 | 
			
		||||
		tx.callbacks.Create().
 | 
			
		||||
			Execute(tx.InstanceSet("gorm:update_track_time", true))
 | 
			
		||||
	case reflect.Struct:
 | 
			
		||||
		if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
 | 
			
		||||
		if err := tx.Statement.Parse(value); err == nil &&
 | 
			
		||||
			tx.Statement.Schema != nil {
 | 
			
		||||
			for _, pf := range tx.Statement.Schema.PrimaryFields {
 | 
			
		||||
				if _, isZero := pf.ValueOf(reflectValue); isZero {
 | 
			
		||||
					tx.callbacks.Create().Execute(tx)
 | 
			
		||||
@ -99,9 +102,11 @@ func (db *DB) Save(value interface{}) (tx *DB) {
 | 
			
		||||
 | 
			
		||||
		tx.callbacks.Update().Execute(tx)
 | 
			
		||||
 | 
			
		||||
		if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate {
 | 
			
		||||
		if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun &&
 | 
			
		||||
			!selectedUpdate {
 | 
			
		||||
			result := reflect.New(tx.Statement.Schema.ModelType).Interface()
 | 
			
		||||
			if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) {
 | 
			
		||||
			if err := tx.Session(&Session{}).
 | 
			
		||||
				First(result).Error; errors.Is(err, ErrRecordNotFound) {
 | 
			
		||||
				return tx.Create(value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
@ -113,10 +118,15 @@ func (db *DB) Save(value interface{}) (tx *DB) {
 | 
			
		||||
// First find first record that match given conditions, order by primary key
 | 
			
		||||
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.Limit(1).Order(clause.OrderByColumn{
 | 
			
		||||
		Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
 | 
			
		||||
		Column: clause.Column{
 | 
			
		||||
			Table: clause.CurrentTable,
 | 
			
		||||
			Name:  clause.PrimaryKey,
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if len(conds) > 0 {
 | 
			
		||||
		if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
			
		||||
		if exprs := tx.
 | 
			
		||||
			Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
 | 
			
		||||
			tx.Statement.AddClause(clause.Where{Exprs: exprs})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@ -126,7 +136,8 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Take return a record that match given conditions, the order will depend on the database implementation
 | 
			
		||||
// Take return a record that match given conditions, the order will
 | 
			
		||||
// depend on the database implementation
 | 
			
		||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.Limit(1)
 | 
			
		||||
	if len(conds) > 0 {
 | 
			
		||||
@ -198,8 +209,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
 | 
			
		||||
				tx.AddError(ErrPrimaryKeyRequired)
 | 
			
		||||
				break
 | 
			
		||||
			} else {
 | 
			
		||||
				primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
 | 
			
		||||
				queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
 | 
			
		||||
				primaryValue, _ := result.Statement.Schema.
 | 
			
		||||
					PrioritizedPrimaryField.
 | 
			
		||||
					ValueOf(resultsValue.Index(resultsValue.Len() - 1))
 | 
			
		||||
				queryDB = tx.Clauses(clause.Gt{
 | 
			
		||||
					Column: clause.Column{
 | 
			
		||||
						Table: clause.CurrentTable,
 | 
			
		||||
						Name:  clause.PrimaryKey,
 | 
			
		||||
					},
 | 
			
		||||
					Value: primaryValue})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -29,9 +29,12 @@ type Plugin interface {
 | 
			
		||||
// ConnPool db conns pool interface
 | 
			
		||||
type ConnPool interface {
 | 
			
		||||
	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
 | 
			
		||||
	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
 | 
			
		||||
	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
 | 
			
		||||
	QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
 | 
			
		||||
	ExecContext(ctx context.Context, query string,
 | 
			
		||||
		args ...interface{}) (sql.Result, error)
 | 
			
		||||
	QueryContext(ctx context.Context, query string,
 | 
			
		||||
		args ...interface{}) (*sql.Rows, error)
 | 
			
		||||
	QueryRowContext(ctx context.Context, query string,
 | 
			
		||||
		args ...interface{}) *sql.Row
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SavePointerDialectorInterface save pointer interface
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,8 @@ type Migrator interface {
 | 
			
		||||
	AddColumn(dst interface{}, field string) error
 | 
			
		||||
	DropColumn(dst interface{}, field string) error
 | 
			
		||||
	AlterColumn(dst interface{}, field string) error
 | 
			
		||||
	MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
 | 
			
		||||
	MigrateColumn(dst interface{}, field *schema.Field,
 | 
			
		||||
		columnType ColumnType) error
 | 
			
		||||
	HasColumn(dst interface{}, field string) bool
 | 
			
		||||
	RenameColumn(dst interface{}, oldName, field string) error
 | 
			
		||||
	ColumnTypes(dst interface{}) ([]ColumnType, error)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								model.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								model.go
									
									
									
									
									
								
							@ -2,8 +2,10 @@ package gorm
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
 | 
			
		||||
// It may be embedded into your model or you may build your own model without it
 | 
			
		||||
// Model a basic GoLang struct which includes the following fields: ID,
 | 
			
		||||
// CreatedAt, UpdatedAt, DeletedAt
 | 
			
		||||
// It may be embedded into your model
 | 
			
		||||
// or you may build your own model without it
 | 
			
		||||
//    type User struct {
 | 
			
		||||
//      gorm.Model
 | 
			
		||||
//    }
 | 
			
		||||
 | 
			
		||||
@ -30,9 +30,11 @@ func (db *PreparedStmtDB) Close() {
 | 
			
		||||
	db.Mux.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
 | 
			
		||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool,
 | 
			
		||||
	isTransaction bool, query string) (Stmt, error) {
 | 
			
		||||
	db.Mux.RLock()
 | 
			
		||||
	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
 | 
			
		||||
	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction ||
 | 
			
		||||
		isTransaction) {
 | 
			
		||||
		db.Mux.RUnlock()
 | 
			
		||||
		return stmt, nil
 | 
			
		||||
	}
 | 
			
		||||
@ -40,7 +42,8 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
 | 
			
		||||
 | 
			
		||||
	db.Mux.Lock()
 | 
			
		||||
	// double check
 | 
			
		||||
	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
 | 
			
		||||
	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction ||
 | 
			
		||||
		isTransaction) {
 | 
			
		||||
		db.Mux.Unlock()
 | 
			
		||||
		return stmt, nil
 | 
			
		||||
	} else if ok {
 | 
			
		||||
@ -57,7 +60,8 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
 | 
			
		||||
	return db.Stmts[query], err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
 | 
			
		||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context,
 | 
			
		||||
	opt *sql.TxOptions) (ConnPool, error) {
 | 
			
		||||
	if beginner, ok := db.ConnPool.(TxBeginner); ok {
 | 
			
		||||
		tx, err := beginner.BeginTx(ctx, opt)
 | 
			
		||||
		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
 | 
			
		||||
@ -65,7 +69,8 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn
 | 
			
		||||
	return nil, ErrInvalidTransaction
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
 | 
			
		||||
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string,
 | 
			
		||||
	args ...interface{}) (result sql.Result, err error) {
 | 
			
		||||
	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		result, err = stmt.ExecContext(ctx, args...)
 | 
			
		||||
@ -79,7 +84,8 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
 | 
			
		||||
	return result, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
			
		||||
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string,
 | 
			
		||||
	args ...interface{}) (rows *sql.Rows, err error) {
 | 
			
		||||
	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		rows, err = stmt.QueryContext(ctx, args...)
 | 
			
		||||
@ -93,7 +99,8 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
 | 
			
		||||
	return rows, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
			
		||||
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string,
 | 
			
		||||
	args ...interface{}) *sql.Row {
 | 
			
		||||
	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return stmt.QueryRowContext(ctx, args...)
 | 
			
		||||
@ -120,10 +127,12 @@ func (tx *PreparedStmtTX) Rollback() error {
 | 
			
		||||
	return ErrInvalidTransaction
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
 | 
			
		||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string,
 | 
			
		||||
	args ...interface{}) (result sql.Result, err error) {
 | 
			
		||||
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
 | 
			
		||||
		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).
 | 
			
		||||
			ExecContext(ctx, args...)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			tx.PreparedStmtDB.Mux.Lock()
 | 
			
		||||
			stmt.Close()
 | 
			
		||||
@ -134,7 +143,8 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
 | 
			
		||||
	return result, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
			
		||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string,
 | 
			
		||||
	args ...interface{}) (rows *sql.Rows, err error) {
 | 
			
		||||
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
 | 
			
		||||
@ -148,7 +158,8 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
 | 
			
		||||
	return rows, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
			
		||||
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string,
 | 
			
		||||
	args ...interface{}) *sql.Row {
 | 
			
		||||
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										109
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										109
									
								
								scan.go
									
									
									
									
									
								
							@ -10,19 +10,24 @@ import (
 | 
			
		||||
	"gorm.io/gorm/schema"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
 | 
			
		||||
func prepareValues(values []interface{}, db *DB,
 | 
			
		||||
	columnTypes []*sql.ColumnType, columns []string) {
 | 
			
		||||
	if db.Statement.Schema != nil {
 | 
			
		||||
		for idx, name := range columns {
 | 
			
		||||
			if field := db.Statement.Schema.LookUpField(name); field != nil {
 | 
			
		||||
				values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
 | 
			
		||||
			field := db.Statement.Schema.LookUpField(name)
 | 
			
		||||
			if field != nil {
 | 
			
		||||
				values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).
 | 
			
		||||
					Interface()
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			values[idx] = new(interface{})
 | 
			
		||||
		}
 | 
			
		||||
	} else if len(columnTypes) > 0 {
 | 
			
		||||
		for idx, columnType := range columnTypes {
 | 
			
		||||
			if columnType.ScanType() != nil {
 | 
			
		||||
				values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
 | 
			
		||||
				values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).
 | 
			
		||||
					Interface()
 | 
			
		||||
			} else {
 | 
			
		||||
				values[idx] = new(interface{})
 | 
			
		||||
			}
 | 
			
		||||
@ -34,9 +39,14 @@ func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
 | 
			
		||||
func scanIntoMap(mapValue map[string]interface{},
 | 
			
		||||
	values []interface{}, columns []string) {
 | 
			
		||||
	for idx, column := range columns {
 | 
			
		||||
		if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
 | 
			
		||||
		reflectValue := reflect.Indirect(
 | 
			
		||||
			reflect.Indirect(reflect.ValueOf(values[idx])),
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		if reflectValue.IsValid() {
 | 
			
		||||
			mapValue[column] = reflectValue.Interface()
 | 
			
		||||
			if valuer, ok := mapValue[column].(driver.Valuer); ok {
 | 
			
		||||
				mapValue[column], _ = valuer.Value()
 | 
			
		||||
@ -111,28 +121,42 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
 | 
			
		||||
				reflectValueType = reflectValueType.Elem()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20))
 | 
			
		||||
			db.Statement.ReflectValue.Set(
 | 
			
		||||
				reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 20),
 | 
			
		||||
			)
 | 
			
		||||
 | 
			
		||||
			if Schema != nil {
 | 
			
		||||
				if reflectValueType != Schema.ModelType && reflectValueType.Kind() == reflect.Struct {
 | 
			
		||||
					Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
 | 
			
		||||
				if reflectValueType != Schema.ModelType &&
 | 
			
		||||
					reflectValueType.Kind() == reflect.Struct {
 | 
			
		||||
					Schema, _ = schema.Parse(db.Statement.Dest,
 | 
			
		||||
						db.cacheStore, db.NamingStrategy)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				for idx, column := range columns {
 | 
			
		||||
					if field := Schema.LookUpField(column); field != nil && field.Readable {
 | 
			
		||||
					if field := Schema.LookUpField(column); field != nil &&
 | 
			
		||||
						field.Readable {
 | 
			
		||||
						fields[idx] = field
 | 
			
		||||
					} else if names := strings.Split(column, "__"); len(names) > 1 {
 | 
			
		||||
						if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
 | 
			
		||||
							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
 | 
			
		||||
								fields[idx] = field
 | 
			
		||||
					} else if names := strings.
 | 
			
		||||
						Split(column, "__"); len(names) > 1 {
 | 
			
		||||
						rel, ok := Schema.Relationships.Relations[names[0]]
 | 
			
		||||
						if ok {
 | 
			
		||||
							field2 := rel.FieldSchema.LookUpField(
 | 
			
		||||
								strings.Join(names[1:], "__"),
 | 
			
		||||
							)
 | 
			
		||||
							if field2 != nil && field2.Readable {
 | 
			
		||||
								fields[idx] = field2
 | 
			
		||||
 | 
			
		||||
								if len(joinFields) == 0 {
 | 
			
		||||
									joinFields = make([][2]*schema.Field, len(columns))
 | 
			
		||||
									joinFields = make([][2]*schema.Field,
 | 
			
		||||
										len(columns))
 | 
			
		||||
								}
 | 
			
		||||
								joinFields[idx] = [2]*schema.Field{rel.Field, field}
 | 
			
		||||
 | 
			
		||||
								joinFields[idx] = [2]*schema.Field{rel.Field,
 | 
			
		||||
									field2}
 | 
			
		||||
								continue
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						values[idx] = &sql.RawBytes{}
 | 
			
		||||
					} else {
 | 
			
		||||
						values[idx] = &sql.RawBytes{}
 | 
			
		||||
@ -143,9 +167,11 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
 | 
			
		||||
			// pluck values into slice of data
 | 
			
		||||
			isPluck := false
 | 
			
		||||
			if len(fields) == 1 {
 | 
			
		||||
				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok || // is scanner
 | 
			
		||||
					reflectValueType.Kind() != reflect.Struct || // is not struct
 | 
			
		||||
					Schema.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 | 
			
		||||
				_, ok := reflect.New(reflectValueType).
 | 
			
		||||
					Interface().(sql.Scanner)
 | 
			
		||||
				// is scanner or is not struct or is time
 | 
			
		||||
				if ok || reflectValueType.Kind() != reflect.Struct ||
 | 
			
		||||
					Schema.ModelType.ConvertibleTo(schema.TimeReflectType) {
 | 
			
		||||
					isPluck = true
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
@ -160,7 +186,9 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
 | 
			
		||||
				} else {
 | 
			
		||||
					for idx, field := range fields {
 | 
			
		||||
						if field != nil {
 | 
			
		||||
							values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
 | 
			
		||||
							values[idx] = reflect.New(
 | 
			
		||||
								reflect.PtrTo(field.IndirectFieldType),
 | 
			
		||||
							).Interface()
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
@ -171,11 +199,14 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
 | 
			
		||||
							value := reflect.ValueOf(values[idx]).Elem()
 | 
			
		||||
							relValue := joinFields[idx][0].ReflectValueOf(elem)
 | 
			
		||||
 | 
			
		||||
							if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
 | 
			
		||||
							if relValue.Kind() == reflect.Ptr &&
 | 
			
		||||
								relValue.IsNil() {
 | 
			
		||||
								if value.IsNil() {
 | 
			
		||||
									continue
 | 
			
		||||
								}
 | 
			
		||||
								relValue.Set(reflect.New(relValue.Type().Elem()))
 | 
			
		||||
								relValue.Set(
 | 
			
		||||
									reflect.New(relValue.Type().Elem()),
 | 
			
		||||
								)
 | 
			
		||||
							}
 | 
			
		||||
 | 
			
		||||
							field.Set(relValue, values[idx])
 | 
			
		||||
@ -186,24 +217,36 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if isPtr {
 | 
			
		||||
					db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem))
 | 
			
		||||
					db.Statement.ReflectValue.Set(reflect.
 | 
			
		||||
						Append(db.Statement.ReflectValue, elem))
 | 
			
		||||
				} else {
 | 
			
		||||
					db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Elem()))
 | 
			
		||||
					db.Statement.ReflectValue.Set(reflect.
 | 
			
		||||
						Append(db.Statement.ReflectValue, elem.Elem()))
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		case reflect.Struct, reflect.Ptr:
 | 
			
		||||
			if db.Statement.ReflectValue.Type() != Schema.ModelType {
 | 
			
		||||
				Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
 | 
			
		||||
				Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore,
 | 
			
		||||
					db.NamingStrategy)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if initialized || rows.Next() {
 | 
			
		||||
				for idx, column := range columns {
 | 
			
		||||
					if field := Schema.LookUpField(column); field != nil && field.Readable {
 | 
			
		||||
						values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
 | 
			
		||||
					if field := Schema.LookUpField(column); field != nil &&
 | 
			
		||||
						field.Readable {
 | 
			
		||||
						values[idx] = reflect.New(
 | 
			
		||||
							reflect.PtrTo(field.IndirectFieldType),
 | 
			
		||||
						).Interface()
 | 
			
		||||
					} else if names := strings.Split(column, "__"); len(names) > 1 {
 | 
			
		||||
						if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
 | 
			
		||||
							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
 | 
			
		||||
								values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
 | 
			
		||||
						rel, ok := Schema.Relationships.Relations[names[0]]
 | 
			
		||||
						if ok {
 | 
			
		||||
							field := rel.FieldSchema.
 | 
			
		||||
								LookUpField(strings.Join(names[1:], "__"))
 | 
			
		||||
							if field != nil &&
 | 
			
		||||
								field.Readable {
 | 
			
		||||
								values[idx] = reflect.New(
 | 
			
		||||
									reflect.PtrTo(field.IndirectFieldType),
 | 
			
		||||
								).Interface()
 | 
			
		||||
								continue
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
@ -217,11 +260,13 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
 | 
			
		||||
				db.AddError(rows.Scan(values...))
 | 
			
		||||
 | 
			
		||||
				for idx, column := range columns {
 | 
			
		||||
					if field := Schema.LookUpField(column); field != nil && field.Readable {
 | 
			
		||||
					if field := Schema.LookUpField(column); field != nil &&
 | 
			
		||||
						field.Readable {
 | 
			
		||||
						field.Set(db.Statement.ReflectValue, values[idx])
 | 
			
		||||
					} else if names := strings.Split(column, "__"); len(names) > 1 {
 | 
			
		||||
						if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
 | 
			
		||||
							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
 | 
			
		||||
							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil &&
 | 
			
		||||
								field.Readable {
 | 
			
		||||
								relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
 | 
			
		||||
								value := reflect.ValueOf(values[idx]).Elem()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -65,10 +65,15 @@ func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
 | 
			
		||||
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
 | 
			
		||||
	if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok {
 | 
			
		||||
		if c, ok := stmt.Clauses["WHERE"]; ok {
 | 
			
		||||
			if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) > 1 {
 | 
			
		||||
			if where, ok := c.Expression.(clause.Where); ok &&
 | 
			
		||||
				len(where.Exprs) > 1 {
 | 
			
		||||
				for _, expr := range where.Exprs {
 | 
			
		||||
					if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
 | 
			
		||||
						where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
 | 
			
		||||
					if orCond, ok := expr.(clause.OrConditions); ok &&
 | 
			
		||||
						len(orCond.Exprs) == 1 {
 | 
			
		||||
						where.Exprs = []clause.Expression{
 | 
			
		||||
							clause.And(where.Exprs...),
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						c.Expression = where
 | 
			
		||||
						stmt.Clauses["WHERE"] = c
 | 
			
		||||
						break
 | 
			
		||||
@ -78,7 +83,11 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		stmt.AddClause(clause.Where{Exprs: []clause.Expression{
 | 
			
		||||
			clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: nil},
 | 
			
		||||
			clause.Eq{Column: clause.Column{
 | 
			
		||||
				Table: clause.CurrentTable,
 | 
			
		||||
				Name:  sd.Field.DBName},
 | 
			
		||||
				Value: nil,
 | 
			
		||||
			},
 | 
			
		||||
		}})
 | 
			
		||||
		stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
 | 
			
		||||
	}
 | 
			
		||||
@ -105,28 +114,50 @@ func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
 | 
			
		||||
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
 | 
			
		||||
	if stmt.SQL.String() == "" {
 | 
			
		||||
		curTime := stmt.DB.NowFunc()
 | 
			
		||||
		stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
 | 
			
		||||
		stmt.AddClause(clause.Set{
 | 
			
		||||
			{
 | 
			
		||||
				Column: clause.Column{Name: sd.Field.DBName},
 | 
			
		||||
				Value:  curTime,
 | 
			
		||||
			},
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		stmt.SetColumn(sd.Field.DBName, curTime, true)
 | 
			
		||||
 | 
			
		||||
		if stmt.Schema != nil {
 | 
			
		||||
			_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
 | 
			
		||||
			column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
 | 
			
		||||
			_, queryValues := schema.GetIdentityFieldValuesMap(
 | 
			
		||||
				stmt.ReflectValue, stmt.Schema.PrimaryFields,
 | 
			
		||||
			)
 | 
			
		||||
			column, values := schema.ToQueryValues(stmt.Table,
 | 
			
		||||
				stmt.Schema.PrimaryFieldDBNames, queryValues)
 | 
			
		||||
 | 
			
		||||
			if len(values) > 0 {
 | 
			
		||||
				stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
 | 
			
		||||
				stmt.AddClause(clause.Where{Exprs: []clause.Expression{
 | 
			
		||||
					clause.IN{Column: column, Values: values}}})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
 | 
			
		||||
				_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
 | 
			
		||||
				column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
 | 
			
		||||
			if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model &&
 | 
			
		||||
				stmt.Model != nil {
 | 
			
		||||
				_, queryValues = schema.GetIdentityFieldValuesMap(
 | 
			
		||||
					reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields,
 | 
			
		||||
				)
 | 
			
		||||
 | 
			
		||||
				column, values = schema.ToQueryValues(
 | 
			
		||||
					stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues,
 | 
			
		||||
				)
 | 
			
		||||
				if len(values) > 0 {
 | 
			
		||||
					stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
 | 
			
		||||
					stmt.AddClause(clause.Where{
 | 
			
		||||
						Exprs: []clause.Expression{
 | 
			
		||||
							clause.IN{
 | 
			
		||||
								Column: column, Values: values,
 | 
			
		||||
							},
 | 
			
		||||
						},
 | 
			
		||||
					})
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
 | 
			
		||||
		if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate &&
 | 
			
		||||
			!ok {
 | 
			
		||||
			stmt.DB.AddError(ErrMissingWhereClause)
 | 
			
		||||
		} else {
 | 
			
		||||
			SoftDeleteQueryClause{Field: sd.Field}.ModifyStatement(stmt)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										113
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										113
									
								
								statement.go
									
									
									
									
									
								
							@ -104,7 +104,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
 | 
			
		||||
			if stmt.Schema == nil {
 | 
			
		||||
				stmt.DB.AddError(ErrModelValueRequired)
 | 
			
		||||
			} else if stmt.Schema.PrioritizedPrimaryField != nil {
 | 
			
		||||
				stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.PrioritizedPrimaryField.DBName)
 | 
			
		||||
				stmt.DB.Dialector.QuoteTo(writer,
 | 
			
		||||
					stmt.Schema.PrioritizedPrimaryField.DBName)
 | 
			
		||||
			} else if len(stmt.Schema.DBNames) > 0 {
 | 
			
		||||
				stmt.DB.Dialector.QuoteTo(writer, stmt.Schema.DBNames[0])
 | 
			
		||||
			}
 | 
			
		||||
@ -181,7 +182,8 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
 | 
			
		||||
				writer.WriteString("(NULL)")
 | 
			
		||||
			}
 | 
			
		||||
		case *DB:
 | 
			
		||||
			subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
 | 
			
		||||
			subdb := v.Session(&Session{Logger: logger.Discard,
 | 
			
		||||
				DryRun: true}).getInstance()
 | 
			
		||||
			subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...)
 | 
			
		||||
			subdb.callbacks.Query().Execute(subdb)
 | 
			
		||||
			writer.WriteString(subdb.Statement.SQL.String())
 | 
			
		||||
@ -230,7 +232,8 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BuildCondition build condition
 | 
			
		||||
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
 | 
			
		||||
func (stmt *Statement) BuildCondition(query interface{},
 | 
			
		||||
	args ...interface{}) []clause.Expression {
 | 
			
		||||
	if s, ok := query.(string); ok {
 | 
			
		||||
		// if it is a number, then treats it as primary key
 | 
			
		||||
		if _, err := strconv.Atoi(s); err != nil {
 | 
			
		||||
@ -262,10 +265,14 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
			
		||||
			if cs, ok := v.Statement.Clauses["WHERE"]; ok {
 | 
			
		||||
				if where, ok := cs.Expression.(clause.Where); ok {
 | 
			
		||||
					if len(where.Exprs) == 1 {
 | 
			
		||||
						if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
 | 
			
		||||
							where.Exprs[0] = clause.AndConditions{Exprs: orConds.Exprs}
 | 
			
		||||
						orConds, ok := where.Exprs[0].(clause.OrConditions)
 | 
			
		||||
						if ok {
 | 
			
		||||
							where.Exprs[0] = clause.AndConditions{
 | 
			
		||||
								Exprs: orConds.Exprs,
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					conds = append(conds, clause.And(where.Exprs...))
 | 
			
		||||
				} else if cs.Expression != nil {
 | 
			
		||||
					conds = append(conds, cs.Expression)
 | 
			
		||||
@ -297,16 +304,24 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
			
		||||
				switch reflectValue.Kind() {
 | 
			
		||||
				case reflect.Slice, reflect.Array:
 | 
			
		||||
					if _, ok := v[key].(driver.Valuer); ok {
 | 
			
		||||
						conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
			
		||||
						conds = append(conds, clause.Eq{
 | 
			
		||||
							Column: key,
 | 
			
		||||
							Value:  v[key],
 | 
			
		||||
						})
 | 
			
		||||
					} else if _, ok := v[key].(Valuer); ok {
 | 
			
		||||
						conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
			
		||||
						conds = append(conds, clause.Eq{
 | 
			
		||||
							Column: key,
 | 
			
		||||
							Value:  v[key],
 | 
			
		||||
						})
 | 
			
		||||
					} else {
 | 
			
		||||
						values := make([]interface{}, reflectValue.Len())
 | 
			
		||||
						for i := 0; i < reflectValue.Len(); i++ {
 | 
			
		||||
							values[i] = reflectValue.Index(i).Interface()
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						conds = append(conds, clause.IN{Column: key, Values: values})
 | 
			
		||||
						conds = append(conds, clause.IN{
 | 
			
		||||
							Column: key, Values: values,
 | 
			
		||||
						})
 | 
			
		||||
					}
 | 
			
		||||
				default:
 | 
			
		||||
					conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
			
		||||
@ -314,7 +329,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			reflectValue := reflect.Indirect(reflect.ValueOf(arg))
 | 
			
		||||
			if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
 | 
			
		||||
			if s, err := schema.Parse(arg, stmt.DB.cacheStore,
 | 
			
		||||
				stmt.DB.NamingStrategy); err == nil {
 | 
			
		||||
				selectedColumns := map[string]bool{}
 | 
			
		||||
				if idx == 0 {
 | 
			
		||||
					for _, v := range args[1:] {
 | 
			
		||||
@ -328,27 +344,56 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
			
		||||
				switch reflectValue.Kind() {
 | 
			
		||||
				case reflect.Struct:
 | 
			
		||||
					for _, field := range s.Fields {
 | 
			
		||||
						selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
 | 
			
		||||
						selected := selectedColumns[field.DBName] ||
 | 
			
		||||
							selectedColumns[field.Name]
 | 
			
		||||
						if selected || (!restricted && field.Readable) {
 | 
			
		||||
							if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
 | 
			
		||||
							v, isZero := field.ValueOf(reflectValue)
 | 
			
		||||
							if !isZero || selected {
 | 
			
		||||
								if field.DBName != "" {
 | 
			
		||||
									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
 | 
			
		||||
									conds = append(conds, clause.Eq{
 | 
			
		||||
										Column: clause.Column{
 | 
			
		||||
											Table: clause.CurrentTable,
 | 
			
		||||
											Name:  field.DBName,
 | 
			
		||||
										},
 | 
			
		||||
										Value: v,
 | 
			
		||||
									})
 | 
			
		||||
								} else if field.DataType != "" {
 | 
			
		||||
									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
 | 
			
		||||
									conds = append(conds, clause.Eq{
 | 
			
		||||
										Column: clause.Column{
 | 
			
		||||
											Table: clause.CurrentTable,
 | 
			
		||||
											Name:  field.Name,
 | 
			
		||||
										},
 | 
			
		||||
										Value: v,
 | 
			
		||||
									})
 | 
			
		||||
								}
 | 
			
		||||
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				case reflect.Slice, reflect.Array:
 | 
			
		||||
					for i := 0; i < reflectValue.Len(); i++ {
 | 
			
		||||
						for _, field := range s.Fields {
 | 
			
		||||
							selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
 | 
			
		||||
							selected := selectedColumns[field.DBName] ||
 | 
			
		||||
								selectedColumns[field.Name]
 | 
			
		||||
							if selected || (!restricted && field.Readable) {
 | 
			
		||||
								if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
 | 
			
		||||
								v, isZero := field.ValueOf(reflectValue.Index(i))
 | 
			
		||||
								if !isZero || selected {
 | 
			
		||||
									if field.DBName != "" {
 | 
			
		||||
										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
 | 
			
		||||
										conds = append(conds, clause.Eq{
 | 
			
		||||
											Column: clause.Column{
 | 
			
		||||
												Table: clause.CurrentTable,
 | 
			
		||||
												Name:  field.DBName,
 | 
			
		||||
											},
 | 
			
		||||
											Value: v,
 | 
			
		||||
										})
 | 
			
		||||
									} else if field.DataType != "" {
 | 
			
		||||
										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
 | 
			
		||||
										conds = append(conds, clause.Eq{
 | 
			
		||||
											Column: clause.Column{
 | 
			
		||||
												Table: clause.CurrentTable,
 | 
			
		||||
												Name:  field.Name,
 | 
			
		||||
											},
 | 
			
		||||
											Value: v,
 | 
			
		||||
										})
 | 
			
		||||
									}
 | 
			
		||||
								}
 | 
			
		||||
							}
 | 
			
		||||
@ -371,13 +416,20 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						if len(values) > 0 {
 | 
			
		||||
							conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
 | 
			
		||||
							conds = append(conds, clause.IN{
 | 
			
		||||
								Column: clause.PrimaryColumn,
 | 
			
		||||
								Values: values,
 | 
			
		||||
							})
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						return conds
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
 | 
			
		||||
				conds = append(conds, clause.IN{
 | 
			
		||||
					Column: clause.PrimaryColumn,
 | 
			
		||||
					Values: args,
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@ -406,7 +458,9 @@ func (stmt *Statement) Build(clauses ...string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (stmt *Statement) Parse(value interface{}) (err error) {
 | 
			
		||||
	if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
 | 
			
		||||
	stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore,
 | 
			
		||||
		stmt.DB.NamingStrategy)
 | 
			
		||||
	if err == nil && stmt.Table == "" {
 | 
			
		||||
		if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
 | 
			
		||||
			stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
 | 
			
		||||
			stmt.Table = tables[1]
 | 
			
		||||
@ -415,6 +469,7 @@ func (stmt *Statement) Parse(value interface{}) (err error) {
 | 
			
		||||
 | 
			
		||||
		stmt.Table = stmt.Schema.Table
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -463,7 +518,8 @@ func (stmt *Statement) clone() *Statement {
 | 
			
		||||
// SetColumn set column's value
 | 
			
		||||
//   stmt.SetColumn("Name", "jinzhu") // Hooks Method
 | 
			
		||||
//   stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
 | 
			
		||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
 | 
			
		||||
func (stmt *Statement) SetColumn(name string, value interface{},
 | 
			
		||||
	fromCallbacks ...bool) {
 | 
			
		||||
	if v, ok := stmt.Dest.(map[string]interface{}); ok {
 | 
			
		||||
		v[name] = value
 | 
			
		||||
	} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
 | 
			
		||||
@ -524,7 +580,8 @@ func (stmt *Statement) Changed(fields ...string) bool {
 | 
			
		||||
	selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
 | 
			
		||||
	changed := func(field *schema.Field) bool {
 | 
			
		||||
		fieldValue, _ := field.ValueOf(modelValue)
 | 
			
		||||
		if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
 | 
			
		||||
		if v, ok := selectColumns[field.DBName]; (ok && v) ||
 | 
			
		||||
			(!ok && !restricted) {
 | 
			
		||||
			if v, ok := stmt.Dest.(map[string]interface{}); ok {
 | 
			
		||||
				if fv, ok := v[field.Name]; ok {
 | 
			
		||||
					return !utils.AssertEqual(fv, fieldValue)
 | 
			
		||||
@ -563,8 +620,10 @@ func (stmt *Statement) Changed(fields ...string) bool {
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
 | 
			
		||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
 | 
			
		||||
// SelectAndOmitColumns get select and omit columns,
 | 
			
		||||
// select -> true, omit -> false
 | 
			
		||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate,
 | 
			
		||||
	requireUpdate bool) (map[string]bool, bool) {
 | 
			
		||||
	results := map[string]bool{}
 | 
			
		||||
	notRestricted := false
 | 
			
		||||
 | 
			
		||||
@ -579,7 +638,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
 | 
			
		||||
			for _, rel := range stmt.Schema.Relationships.Relations {
 | 
			
		||||
				results[rel.Name] = true
 | 
			
		||||
			}
 | 
			
		||||
		} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
 | 
			
		||||
		} else if field := stmt.Schema.LookUpField(column); field != nil &&
 | 
			
		||||
			field.DBName != "" {
 | 
			
		||||
			results[field.DBName] = true
 | 
			
		||||
		} else {
 | 
			
		||||
			results[column] = true
 | 
			
		||||
@ -594,7 +654,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
 | 
			
		||||
					results[rel.Name] = false
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
 | 
			
		||||
		} else if field := stmt.Schema.LookUpField(omit); field != nil &&
 | 
			
		||||
			field.DBName != "" {
 | 
			
		||||
			results[field.DBName] = false
 | 
			
		||||
		} else {
 | 
			
		||||
			results[omit] = false
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user