This commit is contained in:
Paolo Galeone 2015-12-21 13:41:46 +01:00
commit 387c00d2b5
4 changed files with 269 additions and 24 deletions

View File

@ -1064,10 +1064,10 @@ func (u *User) AfterCreate() (err error) {
} }
``` ```
As you know, save/delete operations in gorm are running in a transaction, Save/delete operations in gorm are running in a transaction.
This is means if changes made in the transaction is not visiable unless it is commited, Changes made in that transaction are not visible unless it is commited.
So if you want to use those changes in your callbacks, you need to run SQL in same transaction. So if you want to use those changes in your callbacks, you need to run your SQL in the same transaction.
Fortunately, gorm support pass transaction to callbacks as you needed, you could do it like this: For this Gorm supports passing transactions to callbacks like this:
```go ```go
func (u *User) AfterCreate(tx *gorm.DB) (err error) { func (u *User) AfterCreate(tx *gorm.DB) (err error) {

View File

@ -392,6 +392,10 @@ func (s *DB) DropTable(values ...interface{}) *DB {
func (s *DB) DropTableIfExists(values ...interface{}) *DB { func (s *DB) DropTableIfExists(values ...interface{}) *DB {
db := s.clone() db := s.clone()
for _, value := range values { for _, value := range values {
if tableName, ok := value.(string); ok {
db = db.Table(tableName)
}
db = db.NewScope(value).dropTableIfExists().db db = db.NewScope(value).dropTableIfExists().db
} }
return db return db

View File

@ -9,15 +9,20 @@ import (
) )
func getRealValue(value reflect.Value, columns []string) (results []interface{}) { func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
// If value is a nil pointer, Indirect returns a zero Value!
// Therefor we need to check for a zero value,
// as FieldByName could panic
if pointedValue := reflect.Indirect(value); pointedValue.IsValid() {
for _, column := range columns { for _, column := range columns {
if reflect.Indirect(value).FieldByName(column).IsValid() { if pointedValue.FieldByName(column).IsValid() {
result := reflect.Indirect(value).FieldByName(column).Interface() result := pointedValue.FieldByName(column).Interface()
if r, ok := result.(driver.Valuer); ok { if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value() result, _ = r.Value()
} }
results = append(results, result) results = append(results, result)
} }
} }
}
return return
} }
@ -290,13 +295,14 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
} }
} }
} else { } else {
object := scope.IndirectValue() if object := scope.IndirectValue(); object.IsValid() {
source := getRealValue(object, associationForeignStructFieldNames) source := getRealValue(object, associationForeignStructFieldNames)
field := object.FieldByName(field.Name) field := object.FieldByName(field.Name)
for _, link := range linkHash[toString(source)] { for _, link := range linkHash[toString(source)] {
field.Set(reflect.Append(field, link)) field.Set(reflect.Append(field, link))
} }
} }
}
} }
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {

View File

@ -1,10 +1,13 @@
package gorm_test package gorm_test
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"os" "os"
"reflect" "reflect"
"testing" "testing"
"github.com/jinzhu/gorm"
) )
func getPreloadUser(name string) *User { func getPreloadUser(name string) *User {
@ -129,6 +132,10 @@ func TestNestedPreload1(t *testing.T) {
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
} }
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got, "name = ?", "not_found").Error; err != gorm.RecordNotFound {
panic(err)
}
} }
func TestNestedPreload2(t *testing.T) { func TestNestedPreload2(t *testing.T) {
@ -625,7 +632,7 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
DB.Table("levels").DropTableIfExists("levels") DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
panic(err) panic(err)
@ -698,11 +705,11 @@ func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) {
func TestManyToManyPreloadForPointer(t *testing.T) { func TestManyToManyPreloadForPointer(t *testing.T) {
type ( type (
Level1 struct { Level1 struct {
ID uint `gorm:"primary_key;"` ID uint
Value string Value string
} }
Level2 struct { Level2 struct {
ID uint `gorm:"primary_key;"` ID uint
Value string Value string
Level1s []*Level1 `gorm:"many2many:levels;"` Level1s []*Level1 `gorm:"many2many:levels;"`
} }
@ -710,7 +717,7 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
DB.DropTableIfExists(&Level2{}) DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{}) DB.DropTableIfExists(&Level1{})
DB.Table("levels").DropTableIfExists("levels") DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil { if err := DB.AutoMigrate(&Level2{}, &Level1{}).Error; err != nil {
panic(err) panic(err)
@ -764,6 +771,9 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
panic(err) panic(err)
} }
var got5 Level2
DB.Preload("Level1s").First(&got5, "value = ?", "bogus")
var ruLevel1 Level1 var ruLevel1 Level1
var zhLevel1 Level1 var zhLevel1 Level1
DB.First(&ruLevel1, "value = ?", "ru") DB.First(&ruLevel1, "value = ?", "ru")
@ -776,20 +786,245 @@ func TestManyToManyPreloadForPointer(t *testing.T) {
} }
} }
func TestNilPointerSlice(t *testing.T) { func TestManyToManyPreloadForNestedPointer(t *testing.T) {
type ( type (
Level3 struct { Level1 struct {
ID uint `gorm:"primary_key;"` ID uint
Value string Value string
} }
Level2 struct { Level2 struct {
ID uint `gorm:"primary_key;"` ID uint
Value string
Level1s []*Level1 `gorm:"many2many:levels;"`
}
Level3 struct {
ID uint
Value string
Level2ID sql.NullInt64
Level2 *Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists("levels")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{
Value: "Bob",
Level2: &Level2{
Value: "Foo",
Level1s: []*Level1{
{Value: "ru"},
{Value: "en"},
},
},
}
if err := DB.Save(&want).Error; err != nil {
panic(err)
}
want2 := Level3{
Value: "Tom",
Level2: &Level2{
Value: "Bar",
Level1s: []*Level1{
{Value: "zh"},
{Value: "de"},
},
},
}
if err := DB.Save(&want2).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
var got2 Level3
if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got2, want2) {
t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2))
}
var got3 []Level3
if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got3, []Level3{got, got2}) {
t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2}))
}
var got4 []Level3
if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil {
panic(err)
}
var got5 Level3
DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus")
var ruLevel1 Level1
var zhLevel1 Level1
DB.First(&ruLevel1, "value = ?", "ru")
DB.First(&zhLevel1, "value = ?", "zh")
got.Level2.Level1s = []*Level1{&ruLevel1}
got2.Level2.Level1s = []*Level1{&zhLevel1}
if !reflect.DeepEqual(got4, []Level3{got, got2}) {
t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2}))
}
}
func TestNestedManyToManyPreload(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
}
Level2 struct {
ID uint
Value string
Level1s []*Level1 `gorm:"many2many:level1_level2;"`
}
Level3 struct {
ID uint
Value string
Level2s []Level2 `gorm:"many2many:level2_level3;"`
}
)
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists("level1_level2")
DB.DropTableIfExists("level2_level3")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{
Value: "Level3",
Level2s: []Level2{
{
Value: "Bob",
Level1s: []*Level1{
{Value: "ru"},
{Value: "en"},
},
}, {
Value: "Tom",
Level1s: []*Level1{
{Value: "zh"},
{Value: "de"},
},
},
},
}
if err := DB.Save(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
if err := DB.Preload("Level2s.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
panic(err)
}
}
func TestNestedManyToManyPreload2(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
}
Level2 struct {
ID uint
Value string
Level1s []*Level1 `gorm:"many2many:level1_level2;"`
}
Level3 struct {
ID uint
Value string
Level2ID sql.NullInt64
Level2 *Level2
}
)
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists("level1_level2")
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{
Value: "Level3",
Level2: &Level2{
Value: "Bob",
Level1s: []*Level1{
{Value: "ru"},
{Value: "en"},
},
},
}
if err := DB.Save(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "not_found").Error; err != gorm.RecordNotFound {
panic(err)
}
}
func TestNilPointerSlice(t *testing.T) {
type (
Level3 struct {
ID uint
Value string
}
Level2 struct {
ID uint
Value string Value string
Level3ID uint Level3ID uint
Level3 *Level3 Level3 *Level3
} }
Level1 struct { Level1 struct {
ID uint `gorm:"primary_key;"` ID uint
Value string Value string
Level2ID uint Level2ID uint
Level2 *Level2 Level2 *Level2