Merge branch 'go-gorm:master' into master
This commit is contained in:
		
						commit
						01d96b52a2
					
				
							
								
								
									
										12
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							@ -13,7 +13,7 @@ jobs:
 | 
			
		||||
  sqlite:
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        go: ['1.16', '1.15', '1.14']
 | 
			
		||||
        go: ['1.16', '1.15']
 | 
			
		||||
        platform: [ubuntu-latest] # can not run in windows OS
 | 
			
		||||
    runs-on: ${{ matrix.platform }}
 | 
			
		||||
 | 
			
		||||
@ -38,8 +38,8 @@ jobs:
 | 
			
		||||
  mysql:
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        dbversion: ['mysql:latest', 'mysql:5.7', 'mysql:5.6', 'mariadb:latest']
 | 
			
		||||
        go: ['1.16', '1.15', '1.14']
 | 
			
		||||
        dbversion: ['mysql:latest', 'mysql:5.7', 'mariadb:latest']
 | 
			
		||||
        go: ['1.16', '1.15']
 | 
			
		||||
        platform: [ubuntu-latest]
 | 
			
		||||
    runs-on: ${{ matrix.platform }}
 | 
			
		||||
 | 
			
		||||
@ -82,8 +82,8 @@ jobs:
 | 
			
		||||
  postgres:
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        dbversion: ['postgres:latest', 'postgres:11', 'postgres:10']
 | 
			
		||||
        go: ['1.16', '1.15', '1.14']
 | 
			
		||||
        dbversion: ['postgres:latest', 'postgres:12', 'postgres:11', 'postgres:10']
 | 
			
		||||
        go: ['1.16', '1.15']
 | 
			
		||||
        platform: [ubuntu-latest] # can not run in macOS and Windows
 | 
			
		||||
    runs-on: ${{ matrix.platform }}
 | 
			
		||||
 | 
			
		||||
