Support sql.Scanner
This commit is contained in:
parent
dc15849313
commit
8e0b125cb1
19
do.go
19
do.go
@ -2,6 +2,7 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -560,13 +561,10 @@ func (s *Do) buildWhereCondition(clause map[string]interface{}) (str string) {
|
|||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||||
default:
|
default:
|
||||||
switch arg.(type) {
|
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
arg, _ = scanner.Value()
|
||||||
value := reflect.ValueOf(arg).Field(0).Interface()
|
|
||||||
str = strings.Replace(str, "?", s.addToVars(value), 1)
|
|
||||||
default:
|
|
||||||
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
|
||||||
}
|
}
|
||||||
|
str = strings.Replace(str, "?", s.addToVars(arg), 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -624,13 +622,10 @@ func (s *Do) buildNotCondition(clause map[string]interface{}) (str string) {
|
|||||||
}
|
}
|
||||||
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
str = strings.Replace(str, "?", strings.Join(temp_marks, ","), 1)
|
||||||
default:
|
default:
|
||||||
switch arg.(type) {
|
if scanner, ok := interface{}(arg).(driver.Valuer); ok {
|
||||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
arg, _ = scanner.Value()
|
||||||
value := reflect.ValueOf(arg).Field(0).Interface()
|
|
||||||
str = strings.Replace(not_equal_sql, "?", s.addToVars(value), 1)
|
|
||||||
default:
|
|
||||||
str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1)
|
|
||||||
}
|
}
|
||||||
|
str = strings.Replace(not_equal_sql, "?", s.addToVars(arg), 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
28
model.go
28
model.go
@ -14,6 +14,7 @@ import (
|
|||||||
type Model struct {
|
type Model struct {
|
||||||
data interface{}
|
data interface{}
|
||||||
driver string
|
driver string
|
||||||
|
debug bool
|
||||||
_cache_fields map[string][]Field
|
_cache_fields map[string][]Field
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,11 +107,13 @@ func (m *Model) fields(operation string) (fields []Field) {
|
|||||||
if is_time {
|
if is_time {
|
||||||
field.IsBlank = time_value.IsZero()
|
field.IsBlank = time_value.IsZero()
|
||||||
} else {
|
} else {
|
||||||
switch value.Interface().(type) {
|
_, is_scanner := reflect.New(value.Type()).Interface().(sql.Scanner)
|
||||||
case sql.NullInt64, sql.NullFloat64, sql.NullBool, sql.NullString:
|
|
||||||
|
if is_scanner {
|
||||||
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
|
field.IsBlank = !value.FieldByName("Valid").Interface().(bool)
|
||||||
default:
|
} else {
|
||||||
m := &Model{data: value.Interface(), driver: m.driver}
|
m := &Model{data: value.Interface(), driver: m.driver}
|
||||||
|
|
||||||
fields := m.columnsHasValue("other")
|
fields := m.columnsHasValue("other")
|
||||||
if len(fields) == 0 {
|
if len(fields) == 0 {
|
||||||
field.IsBlank = true
|
field.IsBlank = true
|
||||||
@ -370,25 +373,14 @@ func setFieldValue(field reflect.Value, value interface{}) bool {
|
|||||||
}
|
}
|
||||||
field.SetInt(reflect.ValueOf(value).Int())
|
field.SetInt(reflect.ValueOf(value).Int())
|
||||||
default:
|
default:
|
||||||
field_type := field.Type()
|
if scanner, ok := field.Addr().Interface().(sql.Scanner); ok {
|
||||||
if field_type == reflect.TypeOf(value) {
|
scanner.Scan(value)
|
||||||
field.Set(reflect.ValueOf(value))
|
|
||||||
} else if value == nil {
|
|
||||||
field.Set(reflect.Zero(field.Type()))
|
|
||||||
} else if field_type == reflect.TypeOf(sql.NullBool{}) {
|
|
||||||
field.Set(reflect.ValueOf(sql.NullBool{value.(bool), true}))
|
|
||||||
} else if field_type == reflect.TypeOf(sql.NullFloat64{}) {
|
|
||||||
field.Set(reflect.ValueOf(sql.NullFloat64{value.(float64), true}))
|
|
||||||
} else if field_type == reflect.TypeOf(sql.NullInt64{}) {
|
|
||||||
field.Set(reflect.ValueOf(sql.NullInt64{value.(int64), true}))
|
|
||||||
} else if field_type == reflect.TypeOf(sql.NullString{}) {
|
|
||||||
field.Set(reflect.ValueOf(sql.NullString{value.(string), true}))
|
|
||||||
} else {
|
} else {
|
||||||
field.Set(reflect.ValueOf(value))
|
field.Set(reflect.ValueOf(value))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
} else {
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user