join table handler
This commit is contained in:
parent
b803039c31
commit
6a862765b6
155
join_table_handler.go
Normal file
155
join_table_handler.go
Normal file
@ -0,0 +1,155 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type JoinTableHandlerInterface interface {
|
||||
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||
Table(db *DB) string
|
||||
Add(db *DB, source interface{}, destination interface{}) error
|
||||
Delete(db *DB, sources ...interface{}) error
|
||||
JoinWith(db *DB, source interface{}) *DB
|
||||
}
|
||||
|
||||
type JoinTableForeignKey struct {
|
||||
DBName string
|
||||
AssociationDBName string
|
||||
}
|
||||
|
||||
type JoinTableSource struct {
|
||||
ModelType reflect.Type
|
||||
ForeignKeys []JoinTableForeignKey
|
||||
}
|
||||
|
||||
type JoinTableHandler struct {
|
||||
TableName string `sql:"-"`
|
||||
Source JoinTableSource `sql:"-"`
|
||||
Destination JoinTableSource `sql:"-"`
|
||||
}
|
||||
|
||||
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||
s.TableName = tableName
|
||||
|
||||
s.Source = JoinTableSource{ModelType: source}
|
||||
sourceScope := &Scope{Value: reflect.New(source).Interface()}
|
||||
for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields {
|
||||
if relationship.ForeignDBName == "" {
|
||||
relationship.ForeignFieldName = source.Name() + primaryField.Name
|
||||
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
|
||||
}
|
||||
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: relationship.ForeignDBName,
|
||||
AssociationDBName: primaryField.DBName,
|
||||
})
|
||||
}
|
||||
|
||||
s.Destination = JoinTableSource{ModelType: destination}
|
||||
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
|
||||
for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields {
|
||||
if relationship.AssociationForeignDBName == "" {
|
||||
relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name
|
||||
relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName)
|
||||
}
|
||||
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: relationship.AssociationForeignDBName,
|
||||
AssociationDBName: primaryField.DBName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) Table(*DB) string {
|
||||
return s.TableName
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||
values := map[string]interface{}{}
|
||||
|
||||
for _, source := range sources {
|
||||
scope := db.NewScope(source)
|
||||
modelType := scope.GetModelStruct().ModelType
|
||||
|
||||
if s.Source.ModelType == modelType {
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
||||
}
|
||||
} else if s.Destination.ModelType == modelType {
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
|
||||
scope := db.NewScope("")
|
||||
searchMap := s.GetSearchMap(db, source1, source2)
|
||||
|
||||
var assignColumns, binVars, conditions []string
|
||||
var values []interface{}
|
||||
for key, value := range searchMap {
|
||||
assignColumns = append(assignColumns, key)
|
||||
binVars = append(binVars, `?`)
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
quotedTable := s.Table(db)
|
||||
sql := fmt.Sprintf(
|
||||
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v);",
|
||||
quotedTable,
|
||||
strings.Join(assignColumns, ","),
|
||||
strings.Join(binVars, ","),
|
||||
scope.Dialect().SelectFromDummyTable(),
|
||||
quotedTable,
|
||||
strings.Join(conditions, " AND "),
|
||||
)
|
||||
|
||||
return db.Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
|
||||
var conditions []string
|
||||
var values []interface{}
|
||||
|
||||
for key, value := range s.GetSearchMap(db, sources...) {
|
||||
conditions = append(conditions, fmt.Sprintf("%v = ?", key))
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
|
||||
quotedTable := s.Table(db)
|
||||
|
||||
scope := db.NewScope(source)
|
||||
modelType := scope.GetModelStruct().ModelType
|
||||
var joinConditions []string
|
||||
var queryConditions []string
|
||||
var values []interface{}
|
||||
if s.Source.ModelType == modelType {
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||
}
|
||||
|
||||
for _, foreignKey := range s.Source.ForeignKeys {
|
||||
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
|
||||
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
|
||||
}
|
||||
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
||||
Where(strings.Join(queryConditions, " AND "), values...)
|
||||
} else {
|
||||
db.Error = errors.New("wrong source type for join table handler")
|
||||
return db
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user