Add SetupJoinTable support
This commit is contained in:
		
							parent
							
								
									db03616993
								
							
						
					
					
						commit
						e490e09db5
					
				| @ -4,6 +4,7 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/jinzhu/gorm/clause" | 	"github.com/jinzhu/gorm/clause" | ||||||
| 	"github.com/jinzhu/gorm/schema" | 	"github.com/jinzhu/gorm/schema" | ||||||
| @ -44,7 +45,16 @@ func (association *Association) Find(out interface{}, conds ...interface{}) erro | |||||||
| 			tx         = association.DB.Model(out) | 			tx         = association.DB.Model(out) | ||||||
| 		) | 		) | ||||||
| 
 | 
 | ||||||
| 		if association.Relationship.JoinTable != nil && !tx.Statement.Unscoped { | 		if association.Relationship.JoinTable != nil { | ||||||
|  | 			if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { | ||||||
|  | 				joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} | ||||||
|  | 				for _, queryClause := range association.Relationship.JoinTable.QueryClauses { | ||||||
|  | 					joinStmt.AddClause(queryClause) | ||||||
|  | 				} | ||||||
|  | 				joinStmt.Build("WHERE", "LIMIT") | ||||||
|  | 				tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
| 			tx.Clauses(clause.From{Joins: []clause.Join{{ | 			tx.Clauses(clause.From{Joins: []clause.Join{{ | ||||||
| 				Table: clause.Table{Name: association.Relationship.JoinTable.Table}, | 				Table: clause.Table{Name: association.Relationship.JoinTable.Table}, | ||||||
| 				ON:    clause.Where{Exprs: queryConds}, | 				ON:    clause.Where{Exprs: queryConds}, | ||||||
| @ -321,10 +331,13 @@ func (association *Association) Count() (count int64) { | |||||||
| 		) | 		) | ||||||
| 
 | 
 | ||||||
| 		if association.Relationship.JoinTable != nil { | 		if association.Relationship.JoinTable != nil { | ||||||
| 			if !tx.Statement.Unscoped { | 			if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { | ||||||
|  | 				joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} | ||||||
| 				for _, queryClause := range association.Relationship.JoinTable.QueryClauses { | 				for _, queryClause := range association.Relationship.JoinTable.QueryClauses { | ||||||
| 					tx.Clauses(queryClause) | 					joinStmt.AddClause(queryClause) | ||||||
| 				} | 				} | ||||||
|  | 				joinStmt.Build("WHERE", "LIMIT") | ||||||
|  | 				tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			tx.Clauses(clause.From{Joins: []clause.Join{{ | 			tx.Clauses(clause.From{Joins: []clause.Join{{ | ||||||
|  | |||||||
| @ -169,12 +169,12 @@ func CreateWithReturning(db *gorm.DB) { | |||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					db.AddError(err) | 					db.AddError(err) | ||||||
| 				} | 				} | ||||||
|  | 			} | ||||||
|  | 		} else if !db.DryRun { | ||||||
|  | 			if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { | ||||||
|  | 				db.RowsAffected, _ = result.RowsAffected() | ||||||
| 			} else { | 			} else { | ||||||
| 				if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { | 				db.AddError(err) | ||||||
| 					db.RowsAffected, _ = result.RowsAffected() |  | ||||||
| 				} else { |  | ||||||
| 					db.AddError(err) |  | ||||||
| 				} |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | |||||||
							
								
								
									
										37
									
								
								gorm.go
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								gorm.go
									
									
									
									
									
								
							| @ -108,6 +108,7 @@ func (db *DB) Session(config *Session) *DB { | |||||||
| 	if config.Context != nil { | 	if config.Context != nil { | ||||||
| 		if tx.Statement != nil { | 		if tx.Statement != nil { | ||||||
| 			tx.Statement = tx.Statement.clone() | 			tx.Statement = tx.Statement.clone() | ||||||
|  | 			tx.Statement.DB = tx | ||||||
| 		} else { | 		} else { | ||||||
| 			tx.Statement = &Statement{ | 			tx.Statement = &Statement{ | ||||||
| 				DB:       tx, | 				DB:       tx, | ||||||
| @ -181,6 +182,42 @@ func (db *DB) InstanceGet(key string) (interface{}, bool) { | |||||||
| 	return nil, false | 	return nil, false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { | ||||||
|  | 	var ( | ||||||
|  | 		tx                      = db.getInstance() | ||||||
|  | 		stmt                    = tx.Statement | ||||||
|  | 		modelSchema, joinSchema *schema.Schema | ||||||
|  | 	) | ||||||
|  | 
 | ||||||
|  | 	if err := stmt.Parse(model); err == nil { | ||||||
|  | 		modelSchema = stmt.Schema | ||||||
|  | 	} else { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := stmt.Parse(joinTable); err == nil { | ||||||
|  | 		joinSchema = stmt.Schema | ||||||
|  | 	} else { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if relation, ok := modelSchema.Relationships.Relations[field]; ok && relation.JoinTable != nil { | ||||||
|  | 		for _, ref := range relation.References { | ||||||
|  | 			if f := joinSchema.LookUpField(ref.ForeignKey.DBName); f != nil { | ||||||
|  | 				ref.ForeignKey = f | ||||||
|  | 			} else { | ||||||
|  | 				return fmt.Errorf("missing field %v for join table", ref.ForeignKey.DBName) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		relation.JoinTable = joinSchema | ||||||
|  | 	} else { | ||||||
|  | 		return fmt.Errorf("failed to found relation: %v", field) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Callback returns callback manager
 | // Callback returns callback manager
 | ||||||
| func (db *DB) Callback() *callbacks { | func (db *DB) Callback() *callbacks { | ||||||
| 	return db.callbacks | 	return db.callbacks | ||||||
|  | |||||||
| @ -33,7 +33,7 @@ type Relationship struct { | |||||||
| 	Type                     RelationshipType | 	Type                     RelationshipType | ||||||
| 	Field                    *Field | 	Field                    *Field | ||||||
| 	Polymorphic              *Polymorphic | 	Polymorphic              *Polymorphic | ||||||
| 	References               []Reference | 	References               []*Reference | ||||||
| 	Schema                   *Schema | 	Schema                   *Schema | ||||||
| 	FieldSchema              *Schema | 	FieldSchema              *Schema | ||||||
| 	JoinTable                *Schema | 	JoinTable                *Schema | ||||||
| @ -139,7 +139,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if schema.err == nil { | 	if schema.err == nil { | ||||||
| 		relation.References = append(relation.References, Reference{ | 		relation.References = append(relation.References, &Reference{ | ||||||
| 			PrimaryValue: relation.Polymorphic.Value, | 			PrimaryValue: relation.Polymorphic.Value, | ||||||
| 			ForeignKey:   relation.Polymorphic.PolymorphicType, | 			ForeignKey:   relation.Polymorphic.PolymorphicType, | ||||||
| 		}) | 		}) | ||||||
| @ -150,7 +150,7 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi | |||||||
| 				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) | 				schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %v", relation.foreignKeys, schema, field.Name) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		relation.References = append(relation.References, Reference{ | 		relation.References = append(relation.References, &Reference{ | ||||||
| 			PrimaryKey:    primaryKeyField, | 			PrimaryKey:    primaryKeyField, | ||||||
| 			ForeignKey:    relation.Polymorphic.PolymorphicID, | 			ForeignKey:    relation.Polymorphic.PolymorphicID, | ||||||
| 			OwnPrimaryKey: true, | 			OwnPrimaryKey: true, | ||||||
| @ -246,7 +246,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel | |||||||
| 
 | 
 | ||||||
| 	// build references
 | 	// build references
 | ||||||
| 	for _, f := range relation.JoinTable.Fields { | 	for _, f := range relation.JoinTable.Fields { | ||||||
| 		relation.References = append(relation.References, Reference{ | 		relation.References = append(relation.References, &Reference{ | ||||||
| 			PrimaryKey:    fieldsMap[f.Name], | 			PrimaryKey:    fieldsMap[f.Name], | ||||||
| 			ForeignKey:    f, | 			ForeignKey:    f, | ||||||
| 			OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], | 			OwnPrimaryKey: schema == fieldsMap[f.Name].Schema && ownFieldsMap[f.Name], | ||||||
| @ -326,7 +326,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH | |||||||
| 
 | 
 | ||||||
| 	// build references
 | 	// build references
 | ||||||
| 	for idx, foreignField := range foreignFields { | 	for idx, foreignField := range foreignFields { | ||||||
| 		relation.References = append(relation.References, Reference{ | 		relation.References = append(relation.References, &Reference{ | ||||||
| 			PrimaryKey:    primaryFields[idx], | 			PrimaryKey:    primaryFields[idx], | ||||||
| 			ForeignKey:    foreignField, | 			ForeignKey:    foreignField, | ||||||
| 			OwnPrimaryKey: schema == primarySchema && guessHas, | 			OwnPrimaryKey: schema == primarySchema && guessHas, | ||||||
|  | |||||||
| @ -158,9 +158,11 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | |||||||
| 				writer.WriteString("(NULL)") | 				writer.WriteString("(NULL)") | ||||||
| 			} | 			} | ||||||
| 		case *DB: | 		case *DB: | ||||||
| 			result := v.Session(&Session{DryRun: true, WithConditions: true}).Find(nil).Statement | 			subdb := v.Session(&Session{DryRun: true, WithConditions: true}).getInstance() | ||||||
| 			writer.WriteString(result.SQL.String()) | 			subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) | ||||||
| 			stmt.Vars = append(stmt.Vars, result.Vars...) | 			subdb.callbacks.Query().Execute(subdb) | ||||||
|  | 			writer.WriteString(subdb.Statement.SQL.String()) | ||||||
|  | 			stmt.Vars = subdb.Statement.Vars | ||||||
| 		default: | 		default: | ||||||
| 			switch rv := reflect.ValueOf(v); rv.Kind() { | 			switch rv := reflect.ValueOf(v); rv.Kind() { | ||||||
| 			case reflect.Slice, reflect.Array: | 			case reflect.Slice, reflect.Array: | ||||||
|  | |||||||
							
								
								
									
										99
									
								
								tests/joins_table_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								tests/joins_table_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,99 @@ | |||||||
|  | package tests_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/jinzhu/gorm" | ||||||
|  | 	. "github.com/jinzhu/gorm/tests" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | type Person struct { | ||||||
|  | 	ID        int | ||||||
|  | 	Name      string | ||||||
|  | 	Addresses []Address `gorm:"many2many:person_addresses;"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type Address struct { | ||||||
|  | 	ID   uint | ||||||
|  | 	Name string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type PersonAddress struct { | ||||||
|  | 	PersonID  int | ||||||
|  | 	AddressID int | ||||||
|  | 	CreatedAt time.Time | ||||||
|  | 	DeletedAt gorm.DeletedAt | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestOverrideJoinTable(t *testing.T) { | ||||||
|  | 	DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{}) | ||||||
|  | 
 | ||||||
|  | 	if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil { | ||||||
|  | 		t.Fatalf("Failed to setup join table for person, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil { | ||||||
|  | 		t.Fatalf("Failed to migrate, got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	address1 := Address{Name: "address 1"} | ||||||
|  | 	address2 := Address{Name: "address 2"} | ||||||
|  | 	person := Person{Name: "person", Addresses: []Address{address1, address2}} | ||||||
|  | 	DB.Create(&person) | ||||||
|  | 
 | ||||||
|  | 	var addresses1 []Address | ||||||
|  | 	if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 { | ||||||
|  | 		t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil { | ||||||
|  | 		t.Fatalf("Failed to delete address, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if len(person.Addresses) != 1 { | ||||||
|  | 		t.Fatalf("Should have one address left") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 { | ||||||
|  | 		t.Fatalf("Should found one address") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var addresses2 []Address | ||||||
|  | 	if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 { | ||||||
|  | 		t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Model(&person).Association("Addresses").Count() != 1 { | ||||||
|  | 		t.Fatalf("Should found one address") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var addresses3 []Address | ||||||
|  | 	if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 { | ||||||
|  | 		t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 { | ||||||
|  | 		t.Fatalf("Should found soft deleted addresses with unscoped") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { | ||||||
|  | 		t.Fatalf("Should found soft deleted addresses with unscoped") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Model(&person).Association("Addresses").Clear() | ||||||
|  | 
 | ||||||
|  | 	if DB.Model(&person).Association("Addresses").Count() != 0 { | ||||||
|  | 		t.Fatalf("Should deleted all addresses") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { | ||||||
|  | 		t.Fatalf("Should found soft deleted addresses with unscoped") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Unscoped().Model(&person).Association("Addresses").Clear() | ||||||
|  | 
 | ||||||
|  | 	if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { | ||||||
|  | 		t.Fatalf("address should be deleted when clear with unscoped") | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu