Merge branch 'master' into master

This commit is contained in:
Jinzhu 2018-08-19 07:16:01 +08:00 committed by GitHub
commit 79c826ccad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 145 additions and 47 deletions

View File

@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) {
return
}
if _, ok := scope.Get("gorm:auto_preload"); ok {
autoPreload(scope)
if ap, ok := scope.Get("gorm:auto_preload"); ok {
// If gorm:auto_preload IS NOT a bool then auto preload.
// Else if it IS a bool, use the value
if apb, ok := ap.(bool); !ok {
autoPreload(scope)
} else if apb {
autoPreload(scope)
}
}
if scope.Search.preload == nil || scope.HasError() {

View File

@ -1,12 +1,16 @@
package mssql
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"time"
// Importing mssql driver package only in dialect file, otherwide not needed
_ "github.com/denisenkom/go-mssqldb"
"github.com/jinzhu/gorm"
)
@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st
}
return dialect.CurrentDatabase(), tableName
}
// JSON type to support easy handling of JSON data in character table fields
// using golang json.RawMessage for deferred decoding/encoding
type JSON struct {
json.RawMessage
}
// Value get value of JSON
func (j JSON) Value() (driver.Value, error) {
if len(j.RawMessage) == 0 {
return nil, nil
}
return j.MarshalJSON()
}
// Scan scan value into JSON
func (j *JSON) Scan(value interface{}) error {
str, ok := value.(string)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value))
}
bytes := []byte(str)
return json.Unmarshal(bytes, j)
}

View File

