Add SetupJoinTable support
This commit is contained in:
		
							parent
							
								
									db03616993
								
							
						
					
					
						commit
						e490e09db5
					
				| @ -4,6 +4,7 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm/clause" | ||||
| 	"github.com/jinzhu/gorm/schema" | ||||
| @ -44,7 +45,16 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro | ||||
| 			tx         = association.DB.Model(out) | ||||
| 		) | ||||
| 
 | ||||
| 		if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped { | ||||
| 		if association.Relationship.JoinTable != nil { | ||||
| 			if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { | ||||
| 				joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} | ||||
| 				for _, queryClause := range association.Relationship.JoinTable.QueryClauses { | ||||
| 					joinStmt.AddClause(queryClause) | ||||
| 				} | ||||
| 				joinStmt.Build("WHERE", "LIMIT") | ||||
| 				tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) | ||||
| 			} | ||||
| 
 | ||||
| 			tx.Clauses(clause.From{Joins: []clause.Join{{ | ||||
| 				Table: clause.Table{Name: association.Relationship.JoinTable.Table}, | ||||
| 				ON:    clause.Where{Exprs: queryConds}, | ||||
| @ -321,10 +331,13 @@ func (association *Association) Count() (count int64) { | ||||
| 		) | ||||
| 
 | ||||
| 		if association.Relationship.JoinTable != nil { | ||||
| 			if !tx.Statement.Unscoped { | ||||
| 			if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { | ||||
| 				joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} | ||||
| 				for _, queryClause := range association.Relationship.JoinTable.QueryClauses { | ||||
| 					tx.Clauses(queryClause) | ||||
| 					joinStmt.AddClause(queryClause) | ||||
| 				} | ||||
| 				joinStmt.Build("WHERE", "LIMIT") | ||||
| 				tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) | ||||
| 			} | ||||
| 
 | ||||
| 			tx.Clauses(clause.From{Joins: []clause.Join{{ | ||||
|  | ||||
| @ -169,12 +169,12 @@ func CreateWithReturning(db *gorm.DB) { | ||||
| 				if err != nil { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 			} | ||||
| 		} else if !db.DryRun { | ||||
| 			if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { | ||||
| 				db.RowsAffected, _ = result.RowsAffected() | ||||
| 			} else { | ||||
| 				if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { | ||||
| 					db.RowsAffected, _ = result.RowsAffected() | ||||
| 				} else { | ||||
| 					db.AddError(err) | ||||
| 				} | ||||
| 				db.AddError(err) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										37
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								gorm.go
									
									
									
									
									
								
							| @ -108,6 +108,7 @@ func (db *DB) Session(config *Session) *DB { | ||||
| 	if config.Context != nil { | ||||
| 		if tx.Statement != nil { | ||||
| 			tx.Statement = tx.Statement.clone() | ||||
| 			tx.Statement.DB = tx | ||||
| 		} else { | ||||
| 			tx.Statement = &Statement{ | ||||
| 				DB:       tx, | ||||
| @ -181,6 +182,42 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { | ||||
| 	return nil, false | ||||
| } | ||||
| 
 | ||||
| func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { | ||||
| 	var ( | ||||
| 		tx                      = db.getInstance() | ||||
| 		stmt                    = tx.Statement | ||||
| 		modelSchema, joinSchema *schema.Schema | ||||
| 	) | ||||
| 
 | ||||
| 	if err := stmt.Parse(model); err == nil { | ||||
| 		modelSchema = stmt.Schema | ||||
| 	} else { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if err := stmt.Parse(joinTable); err == nil { | ||||
| 		joinSchema = stmt.Schema | ||||
| 	} else { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { | ||||
| 		for _, ref := range relation.References { | ||||
| 			if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { | ||||
| 				ref.ForeignKey = f | ||||
| 			} else { | ||||
| 				return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		relation.JoinTable = joinSchema | ||||
| 	} else { | ||||
| 		return fmt.Errorf("failed to found relation: %v", field) | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Callback returns callback manager
 | ||||
| func (db *DB) Callback() *callbacks { | ||||
| 	return db.callbacks | ||||
|  | ||||
| @ -33,7 +33,7 @@ type Relationship struct { | ||||
| 	Type                     RelationshipType | ||||
| 	Field                    *Field | ||||
| 	Polymorphic              *Polymorphic | ||||
| 	References               []Reference | ||||
| 	References               []*Reference | ||||
| 	Schema                   *Schema | ||||
| 	FieldSchema              *Schema | ||||
| 	JoinTable                *Schema | ||||
| @ -139,7 +139,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi | ||||
| 	} | ||||
| 
 | ||||
| 	if schema.err == nil { | ||||
| 		relation.References = append(relation.References, Reference{ | ||||
| 		relation.References = append(relation.References, &Reference{ | ||||
| 			PrimaryValue: relation.Polymorphic.Value, | ||||
| 			ForeignKey:   relation.Polymorphic.PolymorphicType, | ||||
| 		}) | ||||
| @ -150,7 +150,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi | ||||
| 				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) | ||||
| 			} | ||||
| 		} | ||||
| 		relation.References = append(relation.References, Reference{ | ||||
| 		relation.References = append(relation.References, &Reference{ | ||||
| 			PrimaryKey:    primaryKeyField, | ||||
| 			ForeignKey:    relation.Polymorphic.PolymorphicID, | ||||
| 			OwnPrimaryKey: true, | ||||
| @ -246,7 +246,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel | ||||
| 
 | ||||
| 	// build references
 | ||||
| 	for _, f := range relation.JoinTable.Fields { | ||||
| 		relation.References = append(relation.References, Reference{ | ||||
| 		relation.References = append(relation.References, &Reference{ | ||||
| 			PrimaryKey:    fieldsMap[f.Name], | ||||
| 			ForeignKey:    f, | ||||
| 			OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], | ||||
| @ -326,7 +326,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH | ||||
| 
 | ||||
| 	// build references
 | ||||
| 	for idx, foreignField := range foreignFields { | ||||
| 		relation.References = append(relation.References, Reference{ | ||||
| 		relation.References = append(relation.References, &Reference{ | ||||
| 			PrimaryKey:    primaryFields[idx], | ||||
| 			ForeignKey:    foreignField, | ||||
| 			OwnPrimaryKey: schema == primarySchema && guessHas, | ||||
|  | ||||
| @ -158,9 +158,11 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | ||||
| 				writer.WriteString("(NULL)") | ||||
| 			} | ||||
| 		case *DB: | ||||
| 			result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement | ||||
| 			writer.WriteString(result.SQL.String()) | ||||
| 			stmt.Vars = append(stmt.Vars, result.Vars...) | ||||
| 			subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance() | ||||
| 			subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) | ||||
| 			subdb.callbacks.Query().Execute(subdb) | ||||
| 			writer.WriteString(subdb.Statement.SQL.String()) | ||||
| 			stmt.Vars = subdb.Statement.Vars | ||||
| 		default: | ||||
| 			switch rv := reflect.ValueOf(v); rv.Kind() { | ||||
| 			case reflect.Slice, reflect.Array: | ||||
|  | ||||
							
								
								
									
										99
									
								
								tests/joins_table_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								tests/joins_table_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,99 @@ | ||||
| package tests_test | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/jinzhu/gorm" | ||||
| 	. "github.com/jinzhu/gorm/tests" | ||||
| ) | ||||
| 
 | ||||
| type Person struct { | ||||
| 	ID        int | ||||
| 	Name      string | ||||
| 	Addresses []Address `gorm:"many2many:person_addresses;"` | ||||
| } | ||||
| 
 | ||||
| type Address struct { | ||||
| 	ID   uint | ||||
| 	Name string | ||||
| } | ||||
| 
 | ||||
| type PersonAddress struct { | ||||
| 	PersonID  int | ||||
| 	AddressID int | ||||
| 	CreatedAt time.Time | ||||
| 	DeletedAt gorm.DeletedAt | ||||
| } | ||||
| 
 | ||||
| func TestOverrideJoinTable(t *testing.T) { | ||||
| 	DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{}) | ||||
| 
 | ||||
| 	if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil { | ||||
| 		t.Fatalf("Failed to setup join table for person, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil { | ||||
| 		t.Fatalf("Failed to migrate, got %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	address1 := Address{Name: "address 1"} | ||||
| 	address2 := Address{Name: "address 2"} | ||||
| 	person := Person{Name: "person", Addresses: []Address{address1, address2}} | ||||
| 	DB.Create(&person) | ||||
| 
 | ||||
| 	var addresses1 []Address | ||||
| 	if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 { | ||||
| 		t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1)) | ||||
| 	} | ||||
| 
 | ||||
| 	if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil { | ||||
| 		t.Fatalf("Failed to delete address, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if len(person.Addresses) != 1 { | ||||
| 		t.Fatalf("Should have one address left") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 { | ||||
| 		t.Fatalf("Should found one address") | ||||
| 	} | ||||
| 
 | ||||
| 	var addresses2 []Address | ||||
| 	if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 { | ||||
| 		t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2)) | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Model(&person).Association("Addresses").Count() != 1 { | ||||
| 		t.Fatalf("Should found one address") | ||||
| 	} | ||||
| 
 | ||||
| 	var addresses3 []Address | ||||
| 	if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 { | ||||
| 		t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3)) | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 { | ||||
| 		t.Fatalf("Should found soft deleted addresses with unscoped") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { | ||||
| 		t.Fatalf("Should found soft deleted addresses with unscoped") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Model(&person).Association("Addresses").Clear() | ||||
| 
 | ||||
| 	if DB.Model(&person).Association("Addresses").Count() != 0 { | ||||
| 		t.Fatalf("Should deleted all addresses") | ||||
| 	} | ||||
| 
 | ||||
| 	if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { | ||||
| 		t.Fatalf("Should found soft deleted addresses with unscoped") | ||||
| 	} | ||||
| 
 | ||||
| 	DB.Unscoped().Model(&person).Association("Addresses").Clear() | ||||
| 
 | ||||
| 	if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { | ||||
| 		t.Fatalf("address should be deleted when clear with unscoped") | ||||
| 	} | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu