From a7aaf151cf3b710c13f109e336b35a113f958721 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 30 Jul 2014 11:32:18 +0800 Subject: [PATCH] Use JoinTable instead of ForeignKey --- callback_shared.go | 22 +++++++++++++++------- field.go | 8 +++++++- relations_test.go | 16 ++++++++-------- scope.go | 43 +++++++++++++++++++++++++++++++++---------- 4 files changed, 63 insertions(+), 26 deletions(-) diff --git a/callback_shared.go b/callback_shared.go index 1c403ebf..324d121d 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "fmt" + "reflect" +) func BeginTransaction(scope *Scope) { scope.Begin() @@ -28,8 +31,8 @@ func SaveBeforeAssociations(scope *Scope) { scope.SetColumn(field.Name, value.Interface()) } - if len(field.ForeignKey) > 0 { - scope.SetColumn(field.ForeignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) + if field.JoinTable != nil && field.JoinTable.foreignKey != "" { + scope.SetColumn(field.JoinTable.foreignKey, newDB.NewScope(value.Interface()).PrimaryKeyValue()) } } } @@ -46,16 +49,19 @@ func SaveAfterAssociations(scope *Scope) { newDB := scope.NewDB() elem := value.Index(i).Addr().Interface() - if len(field.ForeignKey) > 0 { - newDB.NewScope(elem).SetColumn(field.ForeignKey, scope.PrimaryKeyValue()) + if field.JoinTable != nil && field.JoinTable.foreignKey != "" { + newDB.NewScope(elem).SetColumn(field.JoinTable.foreignKey, scope.PrimaryKeyValue()) } scope.Err(newDB.Save(elem).Error) + fmt.Sprintf("INSERT INTO %v (%v, %v) SELECT (%v, %v) FROM %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v) limit 1;") } default: newDB := scope.NewDB() if value.CanAddr() { - newDB.NewScope(field.Value).SetColumn(field.ForeignKey, scope.PrimaryKeyValue()) + if field.JoinTable != nil { + newDB.NewScope(field.Value).SetColumn(field.JoinTable.foreignKey, scope.PrimaryKeyValue()) + } scope.Err(newDB.Save(field.Value).Error) } else { destValue := reflect.New(reflect.TypeOf(field.Value)).Elem() @@ -65,7 +71,9 @@ func SaveAfterAssociations(scope *Scope) { } elem := destValue.Addr().Interface() - newDB.NewScope(elem).SetColumn(field.ForeignKey, scope.PrimaryKeyValue()) + if field.JoinTable != nil { + newDB.NewScope(elem).SetColumn(field.JoinTable.foreignKey, scope.PrimaryKeyValue()) + } scope.Err(newDB.Save(elem).Error) scope.SetColumn(field.Name, destValue.Interface()) } diff --git a/field.go b/field.go index c6975312..8452686c 100644 --- a/field.go +++ b/field.go @@ -6,6 +6,12 @@ import ( "time" ) +type joinTable struct { + joinTable string + foreignKey string + associationForeignKey string +} + type Field struct { Name string DBName string @@ -14,10 +20,10 @@ type Field struct { IsIgnored bool Tag reflect.StructTag SqlTag string - ForeignKey string BeforeAssociation bool AfterAssociation bool isPrimaryKey bool + JoinTable *joinTable } func (f *Field) IsScanner() bool { diff --git a/relations_test.go b/relations_test.go index e441712d..6e89c23e 100644 --- a/relations_test.go +++ b/relations_test.go @@ -115,8 +115,8 @@ func TestRelated(t *testing.T) { var creditcard CreditCard var user3 User - db.First(&creditcard, "number = ?", "1234567890") - db.Model(&creditcard).Related(&user3) + db.Debug().First(&creditcard, "number = ?", "1234567890") + db.Debug().Model(&creditcard).Related(&user3) if user3.Id != user.Id || user3.Name != user.Name { t.Errorf("Should get user from credit card correctly") } @@ -131,17 +131,17 @@ func TestQueryManyToManyWithRelated(t *testing.T) { user := User{Name: "Many2Many", Languages: languages} db.Save(&user) - // var newLanguages []Language + var newLanguages []Language // db.Model(&user).Related(&newLanguages, "Languages") // if len(newLanguages) != 3 { // t.Errorf("Query many to many relations") // } - // newLanguages = []Language{} - // db.Model(&user).Many2Many("Languages").Find(&newLanguages) - // if len(newLanguages) != 3 { - // t.Errorf("Query many to many relations") - // } + newLanguages = []Language{} + db.Model(&user).Many2Many("Languages").Find(&newLanguages) + if len(newLanguages) != 3 { + t.Errorf("Query many to many relations") + } // db.Model(&User{}).Many2Many("Languages").Add(&Language{}) // db.Model(&User{}).Many2Many("Languages").Remove(&Language{}) diff --git a/scope.go b/scope.go index 181b0453..ccf50dde 100644 --- a/scope.go +++ b/scope.go @@ -263,33 +263,56 @@ func (scope *Scope) Fields() []*Field { } if scope.db != nil { + indirectValue := reflect.Indirect(value) field.Tag = fieldStruct.Tag field.SqlTag = scope.sqlTagForField(&field) // parse association - elem := reflect.Indirect(value) - typ := elem.Type() + typ := indirectValue.Type() + foreignKey := settings["FOREIGNKEY"] + associationForeignKey := settings["ASSOCIATIONFOREIGNKEY"] + many2many := settings["MANY2MANY"] - switch elem.Kind() { + switch indirectValue.Kind() { case reflect.Slice: typ = typ.Elem() if typ.Kind() == reflect.Struct { - foreignKey := scopeTyp.Name() + "Id" - if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - field.ForeignKey = foreignKey + if foreignKey == "" { + foreignKey = scopeTyp.Name() + "Id" } + if associationForeignKey == "" { + associationForeignKey = typ.Name() + "Id" + } + + // if not many to many, foreign key could be null + if many2many == "" { + if !reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { + foreignKey = "" + } + } + field.AfterAssociation = true + field.JoinTable = &joinTable{ + joinTable: many2many, + foreignKey: foreignKey, + associationForeignKey: associationForeignKey, + } } case reflect.Struct: if !field.IsTime() && !field.IsScanner() { - if scope.HasColumn(field.Name + "Id") { - field.ForeignKey = field.Name + "Id" + if foreignKey == "" && scope.HasColumn(field.Name+"Id") { + field.JoinTable = &joinTable{foreignKey: field.Name + "Id"} + field.BeforeAssociation = true + } else if scope.HasColumn(foreignKey) { + field.JoinTable = &joinTable{foreignKey: foreignKey} field.BeforeAssociation = true } else { - foreignKey := scopeTyp.Name() + "Id" + if foreignKey == "" { + foreignKey = scopeTyp.Name() + "Id" + } if reflect.New(typ).Elem().FieldByName(foreignKey).IsValid() { - field.ForeignKey = foreignKey + field.JoinTable = &joinTable{foreignKey: foreignKey} } field.AfterAssociation = true }