@ -125,7 +125,7 @@ jobs:
 | 
			
		||||
  sqlserver:
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        go: ['1.16', '1.15', '1.14']
 | 
			
		||||
        go: ['1.16', '1.15']
 | 
			
		||||
        platform: [ubuntu-latest] # can not run test in macOS and windows
 | 
			
		||||
    runs-on: ${{ matrix.platform }}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ func (db *DB) Association(column string) *Association {
 | 
			
		||||
		association.Relationship = db.Statement.Schema.Relationships.Relations[column]
 | 
			
		||||
 | 
			
		||||
		if association.Relationship == nil {
 | 
			
		||||
			association.Error = fmt.Errorf("%w: %v", ErrUnsupportedRelation, column)
 | 
			
		||||
			association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
 | 
			
		||||
@ -355,7 +355,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
 | 
			
		||||
				} else if ev.Type().Elem().AssignableTo(elemType) {
 | 
			
		||||
					fieldValue = reflect.Append(fieldValue, ev.Elem())
 | 
			
		||||
				} else {
 | 
			
		||||
					association.Error = fmt.Errorf("unsupported data type: %v for relation %v", ev.Type(), association.Relationship.Name)
 | 
			
		||||
					association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if elemType.Kind() == reflect.Struct {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								callbacks.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								callbacks.go
									
									
									
									
									
								
							@ -212,7 +212,7 @@ func (c *callback) Register(name string, fn func(*DB)) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *callback) Remove(name string) error {
 | 
			
		||||
	c.processor.db.Logger.Warn(context.Background(), "removing callback `%v` from %v\n", name, utils.FileWithLineNum())
 | 
			
		||||
	c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
 | 
			
		||||
	c.name = name
 | 
			
		||||
	c.remove = true
 | 
			
		||||
	c.processor.callbacks = append(c.processor.callbacks, c)
 | 
			
		||||
@ -220,7 +220,7 @@ func (c *callback) Remove(name string) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *callback) Replace(name string, fn func(*DB)) error {
 | 
			
		||||
	c.processor.db.Logger.Info(context.Background(), "replacing callback `%v` from %v\n", name, utils.FileWithLineNum())
 | 
			
		||||
	c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
 | 
			
		||||
	c.name = name
 | 
			
		||||
	c.handler = fn
 | 
			
		||||
	c.replace = true
 | 
			
		||||
@ -250,7 +250,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
 | 
			
		||||
	for _, c := range cs {
 | 
			
		||||
		// show warning message the callback name already exists
 | 
			
		||||
		if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
 | 
			
		||||
			c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%v` from %v\n", c.name, utils.FileWithLineNum())
 | 
			
		||||
			c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
 | 
			
		||||
		}
 | 
			
		||||
		names = append(names, c.name)
 | 
			
		||||
	}
 | 
			
		||||
@ -266,7 +266,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
 | 
			
		||||
					// if before callback already sorted, append current callback just after it
 | 
			
		||||
					sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
 | 
			
		||||
				} else if curIdx > sortedIdx {
 | 
			
		||||
					return fmt.Errorf("conflicting callback %v with before %v", c.name, c.before)
 | 
			
		||||
					return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
 | 
			
		||||
				}
 | 
			
		||||
			} else if idx := getRIndex(names, c.before); idx != -1 {
 | 
			
		||||
				// if before callback exists
 | 
			
		||||
@ -284,7 +284,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
 | 
			
		||||
					// if after callback sorted, append current callback to last
 | 
			
		||||
					sorted = append(sorted, c.name)
 | 
			
		||||
				} else if curIdx < sortedIdx {
 | 
			
		||||
					return fmt.Errorf("conflicting callback %v with before %v", c.name, c.after)
 | 
			
		||||
					return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
 | 
			
		||||
				}
 | 
			
		||||
			} else if idx := getRIndex(names, c.after); idx != -1 {
 | 
			
		||||
				// if after callback exists but haven't sorted
 | 
			
		||||
 | 
			
		||||
@ -373,7 +373,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if tx.Statement.FullSaveAssociations {
 | 
			
		||||
		tx = tx.InstanceSet("gorm:update_track_time", true)
 | 
			
		||||
		tx = tx.Set("gorm:update_track_time", true)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(selects) > 0 {
 | 
			
		||||
 | 
			
		||||
@ -33,75 +33,81 @@ func BeforeCreate(db *gorm.DB) {
 | 
			
		||||
func Create(config *Config) func(db *gorm.DB) {
 | 
			
		||||
	if config.WithReturning {
 | 
			
		||||
		return CreateWithReturning
 | 
			
		||||
	} else {
 | 
			
		||||
		return func(db *gorm.DB) {
 | 
			
		||||
			if db.Error == nil {
 | 
			
		||||
				if db.Statement.Schema != nil && !db.Statement.Unscoped {
 | 
			
		||||
					for _, c := range db.Statement.Schema.CreateClauses {
 | 
			
		||||
						db.Statement.AddClause(c)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
				if db.Statement.SQL.String() == "" {
 | 
			
		||||
					db.Statement.SQL.Grow(180)
 | 
			
		||||
					db.Statement.AddClauseIfNotExists(clause.Insert{})
 | 
			
		||||
					db.Statement.AddClause(ConvertToCreateValues(db.Statement))
 | 
			
		||||
	return func(db *gorm.DB) {
 | 
			
		||||
		if db.Error != nil {
 | 
			
		||||
			// maybe record logger TODO
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
					db.Statement.Build(db.Statement.BuildClauses...)
 | 
			
		||||
				}
 | 
			
		||||
		if db.Statement.Schema != nil && !db.Statement.Unscoped {
 | 
			
		||||
			for _, c := range db.Statement.Schema.CreateClauses {
 | 
			
		||||
				db.Statement.AddClause(c)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
				if !db.DryRun && db.Error == nil {
 | 
			
		||||
					result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
			
		||||
		if db.Statement.SQL.String() == "" {
 | 
			
		||||
			db.Statement.SQL.Grow(180)
 | 
			
		||||
			db.Statement.AddClauseIfNotExists(clause.Insert{})
 | 
			
		||||
			db.Statement.AddClause(ConvertToCreateValues(db.Statement))
 | 
			
		||||
 | 
			
		||||
					if err == nil {
 | 
			
		||||
						db.RowsAffected, _ = result.RowsAffected()
 | 
			
		||||
			db.Statement.Build(db.Statement.BuildClauses...)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
						if db.RowsAffected > 0 {
 | 
			
		||||
							if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
 | 
			
		||||
								if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
 | 
			
		||||
									switch db.Statement.ReflectValue.Kind() {
 | 
			
		||||
									case reflect.Slice, reflect.Array:
 | 
			
		||||
										if config.LastInsertIDReversed {
 | 
			
		||||
											for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
 | 
			
		||||
												rv := db.Statement.ReflectValue.Index(i)
 | 
			
		||||
												if reflect.Indirect(rv).Kind() != reflect.Struct {
 | 
			
		||||
													break
 | 
			
		||||
												}
 | 
			
		||||
		if !db.DryRun && db.Error == nil {
 | 
			
		||||
			result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
			
		||||
 | 
			
		||||
												_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
 | 
			
		||||
												if isZero {
 | 
			
		||||
													db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
 | 
			
		||||
													insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
 | 
			
		||||
												}
 | 
			
		||||
											}
 | 
			
		||||
										} else {
 | 
			
		||||
											for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
 | 
			
		||||
												rv := db.Statement.ReflectValue.Index(i)
 | 
			
		||||
												if reflect.Indirect(rv).Kind() != reflect.Struct {
 | 
			
		||||
													break
 | 
			
		||||
												}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				db.AddError(err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
												if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
 | 
			
		||||
													db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
 | 
			
		||||
													insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
 | 
			
		||||
												}
 | 
			
		||||
											}
 | 
			
		||||
										}
 | 
			
		||||
									case reflect.Struct:
 | 
			
		||||
										if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
 | 
			
		||||
											db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
 | 
			
		||||
										}
 | 
			
		||||
									}
 | 
			
		||||
								} else {
 | 
			
		||||
									db.AddError(err)
 | 
			
		||||
			db.RowsAffected, _ = result.RowsAffected()
 | 
			
		||||
			if !(db.RowsAffected > 0) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
 | 
			
		||||
				if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
 | 
			
		||||
					switch db.Statement.ReflectValue.Kind() {
 | 
			
		||||
					case reflect.Slice, reflect.Array:
 | 
			
		||||
						if config.LastInsertIDReversed {
 | 
			
		||||
							for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
 | 
			
		||||
								rv := db.Statement.ReflectValue.Index(i)
 | 
			
		||||
								if reflect.Indirect(rv).Kind() != reflect.Struct {
 | 
			
		||||
									break
 | 
			
		||||
								}
 | 
			
		||||
 | 
			
		||||
								_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
 | 
			
		||||
								if isZero {
 | 
			
		||||
									db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
 | 
			
		||||
									insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
 | 
			
		||||
								}
 | 
			
		||||
							}
 | 
			
		||||
						} else {
 | 
			
		||||
							for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
 | 
			
		||||
								rv := db.Statement.ReflectValue.Index(i)
 | 
			
		||||
								if reflect.Indirect(rv).Kind() != reflect.Struct {
 | 
			
		||||
									break
 | 
			
		||||
								}
 | 
			
		||||
 | 
			
		||||
								if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
 | 
			
		||||
									db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
 | 
			
		||||
									insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
 | 
			
		||||
								}
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					} else {
 | 
			
		||||
						db.AddError(err)
 | 
			
		||||
					case reflect.Struct:
 | 
			
		||||
						if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue); isZero {
 | 
			
		||||
							db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					db.AddError(err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -237,9 +243,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
			
		||||
	default:
 | 
			
		||||
		var (
 | 
			
		||||
			selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
 | 
			
		||||
			_, updateTrackTime        = stmt.Get("gorm:update_track_time")
 | 
			
		||||
			curTime                   = stmt.DB.NowFunc()
 | 
			
		||||
			isZero                    bool
 | 
			
		||||
		)
 | 
			
		||||
		stmt.Settings.Delete("gorm:update_track_time")
 | 
			
		||||
 | 
			
		||||
		values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
 | 
			
		||||
 | 
			
		||||
		for _, db := range stmt.Schema.DBNames {
 | 
			
		||||
@ -278,11 +287,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
			
		||||
							field.Set(rv, curTime)
 | 
			
		||||
							values.Values[i][idx], _ = field.ValueOf(rv)
 | 
			
		||||
						}
 | 
			
		||||
					} else if field.AutoUpdateTime > 0 {
 | 
			
		||||
						if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
 | 
			
		||||
							field.Set(rv, curTime)
 | 
			
		||||
							values.Values[i][idx], _ = field.ValueOf(rv)
 | 
			
		||||
						}
 | 
			
		||||
					} else if field.AutoUpdateTime > 0 && updateTrackTime {
 | 
			
		||||
						field.Set(rv, curTime)
 | 
			
		||||
						values.Values[i][idx], _ = field.ValueOf(rv)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
@ -320,11 +327,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
			
		||||
						field.Set(stmt.ReflectValue, curTime)
 | 
			
		||||
						values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
 | 
			
		||||
					}
 | 
			
		||||
				} else if field.AutoUpdateTime > 0 {
 | 
			
		||||
					if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok {
 | 
			
		||||
						field.Set(stmt.ReflectValue, curTime)
 | 
			
		||||
						values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
 | 
			
		||||
					}
 | 
			
		||||
				} else if field.AutoUpdateTime > 0 && updateTrackTime {
 | 
			
		||||
					field.Set(stmt.ReflectValue, curTime)
 | 
			
		||||
					values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,8 @@ func RowQuery(db *gorm.DB) {
 | 
			
		||||
		BuildQuerySQL(db)
 | 
			
		||||
 | 
			
		||||
		if !db.DryRun {
 | 
			
		||||
			if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) {
 | 
			
		||||
			if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
 | 
			
		||||
				db.Statement.Settings.Delete("rows")
 | 
			
		||||
				db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
			
		||||
			} else {
 | 
			
		||||
				db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | 
			
		||||
 | 
			
		||||
@ -233,11 +233,24 @@ type Eq struct {
 | 
			
		||||
func (eq Eq) Build(builder Builder) {
 | 
			
		||||
	builder.WriteQuoted(eq.Column)
 | 
			
		||||
 | 
			
		||||
	if eqNil(eq.Value) {
 | 
			
		||||
		builder.WriteString(" IS NULL")
 | 
			
		||||
	} else {
 | 
			
		||||
		builder.WriteString(" = ")
 | 
			
		||||
		builder.AddVar(builder, eq.Value)
 | 
			
		||||
	switch eq.Value.(type) {
 | 
			
		||||
	case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
 | 
			
		||||
		builder.WriteString(" IN (")
 | 
			
		||||
		rv := reflect.ValueOf(eq.Value)
 | 
			
		||||
		for i := 0; i < rv.Len(); i++ {
 | 
			
		||||
			if i > 0 {
 | 
			
		||||
				builder.WriteByte(',')
 | 
			
		||||
			}
 | 
			
		||||
			builder.AddVar(builder, rv.Index(i).Interface())
 | 
			
		||||
		}
 | 
			
		||||
		builder.WriteByte(')')
 | 
			
		||||
	default:
 | 
			
		||||
		if eqNil(eq.Value) {
 | 
			
		||||
			builder.WriteString(" IS NULL")
 | 
			
		||||
		} else {
 | 
			
		||||
			builder.WriteString(" = ")
 | 
			
		||||
			builder.AddVar(builder, eq.Value)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -251,11 +264,24 @@ type Neq Eq
 | 
			
		||||
func (neq Neq) Build(builder Builder) {
 | 
			
		||||
	builder.WriteQuoted(neq.Column)
 | 
			
		||||
 | 
			
		||||
	if eqNil(neq.Value) {
 | 
			
		||||
		builder.WriteString(" IS NOT NULL")
 | 
			
		||||
	} else {
 | 
			
		||||
		builder.WriteString(" <> ")
 | 
			
		||||
		builder.AddVar(builder, neq.Value)
 | 
			
		||||
	switch neq.Value.(type) {
 | 
			
		||||
	case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
 | 
			
		||||
		builder.WriteString(" NOT IN (")
 | 
			
		||||
		rv := reflect.ValueOf(neq.Value)
 | 
			
		||||
		for i := 0; i < rv.Len(); i++ {
 | 
			
		||||
			if i > 0 {
 | 
			
		||||
				builder.WriteByte(',')
 | 
			
		||||
			}
 | 
			
		||||
			builder.AddVar(builder, rv.Index(i).Interface())
 | 
			
		||||
		}
 | 
			
		||||
		builder.WriteByte(')')
 | 
			
		||||
	default:
 | 
			
		||||
		if eqNil(neq.Value) {
 | 
			
		||||
			builder.WriteString(" IS NOT NULL")
 | 
			
		||||
		} else {
 | 
			
		||||
			builder.WriteString(" <> ")
 | 
			
		||||
			builder.AddVar(builder, neq.Value)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -105,13 +105,15 @@ func TestNamedExpr(t *testing.T) {
 | 
			
		||||
func TestExpression(t *testing.T) {
 | 
			
		||||
	column := "column-name"
 | 
			
		||||
	results := []struct {
 | 
			
		||||
		Expressions []clause.Expression
 | 
			
		||||
		Result      string
 | 
			
		||||
		Expressions  []clause.Expression
 | 
			
		||||
		ExpectedVars []interface{}
 | 
			
		||||
		Result       string
 | 
			
		||||
	}{{
 | 
			
		||||
		Expressions: []clause.Expression{
 | 
			
		||||
			clause.Eq{Column: column, Value: "column-value"},
 | 
			
		||||
		},
 | 
			
		||||
		Result: "`column-name` = ?",
 | 
			
		||||
		ExpectedVars: []interface{}{"column-value"},
 | 
			
		||||
		Result:       "`column-name` = ?",
 | 
			
		||||
	}, {
 | 
			
		||||
		Expressions: []clause.Expression{
 | 
			
		||||
			clause.Eq{Column: column, Value: nil},
 | 
			
		||||
@ -126,7 +128,8 @@ func TestExpression(t *testing.T) {
 | 
			
		||||
		Expressions: []clause.Expression{
 | 
			
		||||
			clause.Neq{Column: column, Value: "column-value"},
 | 
			
		||||
		},
 | 
			
		||||
		Result: "`column-name` <> ?",
 | 
			
		||||
		ExpectedVars: []interface{}{"column-value"},
 | 
			
		||||
		Result:       "`column-name` <> ?",
 | 
			
		||||
	}, {
 | 
			
		||||
		Expressions: []clause.Expression{
 | 
			
		||||
			clause.Neq{Column: column, Value: nil},
 | 
			
		||||
@ -136,6 +139,18 @@ func TestExpression(t *testing.T) {
 | 
			
		||||
			clause.Neq{Column: column, Value: (interface{})(nil)},
 | 
			
		||||
		},
 | 
			
		||||
		Result: "`column-name` IS NOT NULL",
 | 
			
		||||
	}, {
 | 
			
		||||
		Expressions: []clause.Expression{
 | 
			
		||||
			clause.Eq{Column: column, Value: []string{"a", "b"}},
 | 
			
		||||
		},
 | 
			
		||||
		ExpectedVars: []interface{}{"a", "b"},
 | 
			
		||||
		Result:       "`column-name` IN (?,?)",
 | 
			
		||||
	}, {
 | 
			
		||||
		Expressions: []clause.Expression{
 | 
			
		||||
			clause.Neq{Column: column, Value: []string{"a", "b"}},
 | 
			
		||||
		},
 | 
			
		||||
		ExpectedVars: []interface{}{"a", "b"},
 | 
			
		||||
		Result:       "`column-name` NOT IN (?,?)",
 | 
			
		||||
	}}
 | 
			
		||||
 | 
			
		||||
	for idx, result := range results {
 | 
			
		||||
@ -147,6 +162,10 @@ func TestExpression(t *testing.T) {
 | 
			
		||||
				if stmt.SQL.String() != result.Result {
 | 
			
		||||
					t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String())
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) {
 | 
			
		||||
					t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars)
 | 
			
		||||
				}
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ package clause
 | 
			
		||||
type OnConflict struct {
 | 
			
		||||
	Columns      []Column
 | 
			
		||||
	Where        Where
 | 
			
		||||
	TargetWhere  Where
 | 
			
		||||
	OnConstraint string
 | 
			
		||||
	DoNothing    bool
 | 
			
		||||
	DoUpdates    Set
 | 
			
		||||
@ -25,6 +26,12 @@ func (onConflict OnConflict) Build(builder Builder) {
 | 
			
		||||
		}
 | 
			
		||||
		builder.WriteString(`) `)
 | 
			
		||||
	}
 | 
			
		||||
	
 | 
			
		||||
	if len(onConflict.TargetWhere.Exprs) > 0 {
 | 
			
		||||
		builder.WriteString(" WHERE ")
 | 
			
		||||
		onConflict.TargetWhere.Build(builder)
 | 
			
		||||
		builder.WriteByte(' ')
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if onConflict.OnConstraint != "" {
 | 
			
		||||
		builder.WriteString("ON CONSTRAINT ")
 | 
			
		||||
 | 
			
		||||
@ -79,7 +79,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
 | 
			
		||||
		if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
 | 
			
		||||
			tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
 | 
			
		||||
		}
 | 
			
		||||
		tx = tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true))
 | 
			
		||||
		tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
 | 
			
		||||
	case reflect.Struct:
 | 
			
		||||
		if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
 | 
			
		||||
			for _, pf := range tx.Statement.Schema.PrimaryFields {
 | 
			
		||||
@ -190,16 +190,17 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
 | 
			
		||||
 | 
			
		||||
		if tx.Error != nil || int(result.RowsAffected) < batchSize {
 | 
			
		||||
			break
 | 
			
		||||
		} else {
 | 
			
		||||
			resultsValue := reflect.Indirect(reflect.ValueOf(dest))
 | 
			
		||||
			if result.Statement.Schema.PrioritizedPrimaryField == nil {
 | 
			
		||||
				tx.AddError(ErrPrimaryKeyRequired)
 | 
			
		||||
				break
 | 
			
		||||
			} else {
 | 
			
		||||
				primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
 | 
			
		||||
				queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Optimize for-break
 | 
			
		||||
		resultsValue := reflect.Indirect(reflect.ValueOf(dest))
 | 
			
		||||
		if result.Statement.Schema.PrioritizedPrimaryField == nil {
 | 
			
		||||
			tx.AddError(ErrPrimaryKeyRequired)
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1))
 | 
			
		||||
		queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx.RowsAffected = rowsAffected
 | 
			
		||||
@ -304,7 +305,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
			
		||||
 | 
			
		||||
		return tx.Create(dest)
 | 
			
		||||
	} else if len(db.Statement.assigns) > 0 {
 | 
			
		||||
		exprs := tx.Statement.BuildCondition(tx.Statement.assigns[0], tx.Statement.assigns[1:]...)
 | 
			
		||||
		exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
 | 
			
		||||
		assigns := map[string]interface{}{}
 | 
			
		||||
		for _, expr := range exprs {
 | 
			
		||||
			if eq, ok := expr.(clause.Eq); ok {
 | 
			
		||||
@ -382,9 +383,9 @@ func (db *DB) Count(count *int64) (tx *DB) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(tx.Statement.Selects) == 0 {
 | 
			
		||||
		tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
 | 
			
		||||
		tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
 | 
			
		||||
	} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
 | 
			
		||||
		expr := clause.Expr{SQL: "count(1)"}
 | 
			
		||||
		expr := clause.Expr{SQL: "count(*)"}
 | 
			
		||||
 | 
			
		||||
		if len(tx.Statement.Selects) == 1 {
 | 
			
		||||
			dbName := tx.Statement.Selects[0]
 | 
			
		||||
@ -425,7 +426,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DB) Row() *sql.Row {
 | 
			
		||||
	tx := db.getInstance().InstanceSet("rows", false)
 | 
			
		||||
	tx := db.getInstance().Set("rows", false)
 | 
			
		||||
	tx = tx.callbacks.Row().Execute(tx)
 | 
			
		||||
	row, ok := tx.Statement.Dest.(*sql.Row)
 | 
			
		||||
	if !ok && tx.DryRun {
 | 
			
		||||
@ -435,7 +436,7 @@ func (db *DB) Row() *sql.Row {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (db *DB) Rows() (*sql.Rows, error) {
 | 
			
		||||
	tx := db.getInstance().InstanceSet("rows", true)
 | 
			
		||||
	tx := db.getInstance().Set("rows", true)
 | 
			
		||||
	tx = tx.callbacks.Row().Execute(tx)
 | 
			
		||||
	rows, ok := tx.Statement.Dest.(*sql.Rows)
 | 
			
		||||
	if !ok && tx.DryRun && tx.Error == nil {
 | 
			
		||||
@ -473,7 +474,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
 | 
			
		||||
 | 
			
		||||
// Pluck used to query single column from a model as a map
 | 
			
		||||
//     var ages []int64
 | 
			
		||||
//     db.Find(&users).Pluck("age", &ages)
 | 
			
		||||
//     db.Model(&users).Pluck("age", &ages)
 | 
			
		||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
 | 
			
		||||
	tx = db.getInstance()
 | 
			
		||||
	if tx.Statement.Model != nil {
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								gorm.go
									
									
									
									
									
								
							@ -409,7 +409,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
 | 
			
		||||
				}
 | 
			
		||||
				ref.ForeignKey = f
 | 
			
		||||
			} else {
 | 
			
		||||
				return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName)
 | 
			
		||||
				return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -422,7 +422,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
 | 
			
		||||
 | 
			
		||||
		relation.JoinTable = joinSchema
 | 
			
		||||
	} else {
 | 
			
		||||
		return fmt.Errorf("failed to found relation: %v", field)
 | 
			
		||||
		return fmt.Errorf("failed to found relation: %s", field)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,7 @@ type Interface interface {
 | 
			
		||||
	Info(context.Context, string, ...interface{})
 | 
			
		||||
	Warn(context.Context, string, ...interface{})
 | 
			
		||||
	Error(context.Context, string, ...interface{})
 | 
			
		||||
	Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error)
 | 
			
		||||
	Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
 | 
			
		||||
@ -119,13 +119,10 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
 | 
			
		||||
 | 
			
		||||
				for _, rel := range stmt.Schema.Relationships.Relations {
 | 
			
		||||
					if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
 | 
			
		||||
						if constraint := rel.ParseConstraint(); constraint != nil {
 | 
			
		||||
							if constraint.Schema == stmt.Schema {
 | 
			
		||||
								if !tx.Migrator().HasConstraint(value, constraint.Name) {
 | 
			
		||||
									if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
 | 
			
		||||
										return err
 | 
			
		||||
									}
 | 
			
		||||
								}
 | 
			
		||||
						if constraint := rel.ParseConstraint(); constraint != nil &&
 | 
			
		||||
							constraint.Schema == stmt.Schema && !tx.Migrator().HasConstraint(value, constraint.Name) {
 | 
			
		||||
							if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
 | 
			
		||||
								return err
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
@ -294,16 +291,20 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
 | 
			
		||||
 | 
			
		||||
func (m Migrator) AddColumn(value interface{}, field string) error {
 | 
			
		||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
			
		||||
		if field := stmt.Schema.LookUpField(field); field != nil {
 | 
			
		||||
			if !field.IgnoreMigration {
 | 
			
		||||
				return m.DB.Exec(
 | 
			
		||||
					"ALTER TABLE ? ADD ? ?",
 | 
			
		||||
					m.CurrentTable(stmt), clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field),
 | 
			
		||||
				).Error
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		// avoid using the same name field
 | 
			
		||||
		f := stmt.Schema.LookUpField(field)
 | 
			
		||||
		if f == nil {
 | 
			
		||||
			return fmt.Errorf("failed to look up field with name: %s", field)
 | 
			
		||||
		}
 | 
			
		||||
		return fmt.Errorf("failed to look up field with name: %s", field)
 | 
			
		||||
 | 
			
		||||
		if !f.IgnoreMigration {
 | 
			
		||||
			return m.DB.Exec(
 | 
			
		||||
				"ALTER TABLE ? ADD ? ?",
 | 
			
		||||
				m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f),
 | 
			
		||||
			).Error
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -64,7 +64,7 @@ func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransact
 | 
			
		||||
		db.Stmts[query] = Stmt{Stmt: stmt, Transaction: isTransaction}
 | 
			
		||||
		db.PreparedSQL = append(db.PreparedSQL, query)
 | 
			
		||||
	}
 | 
			
		||||
	db.Mux.Unlock()
 | 
			
		||||
	defer db.Mux.Unlock()
 | 
			
		||||
 | 
			
		||||
	return db.Stmts[query], err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -198,28 +198,28 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
			
		||||
		field.DataType = Bool
 | 
			
		||||
		if field.HasDefaultValue && !skipParseDefaultValue {
 | 
			
		||||
			if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil {
 | 
			
		||||
				schema.err = fmt.Errorf("failed to parse %v as default value for bool, got error: %v", field.DefaultValue, err)
 | 
			
		||||
				schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 | 
			
		||||
		field.DataType = Int
 | 
			
		||||
		if field.HasDefaultValue && !skipParseDefaultValue {
 | 
			
		||||
			if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil {
 | 
			
		||||
				schema.err = fmt.Errorf("failed to parse %v as default value for int, got error: %v", field.DefaultValue, err)
 | 
			
		||||
				schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 | 
			
		||||
		field.DataType = Uint
 | 
			
		||||
		if field.HasDefaultValue && !skipParseDefaultValue {
 | 
			
		||||
			if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil {
 | 
			
		||||
				schema.err = fmt.Errorf("failed to parse %v as default value for uint, got error: %v", field.DefaultValue, err)
 | 
			
		||||
				schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.Float32, reflect.Float64:
 | 
			
		||||
		field.DataType = Float
 | 
			
		||||
		if field.HasDefaultValue && !skipParseDefaultValue {
 | 
			
		||||
			if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil {
 | 
			
		||||
				schema.err = fmt.Errorf("failed to parse %v as default value for float, got error: %v", field.DefaultValue, err)
 | 
			
		||||
				schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.String:
 | 
			
		||||
@ -227,7 +227,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
			
		||||
 | 
			
		||||
		if field.HasDefaultValue && !skipParseDefaultValue {
 | 
			
		||||
			field.DefaultValue = strings.Trim(field.DefaultValue, "'")
 | 
			
		||||
			field.DefaultValue = strings.Trim(field.DefaultValue, "\"")
 | 
			
		||||
			field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
 | 
			
		||||
			field.DefaultValueInterface = field.DefaultValue
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.Struct:
 | 
			
		||||
@ -392,7 +392,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			schema.err = fmt.Errorf("invalid embedded struct for %v's field %v, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
 | 
			
		||||
			schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -423,12 +423,12 @@ func (field *Field) setupValuerAndSetter() {
 | 
			
		||||
				} else {
 | 
			
		||||
					v = v.Field(-idx - 1)
 | 
			
		||||
 | 
			
		||||
					if v.Type().Elem().Kind() == reflect.Struct {
 | 
			
		||||
						if !v.IsNil() {
 | 
			
		||||
							v = v.Elem()
 | 
			
		||||
						} else {
 | 
			
		||||
							return nil, true
 | 
			
		||||
						}
 | 
			
		||||
					if v.Type().Elem().Kind() != reflect.Struct {
 | 
			
		||||
						return nil, true
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					if !v.IsNil() {
 | 
			
		||||
						v = v.Elem()
 | 
			
		||||
					} else {
 | 
			
		||||
						return nil, true
 | 
			
		||||
					}
 | 
			
		||||
@ -736,7 +736,7 @@ func (field *Field) setupValuerAndSetter() {
 | 
			
		||||
					if t, err := now.Parse(data); err == nil {
 | 
			
		||||
						field.ReflectValueOf(value).Set(reflect.ValueOf(t))
 | 
			
		||||
					} else {
 | 
			
		||||
						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
 | 
			
		||||
						return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
 | 
			
		||||
					}
 | 
			
		||||
				default:
 | 
			
		||||
					return fallbackSetter(value, v, field.Set)
 | 
			
		||||
@ -765,7 +765,7 @@ func (field *Field) setupValuerAndSetter() {
 | 
			
		||||
						}
 | 
			
		||||
						fieldValue.Elem().Set(reflect.ValueOf(t))
 | 
			
		||||
					} else {
 | 
			
		||||
						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err)
 | 
			
		||||
						return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
 | 
			
		||||
					}
 | 
			
		||||
				default:
 | 
			
		||||
					return fallbackSetter(value, v, field.Set)
 | 
			
		||||
 | 
			
		||||
@ -74,7 +74,9 @@ func (ns NamingStrategy) IndexName(table, column string) string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
 | 
			
		||||
	formattedName := strings.Replace(fmt.Sprintf("%v_%v_%v", prefix, table, name), ".", "_", -1)
 | 
			
		||||
	formattedName := strings.Replace(strings.Join([]string{
 | 
			
		||||
		prefix, table, name,
 | 
			
		||||
	}, "_"), ".", "_", -1)
 | 
			
		||||
 | 
			
		||||
	if utf8.RuneCountInString(formattedName) > 64 {
 | 
			
		||||
		h := sha1.New()
 | 
			
		||||
 | 
			
		||||
@ -85,7 +85,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
 | 
			
		||||
		case reflect.Slice:
 | 
			
		||||
			schema.guessRelation(relation, field, guessHas)
 | 
			
		||||
		default:
 | 
			
		||||
			schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
 | 
			
		||||
			schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -143,11 +143,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if relation.Polymorphic.PolymorphicType == nil {
 | 
			
		||||
		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
 | 
			
		||||
		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if relation.Polymorphic.PolymorphicID == nil {
 | 
			
		||||
		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %v, missing field %v", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
 | 
			
		||||
		schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if schema.err == nil {
 | 
			
		||||
@ -159,7 +159,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
 | 
			
		||||
		primaryKeyField := schema.PrioritizedPrimaryField
 | 
			
		||||
		if len(relation.foreignKeys) > 0 {
 | 
			
		||||
			if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
 | 
			
		||||
				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name)
 | 
			
		||||
				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@ -203,7 +203,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 | 
			
		||||
			if field := schema.LookUpField(foreignKey); field != nil {
 | 
			
		||||
				ownForeignFields = append(ownForeignFields, field)
 | 
			
		||||
			} else {
 | 
			
		||||
				schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
 | 
			
		||||
				schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
@ -215,7 +215,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
 | 
			
		||||
			if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
 | 
			
		||||
				refForeignFields = append(refForeignFields, field)
 | 
			
		||||
			} else {
 | 
			
		||||
				schema.err = fmt.Errorf("invalid foreign key: %v", foreignKey)
 | 
			
		||||
				schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
@ -379,7 +379,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
 | 
			
		||||
			schema.guessRelation(relation, field, guessEmbeddedHas)
 | 
			
		||||
		// case guessEmbeddedHas:
 | 
			
		||||
		default:
 | 
			
		||||
			schema.err = fmt.Errorf("invalid field found for struct %v's field %v, need to define a valid foreign key for relations or it need to implement the Valuer/Scanner interface", schema, field.Name)
 | 
			
		||||
			schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -45,9 +45,9 @@ type Schema struct {
 | 
			
		||||
 | 
			
		||||
func (schema Schema) String() string {
 | 
			
		||||
	if schema.ModelType.Name() == "" {
 | 
			
		||||
		return fmt.Sprintf("%v(%v)", schema.Name, schema.Table)
 | 
			
		||||
		return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%v.%v", schema.ModelType.PkgPath(), schema.ModelType.Name())
 | 
			
		||||
	return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (schema Schema) MakeSlice() reflect.Value {
 | 
			
		||||
@ -86,7 +86,7 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
 | 
			
		||||
		if modelType.PkgPath() == "" {
 | 
			
		||||
			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 | 
			
		||||
		}
 | 
			
		||||
		return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 | 
			
		||||
		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if v, ok := cacheStore.Load(modelType); ok {
 | 
			
		||||
@ -275,7 +275,7 @@ func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, e
 | 
			
		||||
		if modelType.PkgPath() == "" {
 | 
			
		||||
			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 | 
			
		||||
		}
 | 
			
		||||
		return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 | 
			
		||||
		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if v, ok := cacheStore.Load(modelType); ok {
 | 
			
		||||
 | 
			
		||||
@ -178,17 +178,18 @@ func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interfa
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
 | 
			
		||||
	} else {
 | 
			
		||||
		columns := make([]clause.Column, len(foreignKeys))
 | 
			
		||||
		for idx, key := range foreignKeys {
 | 
			
		||||
			columns[idx] = clause.Column{Table: table, Name: key}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for idx, r := range foreignValues {
 | 
			
		||||
			queryValues[idx] = r
 | 
			
		||||
		}
 | 
			
		||||
		return columns, queryValues
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	columns := make([]clause.Column, len(foreignKeys))
 | 
			
		||||
	for idx, key := range foreignKeys {
 | 
			
		||||
		columns[idx] = clause.Column{Table: table, Name: key}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for idx, r := range foreignValues {
 | 
			
		||||
		queryValues[idx] = r
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return columns, queryValues
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type embeddedNamer struct {
 | 
			
		||||
 | 
			
		||||
@ -84,6 +84,32 @@ func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
 | 
			
		||||
	return []clause.Interface{SoftDeleteUpdateClause{Field: f}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SoftDeleteUpdateClause struct {
 | 
			
		||||
	Field *schema.Field
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sd SoftDeleteUpdateClause) Name() string {
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sd SoftDeleteUpdateClause) Build(clause.Builder) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
 | 
			
		||||
	if stmt.SQL.String() == "" {
 | 
			
		||||
		if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok {
 | 
			
		||||
			SoftDeleteQueryClause(sd).ModifyStatement(stmt)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
 | 
			
		||||
	return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										17
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								statement.go
									
									
									
									
									
								
							@ -57,12 +57,12 @@ type StatementModifier interface {
 | 
			
		||||
	ModifyStatement(*Statement)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Write write string
 | 
			
		||||
// WriteString write string
 | 
			
		||||
func (stmt *Statement) WriteString(str string) (int, error) {
 | 
			
		||||
	return stmt.SQL.WriteString(str)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Write write string
 | 
			
		||||
// WriteByte write byte
 | 
			
		||||
func (stmt *Statement) WriteByte(c byte) error {
 | 
			
		||||
	return stmt.SQL.WriteByte(c)
 | 
			
		||||
}
 | 
			
		||||
@ -152,7 +152,7 @@ func (stmt *Statement) Quote(field interface{}) string {
 | 
			
		||||
	return builder.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Write write string
 | 
			
		||||
// AddVar add var
 | 
			
		||||
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
 | 
			
		||||
	for idx, v := range vars {
 | 
			
		||||
		if idx > 0 {
 | 
			
		||||
@ -506,7 +506,6 @@ func (stmt *Statement) clone() *Statement {
 | 
			
		||||
	return newStmt
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Helpers
 | 
			
		||||
// SetColumn set column's value
 | 
			
		||||
//   stmt.SetColumn("Name", "jinzhu") // Hooks Method
 | 
			
		||||
//   stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
 | 
			
		||||
@ -540,11 +539,6 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !stmt.ReflectValue.CanAddr() {
 | 
			
		||||
				stmt.AddError(ErrInvalidValue)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch stmt.ReflectValue.Kind() {
 | 
			
		||||
			case reflect.Slice, reflect.Array:
 | 
			
		||||
				if len(fromCallbacks) > 0 {
 | 
			
		||||
@ -555,6 +549,11 @@ func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks .
 | 
			
		||||
					field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
 | 
			
		||||
				}
 | 
			
		||||
			case reflect.Struct:
 | 
			
		||||
				if !stmt.ReflectValue.CanAddr() {
 | 
			
		||||
					stmt.AddError(ErrInvalidValue)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				field.Set(stmt.ReflectValue, value)
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
 | 
			
		||||
@ -64,7 +64,7 @@ func TestAssociationNotNullClear(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Model(member).Association("Profiles").Clear(); err == nil {
 | 
			
		||||
		t.Fatalf("No error occured during clearind not null association")
 | 
			
		||||
		t.Fatalf("No error occurred during clearind not null association")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -124,7 +124,6 @@ func TestCount(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	var count9 int64
 | 
			
		||||
	if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB {
 | 
			
		||||
		fmt.Println("kdkdkdkdk")
 | 
			
		||||
		return tx.Table("users")
 | 
			
		||||
	}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 {
 | 
			
		||||
		t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
 | 
			
		||||
 | 
			
		||||
@ -3,10 +3,9 @@ module gorm.io/gorm/tests
 | 
			
		||||
go 1.14
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/google/uuid v1.1.1
 | 
			
		||||
	github.com/google/uuid v1.2.0
 | 
			
		||||
	github.com/jinzhu/now v1.1.2
 | 
			
		||||
	github.com/lib/pq v1.6.0
 | 
			
		||||
	github.com/stretchr/testify v1.5.1
 | 
			
		||||
	gorm.io/driver/mysql v1.0.5
 | 
			
		||||
	gorm.io/driver/postgres v1.1.0
 | 
			
		||||
	gorm.io/driver/sqlite v1.1.4
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
package tests_test
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
@ -62,4 +63,12 @@ func TestScopes(t *testing.T) {
 | 
			
		||||
	if result.RowsAffected != 2 {
 | 
			
		||||
		t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var maxId int64
 | 
			
		||||
	userTable := func(db *gorm.DB) *gorm.DB {
 | 
			
		||||
		return db.WithContext(context.Background()).Table("users")
 | 
			
		||||
	}
 | 
			
		||||
	if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil {
 | 
			
		||||
		t.Errorf("select max(id)")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -11,6 +11,7 @@ then
 | 
			
		||||
  cd tests
 | 
			
		||||
  go get -u ./...
 | 
			
		||||
  go mod download
 | 
			
		||||
  go mod tidy
 | 
			
		||||
  cd ..
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -15,17 +15,20 @@ var gormSourceDir string
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	_, file, _, _ := runtime.Caller(0)
 | 
			
		||||
	// compatible solution to get gorm source directory with various operating systems
 | 
			
		||||
	gormSourceDir = regexp.MustCompile(`utils.utils\.go`).ReplaceAllString(file, "")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FileWithLineNum return the file name and line number of the current file
 | 
			
		||||
func FileWithLineNum() string {
 | 
			
		||||
	// the second caller usually from gorm internal, so set i start from 2
 | 
			
		||||
	for i := 2; i < 15; i++ {
 | 
			
		||||
		_, file, line, ok := runtime.Caller(i)
 | 
			
		||||
 | 
			
		||||
		if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) {
 | 
			
		||||
			return file + ":" + strconv.FormatInt(int64(line), 10)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user