Support slice value -> sql.Rows
This commit is contained in:
		
							parent
							
								
									b4591d4db8
								
							
						
					
					
						commit
						afe269ec7d
					
				@ -3,7 +3,6 @@ package gorm
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"regexp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -148,6 +147,6 @@ func (h *Expecter) First(out interface{}, where ...interface{}) ExpectedQuery {
 | 
			
		||||
 | 
			
		||||
// Find triggers a Query
 | 
			
		||||
func (h *Expecter) Find(out interface{}, where ...interface{}) ExpectedQuery {
 | 
			
		||||
	fmt.Printf("Expecting query: %s\n", "some query involving Find")
 | 
			
		||||
	return h.adapter.ExpectQuery("some find condition")
 | 
			
		||||
	h.gorm.Find(out, where...)
 | 
			
		||||
	return h.adapter.ExpectQuery(regexp.QuoteMeta(h.recorder.stmt))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,10 @@ package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
 | 
			
		||||
	"github.com/davecgh/go-spew/spew"
 | 
			
		||||
	sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -20,36 +23,76 @@ type SqlmockQuery struct {
 | 
			
		||||
	query *sqlmock.ExpectedQuery
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
 | 
			
		||||
	var (
 | 
			
		||||
		columns []string
 | 
			
		||||
		rows    *sqlmock.Rows
 | 
			
		||||
		values  []driver.Value
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	q.scope = &Scope{Value: out}
 | 
			
		||||
	fields := q.scope.Fields()
 | 
			
		||||
 | 
			
		||||
func getRowForFields(fields []*Field) []driver.Value {
 | 
			
		||||
	var values []driver.Value
 | 
			
		||||
	for _, field := range fields {
 | 
			
		||||
		if field.IsNormal {
 | 
			
		||||
			var (
 | 
			
		||||
				column = field.StructField.DBName
 | 
			
		||||
				value  = field.Field.Interface()
 | 
			
		||||
			)
 | 
			
		||||
			value := field.Field
 | 
			
		||||
 | 
			
		||||
			if isValue := driver.IsValue(value); isValue {
 | 
			
		||||
				columns = append(columns, column)
 | 
			
		||||
				values = append(values, value)
 | 
			
		||||
			} else if valuer, ok := value.(driver.Valuer); ok {
 | 
			
		||||
				if underlyingValue, err := valuer.Value(); err == nil {
 | 
			
		||||
					values = append(values, underlyingValue)
 | 
			
		||||
					columns = append(columns, field.StructField.DBName)
 | 
			
		||||
			// dereference pointers
 | 
			
		||||
			if field.Field.Kind() == reflect.Ptr {
 | 
			
		||||
				value = reflect.Indirect(field.Field)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// check if we have a zero Value
 | 
			
		||||
			if !value.IsValid() {
 | 
			
		||||
				values = append(values, nil)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			concreteVal := value.Interface()
 | 
			
		||||
 | 
			
		||||
			// if we already have a driver.Value, just append
 | 
			
		||||
			_, isValuer := concreteVal.(driver.Valuer)
 | 
			
		||||
			spew.Printf("%s: %v\r\n", field.DBName, isValuer)
 | 
			
		||||
 | 
			
		||||
			if driver.IsValue(concreteVal) {
 | 
			
		||||
				values = append(values, concreteVal)
 | 
			
		||||
			} else if valuer, ok := concreteVal.(driver.Valuer); ok {
 | 
			
		||||
				if convertedValue, err := valuer.Value(); err == nil {
 | 
			
		||||
					values = append(values, convertedValue)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows = sqlmock.NewRows(columns).AddRow(values...)
 | 
			
		||||
	return values
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *SqlmockQuery) getRowsForOutType(out interface{}) *sqlmock.Rows {
 | 
			
		||||
	var columns []string
 | 
			
		||||
 | 
			
		||||
	for _, field := range (&Scope{}).New(out).GetModelStruct().StructFields {
 | 
			
		||||
		if field.IsNormal {
 | 
			
		||||
			columns = append(columns, field.DBName)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rows := sqlmock.NewRows(columns)
 | 
			
		||||
 | 
			
		||||
	outVal := indirect(reflect.ValueOf(out))
 | 
			
		||||
 | 
			
		||||
	if outVal.Kind() == reflect.Slice {
 | 
			
		||||
		outSlice := []interface{}{}
 | 
			
		||||
		for i := 0; i < outVal.Len(); i++ {
 | 
			
		||||
			outSlice = append(outSlice, outVal.Index(i).Interface())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, outElem := range outSlice {
 | 
			
		||||
			scope := &Scope{Value: outElem}
 | 
			
		||||
			row := getRowForFields(scope.Fields())
 | 
			
		||||
			rows = rows.AddRow(row...)
 | 
			
		||||
		}
 | 
			
		||||
	} else if outVal.Kind() == reflect.Struct {
 | 
			
		||||
		scope := &Scope{Value: out}
 | 
			
		||||
		row := getRowForFields(scope.Fields())
 | 
			
		||||
		rows = rows.AddRow(row...)
 | 
			
		||||
	} else {
 | 
			
		||||
		panic(fmt.Errorf("Can only get rows for slice or struct"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	spew.Dump(columns)
 | 
			
		||||
	spew.Dump(rows)
 | 
			
		||||
 | 
			
		||||
	return rows
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -57,22 +57,67 @@ func TestQueryReturn(t *testing.T) {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in := &User{Id: 1}
 | 
			
		||||
	expectedOut := User{Id: 1, Name: "jinzhu"}
 | 
			
		||||
	in := User{Id: 1}
 | 
			
		||||
	out := User{Id: 1, Name: "jinzhu"}
 | 
			
		||||
 | 
			
		||||
	expect.First(in).Returns(User{Id: 1, Name: "jinzhu"})
 | 
			
		||||
	expect.First(&in).Returns(out)
 | 
			
		||||
 | 
			
		||||
	db.First(in)
 | 
			
		||||
	db.First(&in)
 | 
			
		||||
 | 
			
		||||
	if e := expect.AssertExpectations(); e != nil {
 | 
			
		||||
		t.Error(e)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if in.Name != "jinzhu" {
 | 
			
		||||
		t.Errorf("Expected %s, got %s", expectedOut.Name, in.Name)
 | 
			
		||||
		t.Errorf("Expected %s, got %s", out.Name, in.Name)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ne := reflect.DeepEqual(*in, expectedOut); !ne {
 | 
			
		||||
	if ne := reflect.DeepEqual(in, out); !ne {
 | 
			
		||||
		t.Errorf("Not equal")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFindStructDest(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		db.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in := &User{Id: 1}
 | 
			
		||||
 | 
			
		||||
	expect.Find(in)
 | 
			
		||||
	db.Find(&User{Id: 1})
 | 
			
		||||
 | 
			
		||||
	if e := expect.AssertExpectations(); e != nil {
 | 
			
		||||
		t.Error(e)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFindSlice(t *testing.T) {
 | 
			
		||||
	db, expect, err := gorm.NewDefaultExpecter()
 | 
			
		||||
	defer func() {
 | 
			
		||||
		db.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	in := []User{}
 | 
			
		||||
	out := []User{User{Id: 1, Name: "jinzhu"}, User{Id: 2, Name: "itwx"}}
 | 
			
		||||
 | 
			
		||||
	expect.Find(&in).Returns(&out)
 | 
			
		||||
	db.Find(&in)
 | 
			
		||||
 | 
			
		||||
	if e := expect.AssertExpectations(); e != nil {
 | 
			
		||||
		t.Error(e)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ne := reflect.DeepEqual(in, out); !ne {
 | 
			
		||||
		t.Error("Expected equal slices")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -139,9 +139,9 @@ func (role Role) IsAdmin() bool {
 | 
			
		||||
 | 
			
		||||
type Num int64
 | 
			
		||||
 | 
			
		||||
func (i *Num) Value() (driver.Value, error) {
 | 
			
		||||
func (i Num) Value() (driver.Value, error) {
 | 
			
		||||
	// guaranteed ok
 | 
			
		||||
	return int64(*i), nil
 | 
			
		||||
	return int64(i), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (i *Num) Scan(src interface{}) error {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user