diff --git a/callback_query_preload.go b/callback_query_preload.go index 30f6b585..481bfbe3 100644 --- a/callback_query_preload.go +++ b/callback_query_preload.go @@ -14,8 +14,14 @@ func preloadCallback(scope *Scope) { return } - if _, ok := scope.Get("gorm:auto_preload"); ok { - autoPreload(scope) + if ap, ok := scope.Get("gorm:auto_preload"); ok { + // If gorm:auto_preload IS NOT a bool then auto preload. + // Else if it IS a bool, use the value + if apb, ok := ap.(bool); !ok { + autoPreload(scope) + } else if apb { + autoPreload(scope) + } } if scope.Search.preload == nil || scope.HasError() { diff --git a/dialects/mssql/mssql.go b/dialects/mssql/mssql.go index a8d3c45a..731721cb 100644 --- a/dialects/mssql/mssql.go +++ b/dialects/mssql/mssql.go @@ -1,12 +1,16 @@ package mssql import ( + "database/sql/driver" + "encoding/json" + "errors" "fmt" "reflect" "strconv" "strings" "time" + // Importing mssql driver package only in dialect file, otherwide not needed _ "github.com/denisenkom/go-mssqldb" "github.com/jinzhu/gorm" ) @@ -201,3 +205,27 @@ func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, st } return dialect.CurrentDatabase(), tableName } + +// JSON type to support easy handling of JSON data in character table fields +// using golang json.RawMessage for deferred decoding/encoding +type JSON struct { + json.RawMessage +} + +// Value get value of JSON +func (j JSON) Value() (driver.Value, error) { + if len(j.RawMessage) == 0 { + return nil, nil + } + return j.MarshalJSON() +} + +// Scan scan value into JSON +func (j *JSON) Scan(value interface{}) error { + str, ok := value.(string) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value (strcast):", value)) + } + bytes := []byte(str) + return json.Unmarshal(bytes, j) +} diff --git a/errors.go b/errors.go index da2cf13c..27c9a92d 100644 --- a/errors.go +++ b/errors.go @@ -6,7 +6,7 @@ import ( ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + // ErrRecordNotFound record not found error, happens when only haven't find any matched data when looking up with a struct, finding a slice won't return this error ErrRecordNotFound = errors.New("record not found") // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL ErrInvalidSQL = errors.New("invalid SQL") diff --git a/main.go b/main.go index 25c3a06b..993e19b1 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" ) @@ -48,6 +49,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } var source string var dbSQL SQLCommon + var ownDbSQL bool switch value := args[0].(type) { case string: @@ -59,8 +61,10 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { source = args[1].(string) } dbSQL, err = sql.Open(driver, source) + ownDbSQL = true case SQLCommon: dbSQL = value + ownDbSQL = false default: return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) } @@ -78,7 +82,7 @@ func Open(dialect string, args ...interface{}) (db *DB, err error) { } // Send a ping to make sure the database connection is alive. if d, ok := dbSQL.(*sql.DB); ok { - if err = d.Ping(); err != nil { + if err = d.Ping(); err != nil && ownDbSQL { d.Close() } } @@ -119,7 +123,7 @@ func (s *DB) CommonDB() SQLCommon { // Dialect get dialect func (s *DB) Dialect() Dialect { - return s.parent.dialect + return s.dialect } // Callback return `Callbacks` container, you could add/change/delete callbacks with it @@ -159,7 +163,7 @@ func (s *DB) HasBlockGlobalUpdate() bool { // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() + modelStructsMap = sync.Map{} s.parent.singularTable = enable } @@ -484,6 +488,8 @@ func (s *DB) Begin() *DB { if db, ok := c.db.(sqlDb); ok && db != nil { tx, err := db.Begin() c.db = interface{}(tx).(SQLCommon) + + c.dialect.SetDB(c.db) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) @@ -748,6 +754,7 @@ func (s *DB) clone() *DB { Value: s.Value, Error: s.Error, blockGlobalUpdate: s.blockGlobalUpdate, + dialect: newDialect(s.dialect.GetName(), s.db), } for key, value := range s.values { diff --git a/migration_test.go b/migration_test.go index 7c694485..78555dcc 100644 --- a/migration_test.go +++ b/migration_test.go @@ -398,6 +398,53 @@ func TestAutoMigration(t *testing.T) { } } +func TestCreateAndAutomigrateTransaction(t *testing.T) { + tx := DB.Begin() + + func() { + type Bar struct { + ID uint + } + DB.DropTableIfExists(&Bar{}) + + if ok := DB.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + + if ok := tx.HasTable("bars"); ok { + t.Errorf("Table should not exist, but does") + } + }() + + func() { + type Bar struct { + Name string + } + err := tx.CreateTable(&Bar{}).Error + + if err != nil { + t.Errorf("Should have been able to create the table, but couldn't: %s", err) + } + + if ok := tx.HasTable(&Bar{}); !ok { + t.Errorf("The transaction should be able to see the table") + } + }() + + func() { + type Bar struct { + Stuff string + } + + err := tx.AutoMigrate(&Bar{}).Error + if err != nil { + t.Errorf("Should have been able to alter the table, but couldn't") + } + }() + + tx.Rollback() +} + type MultipleIndexes struct { ID int64 UserID int64 `sql:"unique_index:uix_multipleindexes_user_name,uix_multipleindexes_user_email;index:idx_multipleindexes_user_other"` diff --git a/model_struct.go b/model_struct.go index f571e2e8..8506fe87 100644 --- a/model_struct.go +++ b/model_struct.go @@ -17,28 +17,7 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} - -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() +var modelStructsMap sync.Map // ModelStruct model definition type ModelStruct struct { @@ -48,7 +27,7 @@ type ModelStruct struct { defaultTableName string } -// TableName get model's table name +// TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { if s.defaultTableName == "" && db != nil && s.ModelType != nil { // Set default table name @@ -152,8 +131,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value + if value, ok := modelStructsMap.Load(reflectType); ok && value != nil { + return value.(*ModelStruct) } modelStruct.ModelType = reflectType @@ -601,7 +580,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Set(reflectType, &modelStruct) + modelStructsMap.Store(reflectType, &modelStruct) return &modelStruct } diff --git a/preload_test.go b/preload_test.go index 311ad0be..1db625c9 100644 --- a/preload_test.go +++ b/preload_test.go @@ -123,6 +123,31 @@ func TestAutoPreload(t *testing.T) { } } +func TestAutoPreloadFalseDoesntPreload(t *testing.T) { + user1 := getPreloadUser("auto_user1") + DB.Save(user1) + + preloadDB := DB.Set("gorm:auto_preload", false).Where("role = ?", "Preload") + var user User + preloadDB.Find(&user) + + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + + user2 := getPreloadUser("auto_user2") + DB.Save(user2) + + var users []User + preloadDB.Find(&users) + + for _, user := range users { + if user.BillingAddress.Address1 != "" { + t.Error("AutoPreload was set to fasle, but still fetched data") + } + } +} + func TestNestedPreload1(t *testing.T) { type ( Level1 struct { diff --git a/query_test.go b/query_test.go index fac7d4d8..15bf8b3c 100644 --- a/query_test.go +++ b/query_test.go @@ -181,17 +181,17 @@ func TestSearchWithPlainSQL(t *testing.T) { scopedb.Where("birthday > ?", parseTime("2000-1-1")).Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday > 2000-1-1, but got %v", len(users)) } scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { - t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) + t.Errorf("Should found 2 users' birthday >= 2002-10-10, but got %v", len(users)) } scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { - t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) + t.Errorf("Should found 1 users' birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) @@ -532,28 +532,28 @@ func TestNot(t *testing.T) { DB.Table("users").Where("name = ?", "user3").Count(&name3Count) DB.Not("name", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(User{Name: "user3"}).Find(&users5) if len(users1)-len(users5) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3"}).Find(&users6) if len(users1)-len(users6) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } DB.Not(map[string]interface{}{"name": "user3", "company_id": nil}).Find(&users7) @@ -563,14 +563,14 @@ func TestNot(t *testing.T) { DB.Not("name", []string{"user3"}).Find(&users8) if len(users1)-len(users8) != int(name3Count) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } var name2Count int64 DB.Table("users").Where("name = ?", "user2").Count(&name2Count) DB.Not("name", []string{"user3", "user2"}).Find(&users9) if len(users1)-len(users9) != (int(name3Count) + int(name2Count)) { - t.Errorf("Should find all users's name not equal 3") + t.Errorf("Should find all users' name not equal 3") } } diff --git a/scope.go b/scope.go index 397ccf0b..ca861d8a 100644 --- a/scope.go +++ b/scope.go @@ -63,7 +63,7 @@ func (scope *Scope) SQLDB() SQLCommon { // Dialect get dialect func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect + return scope.db.dialect } // Quote used to quote string to escape them for database @@ -586,10 +586,10 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) scope.Err(fmt.Errorf("invalid query condition: %v", value)) return } - + scopeQuotedTableName := newScope.QuotedTableName() for _, field := range newScope.Fields() { if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") @@ -1216,11 +1216,17 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on func (scope *Scope) removeForeignKey(field string, dest string) { keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") - if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return } - var query = `ALTER TABLE %s DROP CONSTRAINT %s;` + var mysql mysql + var query string + if scope.Dialect().GetName() == mysql.GetName() { + query = `ALTER TABLE %s DROP FOREIGN KEY %s;` + } else { + query = `ALTER TABLE %s DROP CONSTRAINT %s;` + } + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() }