Merge branch 'master' into master

This commit is contained in:
Jinzhu 2019-01-02 21:33:17 +08:00 committed by GitHub
commit cdce5f5e83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 99 additions and 55 deletions

View File

@ -368,6 +368,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa
return association return association
} }
// setErr set error when the error is not nil. And return Association.
func (association *Association) setErr(err error) *Association { func (association *Association) setErr(err error) *Association {
if err != nil { if err != nil {
association.Error = err association.Error = err

View File

@ -75,7 +75,7 @@ func updateCallback(scope *Scope) {
} else { } else {
for _, field := range scope.Fields() { for _, field := range scope.Fields() {
if scope.changeableField(field) { if scope.changeableField(field) {
if !field.IsPrimaryKey && field.IsNormal { if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) {
if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} }

View File

@ -6,11 +6,11 @@ import (
) )
var ( var (
// ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error
ErrRecordNotFound = errors.New("record not found") ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL // ErrInvalidSQL occurs when you attempt a query with invalid SQL
ErrInvalidSQL = errors.New("invalid SQL") ErrInvalidSQL = errors.New("invalid SQL")
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("no valid transaction") ErrInvalidTransaction = errors.New("no valid transaction")
// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
ErrCantStartTransaction = errors.New("can't start transaction") ErrCantStartTransaction = errors.New("can't start transaction")
@ -21,7 +21,7 @@ var (
// Errors contains all happened errors // Errors contains all happened errors
type Errors []error type Errors []error
// IsRecordNotFoundError returns current error has record not found error or not // IsRecordNotFoundError returns true if error contains a RecordNotFound error
func IsRecordNotFoundError(err error) bool { func IsRecordNotFoundError(err error) bool {
if errs, ok := err.(Errors); ok { if errs, ok := err.(Errors); ok {
for _, err := range errs { for _, err := range errs {
@ -33,12 +33,12 @@ func IsRecordNotFoundError(err error) bool {
return err == ErrRecordNotFound return err == ErrRecordNotFound
} }
// GetErrors gets all happened errors // GetErrors gets all errors that have occurred and returns a slice of errors (Error type)
func (errs Errors) GetErrors() []error { func (errs Errors) GetErrors() []error {
return errs return errs
} }
// Add adds an error // Add adds an error to a given slice of errors
func (errs Errors) Add(newErrors ...error) Errors { func (errs Errors) Add(newErrors ...error) Errors {
for _, err := range newErrors { for _, err := range newErrors {
if err == nil { if err == nil {
@ -62,7 +62,7 @@ func (errs Errors) Add(newErrors ...error) Errors {
return errs return errs
} }
// Error format happened errors // Error takes a slice of all errors that have occurred and returns it as a formatted string
func (errs Errors) Error() string { func (errs Errors) Error() string {
var errList = []string{} var errList = []string{}
for _, e := range errs { for _, e := range errs {

View File

@ -2,6 +2,7 @@ package gorm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
@ -44,7 +45,14 @@ func (field *Field) Set(value interface{}) (err error) {
if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { if reflectValue.Type().ConvertibleTo(fieldValue.Type()) {
fieldValue.Set(reflectValue.Convert(fieldValue.Type())) fieldValue.Set(reflectValue.Convert(fieldValue.Type()))
} else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok {
err = scanner.Scan(reflectValue.Interface()) v := reflectValue.Interface()
if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = scanner.Scan(v)
}
} else {
err = scanner.Scan(v)
}
} else { } else {
err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type())
} }

View File

@ -3,6 +3,7 @@ package gorm_test
import ( import (
"testing" "testing"
"github.com/gofrs/uuid"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
@ -47,3 +48,20 @@ func TestCalculateField(t *testing.T) {
t.Errorf("should find embedded field's tag settings") t.Errorf("should find embedded field's tag settings")
} }
} }
func TestFieldSet(t *testing.T) {
type TestFieldSetNullUUID struct {
NullUUID uuid.NullUUID
}
scope := DB.NewScope(&TestFieldSetNullUUID{})
field := scope.Fields()[0]
err := field.Set(uuid.FromStringOrNil("3034d44a-da03-11e8-b366-4a00070b9f00"))
if err != nil {
t.Fatal(err)
}
if id, ok := field.Field.Addr().Interface().(*uuid.NullUUID); !ok {
t.Fatal()
} else if !id.Valid || id.UUID.String() != "3034d44a-da03-11e8-b366-4a00070b9f00" {
t.Fatal(id)
}
}

22
main.go
View File

@ -19,7 +19,7 @@ type DB struct {
// single db // single db
db SQLCommon db SQLCommon
blockGlobalUpdate bool blockGlobalUpdate bool
logMode int logMode logModeValue
logger logger logger logger
search *search search *search
values sync.Map values sync.Map
@ -31,6 +31,14 @@ type DB struct {
singularTable bool singularTable bool
} }
type logModeValue int
const (
defaultLogMode logModeValue = iota
noLogMode
detailedLogMode
)
// Open initialize a new db connection, need to import driver first, e.g: // Open initialize a new db connection, need to import driver first, e.g:
// //
// import _ "github.com/go-sql-driver/mysql" // import _ "github.com/go-sql-driver/mysql"
@ -141,9 +149,9 @@ func (s *DB) SetLogger(log logger) {
// LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs
func (s *DB) LogMode(enable bool) *DB { func (s *DB) LogMode(enable bool) *DB {
if enable { if enable {
s.logMode = 2 s.logMode = detailedLogMode
} else { } else {
s.logMode = 1 s.logMode = noLogMode
} }
return s return s
} }
@ -170,7 +178,7 @@ func (s *DB) SingularTable(enable bool) {
func (s *DB) NewScope(value interface{}) *Scope { func (s *DB) NewScope(value interface{}) *Scope {
dbClone := s.clone() dbClone := s.clone()
dbClone.Value = value dbClone.Value = value
return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} return &Scope{db: dbClone, Search: dbClone.search, Value: value}
} }
// QueryExpr returns the query as expr object // QueryExpr returns the query as expr object
@ -716,7 +724,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join
func (s *DB) AddError(err error) error { func (s *DB) AddError(err error) error {
if err != nil { if err != nil {
if err != ErrRecordNotFound { if err != ErrRecordNotFound {
if s.logMode == 0 { if s.logMode == defaultLogMode {
go s.print(fileWithLineNum(), err) go s.print(fileWithLineNum(), err)
} else { } else {
s.log(err) s.log(err)
@ -780,13 +788,13 @@ func (s *DB) print(v ...interface{}) {
} }
func (s *DB) log(v ...interface{}) { func (s *DB) log(v ...interface{}) {
if s != nil && s.logMode == 2 { if s != nil && s.logMode == detailedLogMode {
s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...)
} }
} }
func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { func (s *DB) slog(sql string, t time.Time, vars ...interface{}) {
if s.logMode == 2 { if s.logMode == detailedLogMode {
s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected)
} }
} }

View File

@ -40,7 +40,7 @@ func (s *ModelStruct) TableName(db *DB) string {
s.defaultTableName = tabler.TableName() s.defaultTableName = tabler.TableName()
} else { } else {
tableName := ToTableName(s.ModelType.Name()) tableName := ToTableName(s.ModelType.Name())
if db == nil || !db.parent.singularTable { if db == nil || (db.parent != nil && !db.parent.singularTable) {
tableName = inflection.Plural(tableName) tableName = inflection.Plural(tableName)
} }
s.defaultTableName = tableName s.defaultTableName = tableName
@ -70,52 +70,52 @@ type StructField struct {
} }
// TagSettingsSet Sets a tag in the tag settings map // TagSettingsSet Sets a tag in the tag settings map
func (s *StructField) TagSettingsSet(key, val string) { func (sf *StructField) TagSettingsSet(key, val string) {
s.tagSettingsLock.Lock() sf.tagSettingsLock.Lock()
defer s.tagSettingsLock.Unlock() defer sf.tagSettingsLock.Unlock()
s.TagSettings[key] = val sf.TagSettings[key] = val
} }
// TagSettingsGet returns a tag from the tag settings // TagSettingsGet returns a tag from the tag settings
func (s *StructField) TagSettingsGet(key string) (string, bool) { func (sf *StructField) TagSettingsGet(key string) (string, bool) {
s.tagSettingsLock.RLock() sf.tagSettingsLock.RLock()
defer s.tagSettingsLock.RUnlock() defer sf.tagSettingsLock.RUnlock()
val, ok := s.TagSettings[key] val, ok := sf.TagSettings[key]
return val, ok return val, ok
} }
// TagSettingsDelete deletes a tag // TagSettingsDelete deletes a tag
func (s *StructField) TagSettingsDelete(key string) { func (sf *StructField) TagSettingsDelete(key string) {
s.tagSettingsLock.Lock() sf.tagSettingsLock.Lock()
defer s.tagSettingsLock.Unlock() defer sf.tagSettingsLock.Unlock()
delete(s.TagSettings, key) delete(sf.TagSettings, key)
} }
func (structField *StructField) clone() *StructField { func (sf *StructField) clone() *StructField {
clone := &StructField{ clone := &StructField{
DBName: structField.DBName, DBName: sf.DBName,
Name: structField.Name, Name: sf.Name,
Names: structField.Names, Names: sf.Names,
IsPrimaryKey: structField.IsPrimaryKey, IsPrimaryKey: sf.IsPrimaryKey,
IsNormal: structField.IsNormal, IsNormal: sf.IsNormal,
IsIgnored: structField.IsIgnored, IsIgnored: sf.IsIgnored,
IsScanner: structField.IsScanner, IsScanner: sf.IsScanner,
HasDefaultValue: structField.HasDefaultValue, HasDefaultValue: sf.HasDefaultValue,
Tag: structField.Tag, Tag: sf.Tag,
TagSettings: map[string]string{}, TagSettings: map[string]string{},
Struct: structField.Struct, Struct: sf.Struct,
IsForeignKey: structField.IsForeignKey, IsForeignKey: sf.IsForeignKey,
} }
if structField.Relationship != nil { if sf.Relationship != nil {
relationship := *structField.Relationship relationship := *sf.Relationship
clone.Relationship = &relationship clone.Relationship = &relationship
} }
// copy the struct field tagSettings, they should be read-locked while they are copied // copy the struct field tagSettings, they should be read-locked while they are copied
structField.tagSettingsLock.Lock() sf.tagSettingsLock.Lock()
defer structField.tagSettingsLock.Unlock() defer sf.tagSettingsLock.Unlock()
for key, value := range structField.TagSettings { for key, value := range sf.TagSettings {
clone.TagSettings[key] = value clone.TagSettings[key] = value
} }

View File

@ -68,7 +68,7 @@ func (scope *Scope) Dialect() Dialect {
// Quote used to quote string to escape them for database // Quote used to quote string to escape them for database
func (scope *Scope) Quote(str string) string { func (scope *Scope) Quote(str string) string {
if strings.Index(str, ".") != -1 { if strings.Contains(str, ".") {
newStrs := []string{} newStrs := []string{}
for _, str := range strings.Split(str, ".") { for _, str := range strings.Split(str, ".") {
newStrs = append(newStrs, scope.Dialect().Quote(str)) newStrs = append(newStrs, scope.Dialect().Quote(str))
@ -330,7 +330,7 @@ func (scope *Scope) TableName() string {
// QuotedTableName return quoted table name // QuotedTableName return quoted table name
func (scope *Scope) QuotedTableName() (name string) { func (scope *Scope) QuotedTableName() (name string) {
if scope.Search != nil && len(scope.Search.tableName) > 0 { if scope.Search != nil && len(scope.Search.tableName) > 0 {
if strings.Index(scope.Search.tableName, " ") != -1 { if strings.Contains(scope.Search.tableName, " ") {
return scope.Search.tableName return scope.Search.tableName
} }
return scope.Quote(scope.Search.tableName) return scope.Quote(scope.Search.tableName)
@ -1309,6 +1309,7 @@ func (scope *Scope) autoIndex() *Scope {
} }
func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) {
resultMap := make(map[string][]interface{})
for _, value := range values { for _, value := range values {
indirectValue := indirect(reflect.ValueOf(value)) indirectValue := indirect(reflect.ValueOf(value))
@ -1327,7 +1328,10 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r
} }
if hasValue { if hasValue {
results = append(results, result) h := fmt.Sprint(result...)
if _, exist := resultMap[h]; !exist {
resultMap[h] = result
}
} }
} }
case reflect.Struct: case reflect.Struct:
@ -1342,11 +1346,16 @@ func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (r
} }
if hasValue { if hasValue {
results = append(results, result) h := fmt.Sprint(result...)
if _, exist := resultMap[h]; !exist {
resultMap[h] = result
} }
} }
} }
}
for _, v := range resultMap {
results = append(results, v)
}
return return
} }