Add JoinTableHandler
This commit is contained in:
		
							parent
							
								
									aa0a4012be
								
							
						
					
					
						commit
						6d64e6837b
					
				@ -71,15 +71,15 @@ func (association *Association) Delete(values ...interface{}) *Association {
 | 
				
			|||||||
	if len(primaryKeys) == 0 {
 | 
						if len(primaryKeys) == 0 {
 | 
				
			||||||
		association.setErr(errors.New("no primary key found"))
 | 
							association.setErr(errors.New("no primary key found"))
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
 | 
							scope := association.Scope
 | 
				
			||||||
		relationship := association.Field.Relationship
 | 
							relationship := association.Field.Relationship
 | 
				
			||||||
		// many to many
 | 
							// many to many
 | 
				
			||||||
		if relationship.Kind == "many_to_many" {
 | 
							if relationship.Kind == "many_to_many" {
 | 
				
			||||||
			whereSql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)",
 | 
								sql := fmt.Sprintf("%v.%v = ? AND %v.%v IN (?)",
 | 
				
			||||||
				relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName),
 | 
									scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName),
 | 
				
			||||||
				relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName))
 | 
									scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName))
 | 
				
			||||||
 | 
								query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, primaryKeys)
 | 
				
			||||||
			if err := association.Scope.DB().Table(relationship.JoinTable).
 | 
								if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
 | 
				
			||||||
				Where(whereSql, association.PrimaryKey, primaryKeys).Delete("").Error; err == nil {
 | 
					 | 
				
			||||||
				leftValues := reflect.Zero(association.Field.Field.Type())
 | 
									leftValues := reflect.Zero(association.Field.Field.Type())
 | 
				
			||||||
				for i := 0; i < association.Field.Field.Len(); i++ {
 | 
									for i := 0; i < association.Field.Field.Len(); i++ {
 | 
				
			||||||
					value := association.Field.Field.Index(i)
 | 
										value := association.Field.Field.Index(i)
 | 
				
			||||||
@ -132,11 +132,9 @@ func (association *Association) Replace(values ...interface{}) *Association {
 | 
				
			|||||||
			addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
 | 
								addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		whereSql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)",
 | 
							sql := fmt.Sprintf("%v.%v = ? AND %v.%v NOT IN (?)", scope.Quote(relationship.JoinTable), scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.JoinTable), scope.Quote(relationship.AssociationForeignDBName))
 | 
				
			||||||
			relationship.JoinTable, association.Scope.Quote(relationship.ForeignDBName),
 | 
							query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey, addedPrimaryKeys)
 | 
				
			||||||
			relationship.JoinTable, association.Scope.Quote(relationship.AssociationForeignDBName))
 | 
							association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship))
 | 
				
			||||||
 | 
					 | 
				
			||||||
		scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey, addedPrimaryKeys).Delete("")
 | 
					 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		association.setErr(errors.New("replace only support many to many"))
 | 
							association.setErr(errors.New("replace only support many to many"))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -147,8 +145,9 @@ func (association *Association) Clear() *Association {
 | 
				
			|||||||
	relationship := association.Field.Relationship
 | 
						relationship := association.Field.Relationship
 | 
				
			||||||
	scope := association.Scope
 | 
						scope := association.Scope
 | 
				
			||||||
	if relationship.Kind == "many_to_many" {
 | 
						if relationship.Kind == "many_to_many" {
 | 
				
			||||||
		whereSql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName))
 | 
							sql := fmt.Sprintf("%v.%v = ?", relationship.JoinTable, scope.Quote(relationship.ForeignDBName))
 | 
				
			||||||
		if err := scope.DB().Table(relationship.JoinTable).Where(whereSql, association.PrimaryKey).Delete("").Error; err == nil {
 | 
							query := scope.NewDB().Table(relationship.JoinTable).Where(sql, association.PrimaryKey)
 | 
				
			||||||
 | 
							if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
 | 
				
			||||||
			association.Field.Set(reflect.Zero(association.Field.Field.Type()))
 | 
								association.Field.Set(reflect.Zero(association.Field.Field.Type()))
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			association.setErr(err)
 | 
								association.setErr(err)
 | 
				
			||||||
@ -166,9 +165,10 @@ func (association *Association) Count() int {
 | 
				
			|||||||
	newScope := scope.New(association.Field.Field.Interface())
 | 
						newScope := scope.New(association.Field.Field.Interface())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if relationship.Kind == "many_to_many" {
 | 
						if relationship.Kind == "many_to_many" {
 | 
				
			||||||
		scope.DB().Table(relationship.JoinTable).
 | 
							query := scope.DB().Table(relationship.JoinTable).
 | 
				
			||||||
			Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
 | 
								Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
 | 
				
			||||||
			Where(relationship.ForeignDBName+" = ?", association.PrimaryKey).Row().Scan(&count)
 | 
								Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
 | 
				
			||||||
 | 
							scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count)
 | 
				
			||||||
	} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
 | 
						} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
 | 
				
			||||||
		whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
 | 
							whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
 | 
				
			||||||
		countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
 | 
							countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,10 +1,6 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import "reflect"
 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"reflect"
 | 
					 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func BeginTransaction(scope *Scope) {
 | 
					func BeginTransaction(scope *Scope) {
 | 
				
			||||||
	scope.Begin()
 | 
						scope.Begin()
 | 
				
			||||||
@ -53,24 +49,8 @@ func SaveAfterAssociations(scope *Scope) {
 | 
				
			|||||||
						scope.Err(newDB.Save(elem).Error)
 | 
											scope.Err(newDB.Save(elem).Error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
						if joinTable := relationship.JoinTable; joinTable != "" {
 | 
											if joinTable := relationship.JoinTable; joinTable != "" {
 | 
				
			||||||
							quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
 | 
												scope.Err(scope.db.GetJoinTableHandler(joinTable).
 | 
				
			||||||
							foreignValue := scope.PrimaryKeyValue()
 | 
													Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue()))
 | 
				
			||||||
							quoteAssociationForeignDBName := scope.Quote(relationship.AssociationForeignDBName)
 | 
					 | 
				
			||||||
							associationForeignValue := newScope.PrimaryKeyValue()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
							newScope.Raw(fmt.Sprintf(
 | 
					 | 
				
			||||||
								"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = %v AND %v = %v);",
 | 
					 | 
				
			||||||
								joinTable,
 | 
					 | 
				
			||||||
								strings.Join([]string{quotedForeignDBName, quoteAssociationForeignDBName}, ","),
 | 
					 | 
				
			||||||
								strings.Join([]string{newScope.AddToVars(foreignValue), newScope.AddToVars(associationForeignValue)}, ","),
 | 
					 | 
				
			||||||
								scope.Dialect().SelectFromDummyTable(),
 | 
					 | 
				
			||||||
								joinTable,
 | 
					 | 
				
			||||||
								quotedForeignDBName,
 | 
					 | 
				
			||||||
								newScope.AddToVars(foreignValue),
 | 
					 | 
				
			||||||
								quoteAssociationForeignDBName,
 | 
					 | 
				
			||||||
								newScope.AddToVars(associationForeignValue),
 | 
					 | 
				
			||||||
							))
 | 
					 | 
				
			||||||
							scope.Err(scope.NewDB().Exec(newScope.Sql, newScope.SqlVars...).Error)
 | 
					 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				default:
 | 
									default:
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										42
									
								
								join_table.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								join_table.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
				
			|||||||
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type JoinTableHandler interface {
 | 
				
			||||||
 | 
						Add(*DB, *Relationship, interface{}, interface{}) error
 | 
				
			||||||
 | 
						Delete(*DB, *Relationship) error
 | 
				
			||||||
 | 
						Scope(*DB, *Relationship) *DB
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type defaultJoinTableHandler struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (*defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error {
 | 
				
			||||||
 | 
						scope := db.NewScope("")
 | 
				
			||||||
 | 
						quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
 | 
				
			||||||
 | 
						quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sql := fmt.Sprintf(
 | 
				
			||||||
 | 
							"INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);",
 | 
				
			||||||
 | 
							scope.Quote(relationship.JoinTable),
 | 
				
			||||||
 | 
							strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","),
 | 
				
			||||||
 | 
							scope.Dialect().SelectFromDummyTable(),
 | 
				
			||||||
 | 
							scope.Quote(relationship.JoinTable),
 | 
				
			||||||
 | 
							quotedForeignDBName,
 | 
				
			||||||
 | 
							quotedAssociationDBName,
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (*defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error {
 | 
				
			||||||
 | 
						return db.Delete("").Error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (*defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB {
 | 
				
			||||||
 | 
						return db
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var DefaultJoinTableHandler = &defaultJoinTableHandler{}
 | 
				
			||||||
							
								
								
									
										41
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										41
									
								
								main.go
									
									
									
									
									
								
							@ -34,6 +34,7 @@ type DB struct {
 | 
				
			|||||||
	singularTable     bool
 | 
						singularTable     bool
 | 
				
			||||||
	source            string
 | 
						source            string
 | 
				
			||||||
	values            map[string]interface{}
 | 
						values            map[string]interface{}
 | 
				
			||||||
 | 
						joinTableHandlers map[string]JoinTableHandler
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func Open(dialect string, args ...interface{}) (DB, error) {
 | 
					func Open(dialect string, args ...interface{}) (DB, error) {
 | 
				
			||||||
@ -91,20 +92,6 @@ func (db *DB) NewScope(value interface{}) *Scope {
 | 
				
			|||||||
	return &Scope{db: dbClone, Search: dbClone.search, Value: value}
 | 
						return &Scope{db: dbClone, Search: dbClone.search, Value: value}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) FreshDB() *DB {
 | 
					 | 
				
			||||||
	newDB := &DB{
 | 
					 | 
				
			||||||
		dialect:      s.dialect,
 | 
					 | 
				
			||||||
		logger:       s.logger,
 | 
					 | 
				
			||||||
		callback:     s.parent.callback.clone(),
 | 
					 | 
				
			||||||
		source:       s.source,
 | 
					 | 
				
			||||||
		values:       map[string]interface{}{},
 | 
					 | 
				
			||||||
		db:           s.db,
 | 
					 | 
				
			||||||
		ModelStructs: map[reflect.Type]*ModelStruct{},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	newDB.parent = newDB
 | 
					 | 
				
			||||||
	return newDB
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// CommonDB Return the underlying sql.DB or sql.Tx instance.
 | 
					// CommonDB Return the underlying sql.DB or sql.Tx instance.
 | 
				
			||||||
// Use of this method is discouraged. It's mainly intended to allow
 | 
					// Use of this method is discouraged. It's mainly intended to allow
 | 
				
			||||||
// coexistence with legacy non-GORM code.
 | 
					// coexistence with legacy non-GORM code.
 | 
				
			||||||
@ -473,3 +460,29 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
 | 
				
			|||||||
	value, ok = s.values[name]
 | 
						value, ok = s.values[name]
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *DB) GetJoinTableHandler(table string) JoinTableHandler {
 | 
				
			||||||
 | 
						if s.parent.joinTableHandlers != nil {
 | 
				
			||||||
 | 
							if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok {
 | 
				
			||||||
 | 
								return joinTableHandler
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok {
 | 
				
			||||||
 | 
								return joinTableHandler
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return DefaultJoinTableHandler
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) {
 | 
				
			||||||
 | 
						if s.parent.joinTableHandlers == nil {
 | 
				
			||||||
 | 
							s.parent.joinTableHandlers = map[string]JoinTableHandler{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(tables) > 0 {
 | 
				
			||||||
 | 
							for _, table := range tables {
 | 
				
			||||||
 | 
								s.parent.joinTableHandlers[table] = joinTableHandler
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							s.parent.joinTableHandlers["*"] = joinTableHandler
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user