@ -6,7 +6,7 @@ import (
)
var (
// ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct
// ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error
ErrRecordNotFound = errors.New("record not found")
// ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL
ErrInvalidSQL = errors.New("invalid SQL")

13
main.go
View File

@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"time"
)
@ -48,6 +49,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
}
var source string
var dbSQL SQLCommon
var ownDbSQL bool
switch value := args[0].(type) {
case string:
@ -59,8 +61,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
source = args[1].(string)
}
dbSQL, err = sql.Open(driver, source)
ownDbSQL = true
case SQLCommon:
dbSQL = value
ownDbSQL = false
default:
return nil, fmt.Errorf("invalid database source: %v is not a valid type", value)
}
@ -78,7 +82,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) {
}
// Send a ping to make sure the database connection is alive.
if d, ok := dbSQL.(*sql.DB); ok {
if err = d.Ping(); err != nil {
if err = d.Ping(); err != nil && ownDbSQL {
d.Close()
}
}
@ -119,7 +123,7 @@ func (s *DB) CommonDB() SQLCommon {
// Dialect get dialect
func (s *DB) Dialect() Dialect {
return s.parent.dialect
return s.dialect
}
// Callback return `Callbacks` container, you could add/change/delete callbacks with it
@ -159,7 +163,7 @@ func (s *DB) HasBlockGlobalUpdate() bool {
// SingularTable use singular table by default
func (s *DB) SingularTable(enable bool) {
modelStructsMap = newModelStructsMap()
modelStructsMap = sync.Map{}
s.parent.singularTable = enable
}
@ -484,6 +488,8 @@ func (s *DB) Begin() *DB {
if db, ok := c.db.(sqlDb); ok && db != nil {
tx, err := db.Begin()
c.db = interface{}(tx).(SQLCommon)
c.dialect.SetDB(c.db)
c.AddError(err)
} else {
c.AddError(ErrCantStartTransaction)
@ -748,6 +754,7 @@ func (s *DB) clone() *DB {
Value: s.Value,
Error: s.Error,
blockGlobalUpdate: s.blockGlobalUpdate,
dialect: newDialect(s.dialect.GetName(), s.db),
}
for key, value := range s.values {

View File

@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) {
}
}
func TestCreateAndAutomigrateTransaction(t *testing.T) {
tx := DB.Begin()
func() {
type Bar struct {
ID uint
}
DB.DropTableIfExists(&Bar{})
if ok := DB.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
if ok := tx.HasTable("bars"); ok {
t.Errorf("Table should not exist, but does")
}
}()
func() {
type Bar struct {
Name string
}
err := tx.CreateTable(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to create the table, but couldn't: %s", err)
}
if ok := tx.HasTable(&Bar{}); !ok {
t.Errorf("The transaction should be able to see the table")
}
}()
func() {
type Bar struct {
Stuff string
}
err := tx.AutoMigrate(&Bar{}).Error
if err != nil {
t.Errorf("Should have been able to alter the table, but couldn't")
}
}()
tx.Rollback()
}
type MultipleIndexes struct {
ID int64
UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"`

View File

@ -17,28 +17,7 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}
type safeModelStructsMap struct {
m map[reflect.Type]*ModelStruct
l *sync.RWMutex
}
func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) {
s.l.Lock()
defer s.l.Unlock()
s.m[key] = value
}
func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct {
s.l.RLock()
defer s.l.RUnlock()
return s.m[key]
}
func newModelStructsMap() *safeModelStructsMap {
return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)}
}
var modelStructsMap = newModelStructsMap()
var modelStructsMap sync.Map
// ModelStruct model definition
type ModelStruct struct {
@ -48,7 +27,7 @@ type ModelStruct struct {
defaultTableName string
}
// TableName get model's table name
// TableName returns model's table name
func (s *ModelStruct) TableName(db *DB) string {
if s.defaultTableName == "" && db != nil && s.ModelType != nil {
// Set default table name
@ -152,8 +131,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
// Get Cached model struct
if value := modelStructsMap.Get(reflectType); value != nil {
return value
if value, ok := modelStructsMap.Load(reflectType); ok && value != nil {
return value.(*ModelStruct)
}
modelStruct.ModelType = reflectType
@ -601,7 +580,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
}
modelStructsMap.Set(reflectType, &modelStruct)
modelStructsMap.Store(reflectType, &modelStruct)
return &modelStruct
}

View File

@ -123,6 +123,31 @@ func TestAutoPreload(t *testing.T) {
}
}
func TestAutoPreloadFalseDoesntPreload(t *testing.T) {
user1 := getPreloadUser("auto_user1")
DB.Save(user1)
preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload")
var user User
preloadDB.Find(&user)
if user.BillingAddress.Address1 != "" {
t.Error("AutoPreload was set to fasle, but still fetched data")
}
user2 := getPreloadUser("auto_user2")
DB.Save(user2)
var users []User
preloadDB.Find(&users)
for _, user := range users {
if user.BillingAddress.Address1 != "" {
t.Error("AutoPreload was set to fasle, but still fetched data")
}
}
}
func TestNestedPreload1(t *testing.T) {
type (
Level1 struct {

View File

@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) {
scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users))
}
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users))
}
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
if len(users) != 1 {
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
}
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
@ -532,28 +532,28 @@ func TestNot(t *testing.T) {
DB.Table("users").Where("name = ?", "user3").Count(&name3Count)
DB.Not("name", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not("name = ?", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not("name <> ?", "user3").Find(&users4)
if len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not(User{Name: "user3"}).Find(&users5)
if len(users1)-len(users5) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6)
if len(users1)-len(users6) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7)
@ -563,14 +563,14 @@ func TestNot(t *testing.T) {
DB.Not("name", []string{"user3"}).Find(&users8)
if len(users1)-len(users8) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
var name2Count int64
DB.Table("users").Where("name = ?", "user2").Count(&name2Count)
DB.Not("name", []string{"user3", "user2"}).Find(&users9)
if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) {
t.Errorf("Should find all users's name not equal 3")
t.Errorf("Should find all users' name not equal 3")
}
}

View File

@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon {
// Dialect get dialect
func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect
return scope.db.dialect
}
// Quote used to quote string to escape them for database
@ -586,10 +586,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
scope.Err(fmt.Errorf("invalid query condition: %v", value))
return
}
scopeQuotedTableName := newScope.QuotedTableName()
for _, field := range newScope.Fields() {
if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface())))
}
}
return strings.Join(sqls, " AND ")
@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
func (scope *Scope) removeForeignKey(field string, dest string) {
keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign")
if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) {
return
}
var query = `ALTER TABLE %s DROP CONSTRAINT %s;`
var mysql mysql
var query string
if scope.Dialect().GetName() == mysql.GetName() {
query = `ALTER TABLE %s DROP FOREIGN KEY %s;`
} else {
query = `ALTER TABLE %s DROP CONSTRAINT %s;`
}
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec()
}