Merge branch 'go-gorm:master' into master
This commit is contained in:
		
						commit
						da07675bae
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -3,3 +3,4 @@ documents
 | 
				
			|||||||
coverage.txt
 | 
					coverage.txt
 | 
				
			||||||
_book
 | 
					_book
 | 
				
			||||||
.idea
 | 
					.idea
 | 
				
			||||||
 | 
					vendor
 | 
				
			||||||
@ -201,13 +201,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
 | 
				
			|||||||
		switch stmt.ReflectValue.Kind() {
 | 
							switch stmt.ReflectValue.Kind() {
 | 
				
			||||||
		case reflect.Slice, reflect.Array:
 | 
							case reflect.Slice, reflect.Array:
 | 
				
			||||||
			rValLen := stmt.ReflectValue.Len()
 | 
								rValLen := stmt.ReflectValue.Len()
 | 
				
			||||||
			stmt.SQL.Grow(rValLen * 18)
 | 
					 | 
				
			||||||
			values.Values = make([][]interface{}, rValLen)
 | 
					 | 
				
			||||||
			if rValLen == 0 {
 | 
								if rValLen == 0 {
 | 
				
			||||||
				stmt.AddError(gorm.ErrEmptySlice)
 | 
									stmt.AddError(gorm.ErrEmptySlice)
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								stmt.SQL.Grow(rValLen * 18)
 | 
				
			||||||
 | 
								stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
 | 
				
			||||||
 | 
								values.Values = make([][]interface{}, rValLen)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
 | 
								defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
 | 
				
			||||||
			for i := 0; i < rValLen; i++ {
 | 
								for i := 0; i < rValLen; i++ {
 | 
				
			||||||
				rv := reflect.Indirect(stmt.ReflectValue.Index(i))
 | 
									rv := reflect.Indirect(stmt.ReflectValue.Index(i))
 | 
				
			||||||
 | 
				
			|||||||
@ -255,6 +255,7 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
 | 
					// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions)
 | 
				
			||||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
					func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
 | 
				
			||||||
	queryTx := db.Limit(1).Order(clause.OrderByColumn{
 | 
						queryTx := db.Limit(1).Order(clause.OrderByColumn{
 | 
				
			||||||
@ -599,11 +600,12 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
 | 
				
			|||||||
		opt = opts[0]
 | 
							opt = opts[0]
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok {
 | 
						switch beginner := tx.Statement.ConnPool.(type) {
 | 
				
			||||||
 | 
						case TxBeginner:
 | 
				
			||||||
		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
							tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
				
			||||||
	} else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok {
 | 
						case ConnPoolBeginner:
 | 
				
			||||||
		tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
							tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
 | 
				
			||||||
	} else {
 | 
						default:
 | 
				
			||||||
		err = ErrInvalidTransaction
 | 
							err = ErrInvalidTransaction
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								gorm.go
									
									
									
									
									
								
							@ -462,7 +462,7 @@ func (db *DB) Use(plugin Plugin) error {
 | 
				
			|||||||
//			.First(&User{})
 | 
					//			.First(&User{})
 | 
				
			||||||
// })
 | 
					// })
 | 
				
			||||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
 | 
					func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
 | 
				
			||||||
	tx := queryFn(db.Session(&Session{DryRun: true}))
 | 
						tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
 | 
				
			||||||
	stmt := tx.Statement
 | 
						stmt := tx.Statement
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
 | 
						return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
 | 
				
			||||||
 | 
				
			|||||||
@ -56,6 +56,13 @@ type TxCommitter interface {
 | 
				
			|||||||
	Rollback() error
 | 
						Rollback() error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Tx sql.Tx interface
 | 
				
			||||||
 | 
					type Tx interface {
 | 
				
			||||||
 | 
						ConnPool
 | 
				
			||||||
 | 
						TxCommitter
 | 
				
			||||||
 | 
						StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Valuer gorm valuer interface
 | 
					// Valuer gorm valuer interface
 | 
				
			||||||
type Valuer interface {
 | 
					type Valuer interface {
 | 
				
			||||||
	GormValue(context.Context, *DB) clause.Expr
 | 
						GormValue(context.Context, *DB) clause.Expr
 | 
				
			||||||
 | 
				
			|||||||
@ -115,7 +115,7 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type PreparedStmtTX struct {
 | 
					type PreparedStmtTX struct {
 | 
				
			||||||
	*sql.Tx
 | 
						Tx
 | 
				
			||||||
	PreparedStmtDB *PreparedStmtDB
 | 
						PreparedStmtDB *PreparedStmtDB
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -151,7 +151,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
 | 
				
			|||||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
					func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
				
			||||||
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
						stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		rows, err = tx.Tx.Stmt(stmt.Stmt).QueryContext(ctx, args...)
 | 
							rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			tx.PreparedStmtDB.Mux.Lock()
 | 
								tx.PreparedStmtDB.Mux.Lock()
 | 
				
			||||||
			defer tx.PreparedStmtDB.Mux.Unlock()
 | 
								defer tx.PreparedStmtDB.Mux.Unlock()
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										111
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										111
									
								
								scan.go
									
									
									
									
									
								
							@ -50,54 +50,40 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue reflect.Value, values []interface{}, columns []string, fields []*schema.Field, joinFields [][2]*schema.Field) {
 | 
					func (db *DB) scanIntoStruct(rows *sql.Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) {
 | 
				
			||||||
	for idx, column := range columns {
 | 
						for idx, field := range fields {
 | 
				
			||||||
		if sch == nil {
 | 
							if field != nil {
 | 
				
			||||||
			values[idx] = reflectValue.Interface()
 | 
					 | 
				
			||||||
		} else if field := sch.LookUpField(column); field != nil && field.Readable {
 | 
					 | 
				
			||||||
			values[idx] = field.NewValuePool.Get()
 | 
								values[idx] = field.NewValuePool.Get()
 | 
				
			||||||
			defer field.NewValuePool.Put(values[idx])
 | 
							} else if len(fields) == 1 {
 | 
				
			||||||
		} else if names := strings.Split(column, "__"); len(names) > 1 {
 | 
								if reflectValue.CanAddr() {
 | 
				
			||||||
			if rel, ok := sch.Relationships.Relations[names[0]]; ok {
 | 
									values[idx] = reflectValue.Addr().Interface()
 | 
				
			||||||
				if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
 | 
								} else {
 | 
				
			||||||
					values[idx] = field.NewValuePool.Get()
 | 
									values[idx] = reflectValue.Interface()
 | 
				
			||||||
					defer field.NewValuePool.Put(values[idx])
 | 
					 | 
				
			||||||
					continue
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			values[idx] = &sql.RawBytes{}
 | 
					 | 
				
			||||||
		} else if len(columns) == 1 {
 | 
					 | 
				
			||||||
			sch = nil
 | 
					 | 
				
			||||||
			values[idx] = reflectValue.Interface()
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			values[idx] = &sql.RawBytes{}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db.RowsAffected++
 | 
						db.RowsAffected++
 | 
				
			||||||
	db.AddError(rows.Scan(values...))
 | 
						db.AddError(rows.Scan(values...))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if sch != nil {
 | 
						for idx, field := range fields {
 | 
				
			||||||
		for idx, column := range columns {
 | 
							if field != nil {
 | 
				
			||||||
			if field := sch.LookUpField(column); field != nil && field.Readable {
 | 
								if len(joinFields) == 0 || joinFields[idx][0] == nil {
 | 
				
			||||||
				field.Set(db.Statement.Context, reflectValue, values[idx])
 | 
									field.Set(db.Statement.Context, reflectValue, values[idx])
 | 
				
			||||||
			} else if names := strings.Split(column, "__"); len(names) > 1 {
 | 
								} else {
 | 
				
			||||||
				if rel, ok := sch.Relationships.Relations[names[0]]; ok {
 | 
									relValue := joinFields[idx][0].ReflectValueOf(db.Statement.Context, reflectValue)
 | 
				
			||||||
					if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
 | 
									if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
 | 
				
			||||||
						relValue := rel.Field.ReflectValueOf(db.Statement.Context, reflectValue)
 | 
										if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
 | 
				
			||||||
 | 
											return
 | 
				
			||||||
						if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
 | 
					 | 
				
			||||||
							if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
 | 
					 | 
				
			||||||
								continue
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
							relValue.Set(reflect.New(relValue.Type().Elem()))
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
						field.Set(db.Statement.Context, relValue, values[idx])
 | 
					 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										relValue.Set(reflect.New(relValue.Type().Elem()))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// release data to pool
 | 
				
			||||||
 | 
								field.NewValuePool.Put(values[idx])
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -176,7 +162,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
			reflectValue = db.Statement.ReflectValue
 | 
								reflectValue = db.Statement.ReflectValue
 | 
				
			||||||
		)
 | 
							)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for reflectValue.Kind() == reflect.Interface {
 | 
							if reflectValue.Kind() == reflect.Interface {
 | 
				
			||||||
			reflectValue = reflectValue.Elem()
 | 
								reflectValue = reflectValue.Elem()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -195,35 +181,38 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
				sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
 | 
									sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for idx, column := range columns {
 | 
					 | 
				
			||||||
				if field := sch.LookUpField(column); field != nil && field.Readable {
 | 
					 | 
				
			||||||
					fields[idx] = field
 | 
					 | 
				
			||||||
				} else if names := strings.Split(column, "__"); len(names) > 1 {
 | 
					 | 
				
			||||||
					if rel, ok := sch.Relationships.Relations[names[0]]; ok {
 | 
					 | 
				
			||||||
						if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
 | 
					 | 
				
			||||||
							fields[idx] = field
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
							if len(joinFields) == 0 {
 | 
					 | 
				
			||||||
								joinFields = make([][2]*schema.Field, len(columns))
 | 
					 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
							joinFields[idx] = [2]*schema.Field{rel.Field, field}
 | 
					 | 
				
			||||||
							continue
 | 
					 | 
				
			||||||
						}
 | 
					 | 
				
			||||||
					}
 | 
					 | 
				
			||||||
					values[idx] = &sql.RawBytes{}
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					values[idx] = &sql.RawBytes{}
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if len(columns) == 1 {
 | 
								if len(columns) == 1 {
 | 
				
			||||||
				// isPluck
 | 
									// Is Pluck
 | 
				
			||||||
				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
 | 
									if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
 | 
				
			||||||
					reflectValueType.Kind() != reflect.Struct || // is not struct
 | 
										reflectValueType.Kind() != reflect.Struct || // is not struct
 | 
				
			||||||
					sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 | 
										sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 | 
				
			||||||
					sch = nil
 | 
										sch = nil
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Not Pluck
 | 
				
			||||||
 | 
								if sch != nil {
 | 
				
			||||||
 | 
									for idx, column := range columns {
 | 
				
			||||||
 | 
										if field := sch.LookUpField(column); field != nil && field.Readable {
 | 
				
			||||||
 | 
											fields[idx] = field
 | 
				
			||||||
 | 
										} else if names := strings.Split(column, "__"); len(names) > 1 {
 | 
				
			||||||
 | 
											if rel, ok := sch.Relationships.Relations[names[0]]; ok {
 | 
				
			||||||
 | 
												if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
 | 
				
			||||||
 | 
													fields[idx] = field
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
													if len(joinFields) == 0 {
 | 
				
			||||||
 | 
														joinFields = make([][2]*schema.Field, len(columns))
 | 
				
			||||||
 | 
													}
 | 
				
			||||||
 | 
													joinFields[idx] = [2]*schema.Field{rel.Field, field}
 | 
				
			||||||
 | 
													continue
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
											values[idx] = &sql.RawBytes{}
 | 
				
			||||||
 | 
										} else {
 | 
				
			||||||
 | 
											values[idx] = &sql.RawBytes{}
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		switch reflectValue.Kind() {
 | 
							switch reflectValue.Kind() {
 | 
				
			||||||
@ -256,7 +245,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
					elem = reflect.New(reflectValueType)
 | 
										elem = reflect.New(reflectValueType)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				db.scanIntoStruct(sch, rows, elem, values, columns, fields, joinFields)
 | 
									db.scanIntoStruct(rows, elem, values, fields, joinFields)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if !update {
 | 
									if !update {
 | 
				
			||||||
					if isPtr {
 | 
										if isPtr {
 | 
				
			||||||
@ -272,7 +261,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		case reflect.Struct, reflect.Ptr:
 | 
							case reflect.Struct, reflect.Ptr:
 | 
				
			||||||
			if initialized || rows.Next() {
 | 
								if initialized || rows.Next() {
 | 
				
			||||||
				db.scanIntoStruct(sch, rows, reflectValue, values, columns, fields, joinFields)
 | 
									db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			db.AddError(rows.Scan(dest))
 | 
								db.AddError(rows.Scan(dest))
 | 
				
			||||||
 | 
				
			|||||||
@ -465,24 +465,33 @@ func (field *Field) setupValuerAndSetter() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ValueOf returns field's value and if it is zero
 | 
						// ValueOf returns field's value and if it is zero
 | 
				
			||||||
	field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
 | 
						fieldIndex := field.StructField.Index[0]
 | 
				
			||||||
		v = reflect.Indirect(v)
 | 
						switch {
 | 
				
			||||||
		for _, fieldIdx := range field.StructField.Index {
 | 
						case len(field.StructField.Index) == 1 && fieldIndex > 0:
 | 
				
			||||||
			if fieldIdx >= 0 {
 | 
							field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
 | 
				
			||||||
				v = v.Field(fieldIdx)
 | 
								fieldValue := reflect.Indirect(value).Field(fieldIndex)
 | 
				
			||||||
			} else {
 | 
								return fieldValue.Interface(), fieldValue.IsZero()
 | 
				
			||||||
				v = v.Field(-fieldIdx - 1)
 | 
							}
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
				if !v.IsNil() {
 | 
							field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
 | 
				
			||||||
					v = v.Elem()
 | 
								v = reflect.Indirect(v)
 | 
				
			||||||
 | 
								for _, fieldIdx := range field.StructField.Index {
 | 
				
			||||||
 | 
									if fieldIdx >= 0 {
 | 
				
			||||||
 | 
										v = v.Field(fieldIdx)
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					return nil, true
 | 
										v = v.Field(-fieldIdx - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										if !v.IsNil() {
 | 
				
			||||||
 | 
											v = v.Elem()
 | 
				
			||||||
 | 
										} else {
 | 
				
			||||||
 | 
											return nil, true
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		fv, zero := v.Interface(), v.IsZero()
 | 
								fv, zero := v.Interface(), v.IsZero()
 | 
				
			||||||
		return fv, zero
 | 
								return fv, zero
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if field.Serializer != nil {
 | 
						if field.Serializer != nil {
 | 
				
			||||||
@ -509,24 +518,31 @@ func (field *Field) setupValuerAndSetter() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ReflectValueOf returns field's reflect value
 | 
						// ReflectValueOf returns field's reflect value
 | 
				
			||||||
	field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
 | 
						switch {
 | 
				
			||||||
		v = reflect.Indirect(v)
 | 
						case len(field.StructField.Index) == 1 && fieldIndex > 0:
 | 
				
			||||||
		for idx, fieldIdx := range field.StructField.Index {
 | 
							field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
 | 
				
			||||||
			if fieldIdx >= 0 {
 | 
								return reflect.Indirect(value).Field(fieldIndex)
 | 
				
			||||||
				v = v.Field(fieldIdx)
 | 
							}
 | 
				
			||||||
			} else {
 | 
						default:
 | 
				
			||||||
				v = v.Field(-fieldIdx - 1)
 | 
							field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
 | 
				
			||||||
 | 
								v = reflect.Indirect(v)
 | 
				
			||||||
 | 
								for idx, fieldIdx := range field.StructField.Index {
 | 
				
			||||||
 | 
									if fieldIdx >= 0 {
 | 
				
			||||||
 | 
										v = v.Field(fieldIdx)
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										v = v.Field(-fieldIdx - 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if v.IsNil() {
 | 
										if v.IsNil() {
 | 
				
			||||||
					v.Set(reflect.New(v.Type().Elem()))
 | 
											v.Set(reflect.New(v.Type().Elem()))
 | 
				
			||||||
				}
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if idx < len(field.StructField.Index)-1 {
 | 
										if idx < len(field.StructField.Index)-1 {
 | 
				
			||||||
					v = v.Elem()
 | 
											v = v.Elem()
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								return v
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return v
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
 | 
						fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										171
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								tests/connpool_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,171 @@
 | 
				
			|||||||
 | 
					package tests_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"gorm.io/driver/mysql"
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
 | 
						. "gorm.io/gorm/utils/tests"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type wrapperTx struct {
 | 
				
			||||||
 | 
						*sql.Tx
 | 
				
			||||||
 | 
						conn *wrapperConnPool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
 | 
				
			||||||
 | 
						c.conn.got = append(c.conn.got, query)
 | 
				
			||||||
 | 
						return c.Tx.PrepareContext(ctx, query)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 | 
				
			||||||
 | 
						c.conn.got = append(c.conn.got, query)
 | 
				
			||||||
 | 
						return c.Tx.ExecContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
 | 
				
			||||||
 | 
						c.conn.got = append(c.conn.got, query)
 | 
				
			||||||
 | 
						return c.Tx.QueryContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
				
			||||||
 | 
						c.conn.got = append(c.conn.got, query)
 | 
				
			||||||
 | 
						return c.Tx.QueryRowContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type wrapperConnPool struct {
 | 
				
			||||||
 | 
						db     *sql.DB
 | 
				
			||||||
 | 
						got    []string
 | 
				
			||||||
 | 
						expect []string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) Ping() error {
 | 
				
			||||||
 | 
						return c.db.Ping()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction.
 | 
				
			||||||
 | 
					// func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
 | 
				
			||||||
 | 
					//	 return c.db.BeginTx(ctx, opts)
 | 
				
			||||||
 | 
					// }
 | 
				
			||||||
 | 
					// You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries.
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) {
 | 
				
			||||||
 | 
						tx, err := c.db.BeginTx(ctx, opts)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return &wrapperTx{Tx: tx, conn: c}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
 | 
				
			||||||
 | 
						c.got = append(c.got, query)
 | 
				
			||||||
 | 
						return c.db.PrepareContext(ctx, query)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
 | 
				
			||||||
 | 
						c.got = append(c.got, query)
 | 
				
			||||||
 | 
						return c.db.ExecContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
 | 
				
			||||||
 | 
						c.got = append(c.got, query)
 | 
				
			||||||
 | 
						return c.db.QueryContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
				
			||||||
 | 
						c.got = append(c.got, query)
 | 
				
			||||||
 | 
						return c.db.QueryRowContext(ctx, query, args...)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestConnPoolWrapper(t *testing.T) {
 | 
				
			||||||
 | 
						dialect := os.Getenv("GORM_DIALECT")
 | 
				
			||||||
 | 
						if dialect != "mysql" {
 | 
				
			||||||
 | 
							t.SkipNow()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dbDSN := os.Getenv("GORM_DSN")
 | 
				
			||||||
 | 
						if dbDSN == "" {
 | 
				
			||||||
 | 
							dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						nativeDB, err := sql.Open("mysql", dbDSN)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should open db success, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn := &wrapperConnPool{
 | 
				
			||||||
 | 
							db: nativeDB,
 | 
				
			||||||
 | 
							expect: []string{
 | 
				
			||||||
 | 
								"SELECT VERSION()",
 | 
				
			||||||
 | 
								"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
								"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
								"INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
								"SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT 1",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							if !reflect.DeepEqual(conn.got, conn.expect) {
 | 
				
			||||||
 | 
								t.Errorf("expect %#v but got %#v", conn.expect, conn.got)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn}))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should open db success, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx := db.Begin()
 | 
				
			||||||
 | 
						user := *GetUser("transaction", Config{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx.Save(&user).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("No error should raise, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should find saved record, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						user1 := *GetUser("transaction1-1", Config{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx.Save(&user1).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("No error should raise, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should find saved record, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should return the underlying sql.Tx")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx.Rollback()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should not find record after rollback, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						txDB := db.Where("fake_name = ?", "fake_name")
 | 
				
			||||||
 | 
						tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin()
 | 
				
			||||||
 | 
						user2 := *GetUser("transaction-2", Config{})
 | 
				
			||||||
 | 
						if err = tx2.Save(&user2).Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("No error should raise, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should find saved record, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tx2.Commit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Should be able to find committed record, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -3,11 +3,12 @@ module gorm.io/gorm/tests
 | 
				
			|||||||
go 1.14
 | 
					go 1.14
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
 | 
						github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
 | 
				
			||||||
	github.com/google/uuid v1.3.0
 | 
						github.com/google/uuid v1.3.0
 | 
				
			||||||
	github.com/jackc/pgx/v4 v4.15.0 // indirect
 | 
						github.com/jackc/pgx/v4 v4.15.0 // indirect
 | 
				
			||||||
	github.com/jinzhu/now v1.1.4
 | 
						github.com/jinzhu/now v1.1.4
 | 
				
			||||||
	github.com/lib/pq v1.10.4
 | 
						github.com/lib/pq v1.10.4
 | 
				
			||||||
	github.com/mattn/go-sqlite3 v1.14.11 // indirect
 | 
						github.com/mattn/go-sqlite3 v1.14.12 // indirect
 | 
				
			||||||
	golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
						golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
				
			||||||
	gorm.io/driver/mysql v1.3.2
 | 
						gorm.io/driver/mysql v1.3.2
 | 
				
			||||||
	gorm.io/driver/postgres v1.3.1
 | 
						gorm.io/driver/postgres v1.3.1
 | 
				
			||||||
 | 
				
			|||||||
@ -1158,3 +1158,39 @@ func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
 | 
							t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type DoubleInt64 struct {
 | 
				
			||||||
 | 
						data int64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (t *DoubleInt64) Scan(val interface{}) error {
 | 
				
			||||||
 | 
						switch v := val.(type) {
 | 
				
			||||||
 | 
						case int64:
 | 
				
			||||||
 | 
							t.data = v * 2
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return fmt.Errorf("DoubleInt64 cant not scan with:%v", v)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// https://github.com/go-gorm/gorm/issues/5091
 | 
				
			||||||
 | 
					func TestQueryScannerWithSingleColumn(t *testing.T) {
 | 
				
			||||||
 | 
						user := User{Name: "scanner_raw_1", Age: 10}
 | 
				
			||||||
 | 
						DB.Create(&user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var result1 DoubleInt64
 | 
				
			||||||
 | 
						if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck(
 | 
				
			||||||
 | 
							"age", &result1).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Failed, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						AssertEqual(t, result1.data, 20)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var result2 DoubleInt64
 | 
				
			||||||
 | 
						if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select(
 | 
				
			||||||
 | 
							"age").Scan(&result2).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Failed, got error: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						AssertEqual(t, result2.data, 20)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user