Support sql.Scanner

This commit is contained in:
Jinzhu 2013-11-10 18:33:37 +08:00
parent dc15849313
commit 8e0b125cb1
3 changed files with 18 additions and 31 deletions

19
do.go
View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -560,13 +561,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
} }
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default: default:
switch arg.(type) { if scanner, ok := interface{}(arg).(driver.Valuer); ok {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: arg, _ = scanner.Value()
value := reflect.ValueOf(arg).Field(0).Interface()
str = strings.Replace(str, "?", s.addToVars(value), 1)
default:
str = strings.Replace(str, "?", s.addToVars(arg), 1)
} }
str = strings.Replace(str, "?", s.addToVars(arg), 1)
} }
} }
return return
@ -624,13 +622,10 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
} }
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1) str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
default: default:
switch arg.(type) { if scanner, ok := interface{}(arg).(driver.Valuer); ok {
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString: arg, _ = scanner.Value()
value := reflect.ValueOf(arg).Field(0).Interface()
str = strings.Replace(not_equal_sql, "?", s.addToVars(value), 1)
default:
str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1)
} }
str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1)
} }
} }
return return

View File

@ -14,6 +14,7 @@ import (
type Model struct { type Model struct {
data interface{} data interface{}
driver string driver string
debug bool
_cache_fields map[string][]Field _cache_fields map[string][]Field
} }
@ -106,11 +107,13 @@ func (m *Model) fields(operation string) (fields []Field) {
if is_time { if is_time {
field.IsBlank = time_value.IsZero() field.IsBlank = time_value.IsZero()
} else { } else {
switch value.Interface().(type) { _, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
if is_scanner {
field.IsBlank = !value.FieldByName("Valid").Interface().(bool) field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
default: } else {
m := &Model{data: value.Interface(), driver: m.driver} m := &Model{data: value.Interface(), driver: m.driver}
fields := m.columnsHasValue("other") fields := m.columnsHasValue("other")
if len(fields) == 0 { if len(fields) == 0 {
field.IsBlank = true field.IsBlank = true
@ -370,25 +373,14 @@ func setFieldValue(field reflect.Value, value interface{}) bool {
} }
field.SetInt(reflect.ValueOf(value).Int()) field.SetInt(reflect.ValueOf(value).Int())
default: default:
field_type := field.Type() if scanner, ok := field.Addr().Interface().(sql.Scanner); ok {
if field_type == reflect.TypeOf(value) { scanner.Scan(value)
field.Set(reflect.ValueOf(value))
} else if value == nil {
field.Set(reflect.Zero(field.Type()))
} else if field_type == reflect.TypeOf(sql.NullBool{}) {
field.Set(reflect.ValueOf(sql.NullBool{value.(bool), true}))
} else if field_type == reflect.TypeOf(sql.NullFloat64{}) {
field.Set(reflect.ValueOf(sql.NullFloat64{value.(float64), true}))
} else if field_type == reflect.TypeOf(sql.NullInt64{}) {
field.Set(reflect.ValueOf(sql.NullInt64{value.(int64), true}))
} else if field_type == reflect.TypeOf(sql.NullString{}) {
field.Set(reflect.ValueOf(sql.NullString{value.(string), true}))
} else { } else {
field.Set(reflect.ValueOf(value)) field.Set(reflect.ValueOf(value))
} }
} }
return true return true
} else {
return false
} }
return false
} }

View File

@ -2,8 +2,8 @@ package gorm
import ( import (
"bytes" "bytes"
"fmt"
"fmt"
"strings" "strings"
) )