Merge branch 'master' into master

This commit is contained in:
Tung Vu 2018-12-08 12:16:42 +07:00 committed by GitHub
commit 7b9c9e4706
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 6 deletions

View File

@ -267,15 +267,16 @@ func (association *Association) Count() int {
query = scope.DB() query = scope.DB()
) )
if relationship.Kind == "many_to_many" { switch relationship.Kind {
case "many_to_many":
query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { case "has_many", "has_one":
primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value)
query = query.Where( query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)),
toQueryValues(primaryKeys)..., toQueryValues(primaryKeys)...,
) )
} else if relationship.Kind == "belongs_to" { case "belongs_to":
primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value)
query = query.Where( query = query.Where(
fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)),
@ -367,6 +368,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
return association return association
} }
// setErr set error when the error is not nil. And return Association.
func (association *Association) setErr(err error) *Association { func (association *Association) setErr(err error) *Association {
if err != nil { if err != nil {
association.Error = err association.Error = err

View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -44,7 +45,14 @@ func (field *Field) Set(value interface{}) (err error) {
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type())) fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
err = scanner.Scan(reflectValue.Interface()) v := reflectValue.Interface()
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = scanner.Scan(v)
}
} else {
err = scanner.Scan(v)
}
} else { } else {
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
} }

View File

@ -3,6 +3,7 @@ package gorm_test
import ( import (
"testing" "testing"
"github.com/gofrs/uuid"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
@ -47,3 +48,20 @@ func TestCalculateField(t *testing.T) {
t.Errorf("should find embedded field's tag settings") t.Errorf("should find embedded field's tag settings")
} }
} }
func TestFieldSet(t *testing.T) {
type TestFieldSetNullUUID struct {
NullUUID uuid.NullUUID
}
scope := DB.NewScope(&TestFieldSetNullUUID{})
field := scope.Fields()[0]
err := field.Set(uuid.FromStringOrNil("3034d44a-da03-11e8-b366-4a00070b9f00"))
if err != nil {
t.Fatal(err)
}
if id, ok := field.Field.Addr().Interface().(*uuid.NullUUID); !ok {
t.Fatal()
} else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" {
t.Fatal(id)
}
}

View File

@ -68,7 +68,7 @@ func (scope *Scope) Dialect() Dialect {
// Quote used to quote string to escape them for database // Quote used to quote string to escape them for database
func (scope *Scope) Quote(str string) string { func (scope *Scope) Quote(str string) string {
if strings.Index(str, ".") != -1 { if strings.Contains(str, ".") {
newStrs := []string{} newStrs := []string{}
for _, str := range strings.Split(str, ".") { for _, str := range strings.Split(str, ".") {
newStrs = append(newStrs, scope.Dialect().Quote(str)) newStrs = append(newStrs, scope.Dialect().Quote(str))
@ -330,7 +330,7 @@ func (scope *Scope) TableName() string {
// QuotedTableName return quoted table name // QuotedTableName return quoted table name
func (scope *Scope) QuotedTableName() (name string) { func (scope *Scope) QuotedTableName() (name string) {
if scope.Search != nil && len(scope.Search.tableName) > 0 { if scope.Search != nil && len(scope.Search.tableName) > 0 {
if strings.Index(scope.Search.tableName, " ") != -1 { if strings.Contains(scope.Search.tableName, " ") {
return scope.Search.tableName return scope.Search.tableName
} }
return scope.Quote(scope.Search.tableName) return scope.Quote(scope.Search.tableName)