Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
9e131e5ec0
@ -138,12 +138,14 @@ db.SingularTable(true)
|
||||
```go
|
||||
// Create table
|
||||
db.CreateTable(&User{})
|
||||
db.Set("gorm:table_options", "ENGINE=InnoDB").CreateTable(&User{})
|
||||
|
||||
// Drop table
|
||||
db.DropTable(&User{})
|
||||
|
||||
// Automating Migration
|
||||
db.AutoMigrate(&User{})
|
||||
db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{})
|
||||
db.AutoMigrate(&User{}, &Product{}, &Order{})
|
||||
// Feel free to change your struct, AutoMigrate will keep your database up-to-date.
|
||||
// AutoMigrate will ONLY add *new columns* and *new indexes*,
|
||||
@ -1126,7 +1128,7 @@ type Product struct {
|
||||
// 2nd param : destination table(id)
|
||||
// 3rd param : ONDELETE
|
||||
// 4th param : ONUPDATE
|
||||
db.Model(&User{}).AddForeignKey("role_id", "roles", "CASCADE", "RESTRICT")
|
||||
db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||
|
||||
// Add index
|
||||
db.Model(&User{}).AddIndex("idx_user_name", "name")
|
||||
|
211
association.go
211
association.go
@ -4,14 +4,14 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Association struct {
|
||||
Scope *Scope
|
||||
PrimaryKey interface{}
|
||||
Column string
|
||||
Error error
|
||||
Field *Field
|
||||
Scope *Scope
|
||||
Column string
|
||||
Error error
|
||||
Field *Field
|
||||
}
|
||||
|
||||
func (association *Association) setErr(err error) *Association {
|
||||
@ -45,60 +45,42 @@ func (association *Association) Append(values ...interface{}) *Association {
|
||||
return association.setErr(scope.db.Error)
|
||||
}
|
||||
|
||||
func (association *Association) getPrimaryKeys(values ...interface{}) []interface{} {
|
||||
primaryKeys := []interface{}{}
|
||||
func (association *Association) Delete(values ...interface{}) *Association {
|
||||
scope := association.Scope
|
||||
relationship := association.Field.Relationship
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
if reflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank {
|
||||
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
|
||||
}
|
||||
}
|
||||
} else if reflectValue.Kind() == reflect.Struct {
|
||||
if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank {
|
||||
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
|
||||
// many to many
|
||||
if relationship.Kind == "many_to_many" {
|
||||
query := scope.NewDB()
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
return primaryKeys
|
||||
}
|
||||
|
||||
func (association *Association) Delete(values ...interface{}) *Association {
|
||||
primaryKeys := association.getPrimaryKeys(values...)
|
||||
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
|
||||
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys))
|
||||
query = query.Where(sql, toQueryValues(primaryKeys)...)
|
||||
|
||||
if len(primaryKeys) == 0 {
|
||||
association.setErr(errors.New("no primary key found"))
|
||||
} else {
|
||||
scope := association.Scope
|
||||
relationship := association.Field.Relationship
|
||||
// many to many
|
||||
if relationship.Kind == "many_to_many" {
|
||||
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
|
||||
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
|
||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||
value := association.Field.Field.Index(i)
|
||||
if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil {
|
||||
var included = false
|
||||
for _, primaryKey := range primaryKeys {
|
||||
if equalAsString(primaryKey, primaryField.Field.Interface()) {
|
||||
included = true
|
||||
}
|
||||
}
|
||||
if !included {
|
||||
leftValues = reflect.Append(leftValues, value)
|
||||
}
|
||||
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
|
||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||
reflectValue := association.Field.Field.Index(i)
|
||||
primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0]
|
||||
var included = false
|
||||
for _, pk := range primaryKeys {
|
||||
if equalAsString(primaryKey, pk) {
|
||||
included = true
|
||||
}
|
||||
}
|
||||
association.Field.Set(leftValues)
|
||||
if !included {
|
||||
leftValues = reflect.Append(leftValues, reflectValue)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
association.setErr(errors.New("delete only support many to many"))
|
||||
association.Field.Set(leftValues)
|
||||
}
|
||||
} else {
|
||||
association.setErr(errors.New("delete only support many to many"))
|
||||
}
|
||||
return association
|
||||
}
|
||||
@ -109,16 +91,16 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||
if relationship.Kind == "many_to_many" {
|
||||
field := association.Field.Field
|
||||
|
||||
oldPrimaryKeys := association.getPrimaryKeys(field.Interface())
|
||||
oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface())
|
||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||
association.Append(values...)
|
||||
newPrimaryKeys := association.getPrimaryKeys(field.Interface())
|
||||
newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface())
|
||||
|
||||
var addedPrimaryKeys = []interface{}{}
|
||||
var addedPrimaryKeys = [][]interface{}{}
|
||||
for _, newKey := range newPrimaryKeys {
|
||||
hasEqual := false
|
||||
for _, oldKey := range oldPrimaryKeys {
|
||||
if reflect.DeepEqual(newKey, oldKey) {
|
||||
if equalAsString(newKey, oldKey) {
|
||||
hasEqual = true
|
||||
break
|
||||
}
|
||||
@ -127,13 +109,21 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||
addedPrimaryKeys = append(addedPrimaryKeys, newKey)
|
||||
}
|
||||
}
|
||||
for _, primaryKey := range association.getPrimaryKeys(values...) {
|
||||
|
||||
for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) {
|
||||
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
|
||||
}
|
||||
|
||||
if len(addedPrimaryKeys) > 0 {
|
||||
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
|
||||
query := scope.NewDB()
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys))
|
||||
query = query.Where(sql, toQueryValues(addedPrimaryKeys)...)
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship))
|
||||
}
|
||||
} else {
|
||||
@ -146,8 +136,13 @@ func (association *Association) Clear() *Association {
|
||||
relationship := association.Field.Relationship
|
||||
scope := association.Scope
|
||||
if relationship.Kind == "many_to_many" {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
|
||||
query := scope.NewDB().Where(sql, association.PrimaryKey)
|
||||
query := scope.NewDB()
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
|
||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||
} else {
|
||||
@ -168,18 +163,104 @@ func (association *Association) Count() int {
|
||||
if relationship.Kind == "many_to_many" {
|
||||
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
|
||||
query := scope.DB()
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
|
||||
field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||
}
|
||||
countScope.Count(&count)
|
||||
query.Table(newScope.TableName()).Count(&count)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
if v, ok := scope.FieldByName(association.Column); ok {
|
||||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||
scope.DB().Table(newScope.TableName()).Where(whereSql, v).Count(&count)
|
||||
query := scope.DB()
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
|
||||
field.Field.Interface())
|
||||
}
|
||||
}
|
||||
query.Table(newScope.TableName()).Count(&count)
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} {
|
||||
results := [][]interface{}{}
|
||||
scope := association.Scope
|
||||
|
||||
for _, value := range values {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
if reflectValue.Kind() == reflect.Slice {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
primaryKeys := []interface{}{}
|
||||
newScope := scope.New(reflectValue.Index(i).Interface())
|
||||
for _, column := range columns {
|
||||
if field, ok := newScope.FieldByName(column); ok {
|
||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||
} else {
|
||||
primaryKeys = append(primaryKeys, "")
|
||||
}
|
||||
}
|
||||
results = append(results, primaryKeys)
|
||||
}
|
||||
} else if reflectValue.Kind() == reflect.Struct {
|
||||
newScope := scope.New(value)
|
||||
var primaryKeys []interface{}
|
||||
for _, column := range columns {
|
||||
if field, ok := newScope.FieldByName(column); ok {
|
||||
primaryKeys = append(primaryKeys, field.Field.Interface())
|
||||
} else {
|
||||
primaryKeys = append(primaryKeys, "")
|
||||
}
|
||||
}
|
||||
|
||||
results = append(results, primaryKeys)
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func toQueryMarks(primaryValues [][]interface{}) string {
|
||||
var results []string
|
||||
|
||||
for _, primaryValue := range primaryValues {
|
||||
var marks []string
|
||||
for _,_ = range primaryValue {
|
||||
marks = append(marks, "?")
|
||||
}
|
||||
|
||||
if len(marks) > 1 {
|
||||
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
|
||||
} else {
|
||||
results = append(results, strings.Join(marks, ""))
|
||||
}
|
||||
}
|
||||
return strings.Join(results, ",")
|
||||
}
|
||||
|
||||
func toQueryCondition(scope *Scope, columns []string) string {
|
||||
var newColumns []string
|
||||
for _, column := range columns {
|
||||
newColumns = append(newColumns, scope.Quote(column))
|
||||
}
|
||||
|
||||
if len(columns) > 1 {
|
||||
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
|
||||
} else {
|
||||
return strings.Join(columns, ",")
|
||||
}
|
||||
}
|
||||
|
||||
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
|
||||
for _, primaryValue := range primaryValues {
|
||||
for _, value := range primaryValue {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
@ -23,11 +23,11 @@ func TestHasOneAndHasManyAssociation(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := DB.Save(&post).Error; err != nil {
|
||||
t.Errorf("Got errors when save post")
|
||||
t.Errorf("Got errors when save post", err.Error())
|
||||
}
|
||||
|
||||
if DB.First(&Category{}, "name = ?", "Category 1").Error != nil {
|
||||
t.Errorf("Category should be saved")
|
||||
if err := DB.First(&Category{}, "name = ?", "Category 1").Error; err != nil {
|
||||
t.Errorf("Category should be saved", err.Error())
|
||||
}
|
||||
|
||||
var p Post
|
||||
@ -186,6 +186,7 @@ func TestManyToMany(t *testing.T) {
|
||||
var language Language
|
||||
DB.Where("name = ?", "EE").First(&language)
|
||||
DB.Model(&user).Association("Languages").Delete(language, &language)
|
||||
|
||||
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
|
||||
t.Errorf("Relations should be deleted with Delete")
|
||||
}
|
||||
|
@ -35,9 +35,11 @@ func Create(scope *Scope) {
|
||||
}
|
||||
}
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) {
|
||||
columns = append(columns, scope.Quote(relationField.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
|
||||
for _, dbName := range relationship.ForeignDBNames {
|
||||
if relationField := fields[dbName]; !scope.changeableField(relationField) {
|
||||
columns = append(columns, scope.Quote(relationField.DBName))
|
||||
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ func Query(scope *Scope) {
|
||||
if kind := dest.Kind(); kind == reflect.Slice {
|
||||
isSlice = true
|
||||
destType = dest.Type().Elem()
|
||||
dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType))))
|
||||
dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
|
||||
|
||||
if destType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
|
@ -19,8 +19,13 @@ func SaveBeforeAssociations(scope *Scope) {
|
||||
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
value := field.Field
|
||||
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
|
||||
if relationship.ForeignFieldName != "" {
|
||||
scope.Err(scope.SetColumn(relationship.ForeignFieldName, scope.New(value.Addr().Interface()).PrimaryKeyValue()))
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
|
||||
scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -44,8 +49,13 @@ func SaveAfterAssociations(scope *Scope) {
|
||||
elem := value.Index(i).Addr().Interface()
|
||||
newScope := newDB.NewScope(elem)
|
||||
|
||||
if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" {
|
||||
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
|
||||
if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
@ -61,8 +71,13 @@ func SaveAfterAssociations(scope *Scope) {
|
||||
default:
|
||||
elem := value.Addr().Interface()
|
||||
newScope := scope.New(elem)
|
||||
if relationship.ForeignFieldName != "" {
|
||||
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
associationForeignName := relationship.AssociationForeignDBNames[idx]
|
||||
if f, ok := scope.FieldByName(associationForeignName); ok {
|
||||
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
|
@ -55,9 +55,10 @@ func Update(scope *Scope) {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) {
|
||||
if !relationField.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())))
|
||||
for _, dbName := range relationship.ForeignDBNames {
|
||||
if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
|
||||
sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))
|
||||
sqls = append(sqls, sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -70,8 +70,8 @@ func (commonDialect) Quote(key string) string {
|
||||
}
|
||||
|
||||
func (commonDialect) databaseName(scope *Scope) string {
|
||||
from := strings.Index(scope.db.parent.source, "/") + 1
|
||||
to := strings.Index(scope.db.parent.source, "?")
|
||||
from := strings.LastIndex(scope.db.parent.source, "/") + 1
|
||||
to := strings.LastIndex(scope.db.parent.source, "?")
|
||||
if to == -1 {
|
||||
to = len(scope.db.parent.source)
|
||||
}
|
||||
|
@ -45,41 +45,18 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
|
||||
s.TableName = tableName
|
||||
|
||||
s.Source = JoinTableSource{ModelType: source}
|
||||
sourceScope := &Scope{Value: reflect.New(source).Interface()}
|
||||
sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields
|
||||
for _, primaryField := range sourcePrimaryFields {
|
||||
if relationship.ForeignDBName == "" {
|
||||
relationship.ForeignFieldName = source.Name() + primaryField.Name
|
||||
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
|
||||
}
|
||||
|
||||
var dbName string
|
||||
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
|
||||
dbName = relationship.ForeignDBName
|
||||
} else {
|
||||
dbName = ToDBName(source.Name() + primaryField.Name)
|
||||
}
|
||||
|
||||
for idx, dbName := range relationship.ForeignFieldNames {
|
||||
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: dbName,
|
||||
AssociationDBName: primaryField.DBName,
|
||||
DBName: relationship.ForeignDBNames[idx],
|
||||
AssociationDBName: dbName,
|
||||
})
|
||||
}
|
||||
|
||||
s.Destination = JoinTableSource{ModelType: destination}
|
||||
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
|
||||
destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields
|
||||
for _, primaryField := range destinationPrimaryFields {
|
||||
var dbName string
|
||||
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
|
||||
dbName = relationship.AssociationForeignDBName
|
||||
} else {
|
||||
dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name)
|
||||
}
|
||||
|
||||
for idx, dbName := range relationship.AssociationForeignFieldNames {
|
||||
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: dbName,
|
||||
AssociationDBName: primaryField.DBName,
|
||||
DBName: relationship.AssociationForeignDBNames[idx],
|
||||
AssociationDBName: dbName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
10
main.go
10
main.go
@ -118,7 +118,7 @@ func (s *DB) Callback() *callback {
|
||||
}
|
||||
|
||||
func (s *DB) SetLogger(l logger) {
|
||||
s.parent.logger = l
|
||||
s.logger = l
|
||||
}
|
||||
|
||||
func (s *DB) LogMode(enable bool) *DB {
|
||||
@ -259,9 +259,9 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
|
||||
if !result.RecordNotFound() {
|
||||
return result
|
||||
}
|
||||
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
|
||||
c.err(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates).db.Error)
|
||||
} else if len(c.search.assignAttrs) > 0 {
|
||||
c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates)
|
||||
c.err(c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates).db.Error)
|
||||
}
|
||||
return c
|
||||
}
|
||||
@ -450,10 +450,10 @@ func (s *DB) Association(column string) *Association {
|
||||
err = errors.New("primary key can't be nil")
|
||||
} else {
|
||||
if field, ok := scope.FieldByName(column); ok {
|
||||
if field.Relationship == nil || field.Relationship.ForeignFieldName == "" {
|
||||
if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
|
||||
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
|
||||
} else {
|
||||
return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field}
|
||||
return &Association{Scope: scope, Column: column, Field: field}
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)
|
||||
|
@ -3,7 +3,7 @@ package gorm
|
||||
import "time"
|
||||
|
||||
func (s *DB) clone() *DB {
|
||||
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
|
||||
db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
|
||||
|
||||
for key, value := range s.values {
|
||||
db.values[key] = value
|
||||
@ -34,7 +34,7 @@ func (s *DB) err(err error) error {
|
||||
}
|
||||
|
||||
func (s *DB) print(v ...interface{}) {
|
||||
s.parent.logger.(logger).Print(v...)
|
||||
s.logger.(logger).Print(v...)
|
||||
}
|
||||
|
||||
func (s *DB) log(v ...interface{}) {
|
||||
|
173
model_struct.go
173
model_struct.go
@ -5,10 +5,11 @@ import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qor/inflection"
|
||||
)
|
||||
|
||||
var modelStructs = map[reflect.Type]*ModelStruct{}
|
||||
@ -61,19 +62,16 @@ func (structField *StructField) clone() *StructField {
|
||||
}
|
||||
|
||||
type Relationship struct {
|
||||
Kind string
|
||||
PolymorphicType string
|
||||
PolymorphicDBName string
|
||||
ForeignFieldName string
|
||||
ForeignDBName string
|
||||
AssociationForeignFieldName string
|
||||
AssociationForeignDBName string
|
||||
JoinTableHandler JoinTableHandlerInterface
|
||||
Kind string
|
||||
PolymorphicType string
|
||||
PolymorphicDBName string
|
||||
ForeignFieldNames []string
|
||||
ForeignDBNames []string
|
||||
AssociationForeignFieldNames []string
|
||||
AssociationForeignDBNames []string
|
||||
JoinTableHandler JoinTableHandlerInterface
|
||||
}
|
||||
|
||||
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
||||
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
|
||||
|
||||
func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
var modelStruct ModelStruct
|
||||
|
||||
@ -113,11 +111,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
} else {
|
||||
name := ToDBName(scopeType.Name())
|
||||
if scope.db == nil || !scope.db.parent.singularTable {
|
||||
for index, reg := range pluralMapKeys {
|
||||
if reg.MatchString(name) {
|
||||
name = reg.ReplaceAllString(name, pluralMapValues[index])
|
||||
}
|
||||
}
|
||||
name = inflection.Plural(name)
|
||||
}
|
||||
|
||||
modelStruct.defaultTableName = name
|
||||
@ -190,12 +184,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
|
||||
var relationship = &Relationship{}
|
||||
|
||||
foreignKey := gormSettings["FOREIGNKEY"]
|
||||
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
|
||||
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
|
||||
relationship.ForeignFieldName = polymorphicField.Name
|
||||
relationship.ForeignDBName = polymorphicField.DBName
|
||||
relationship.ForeignFieldNames = []string{polymorphicField.Name}
|
||||
relationship.ForeignDBNames = []string{polymorphicField.DBName}
|
||||
relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name}
|
||||
relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName}
|
||||
relationship.PolymorphicType = polymorphicType.Name
|
||||
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||
polymorphicType.IsForeignKey = true
|
||||
@ -204,6 +199,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
}
|
||||
}
|
||||
|
||||
var foreignKeys []string
|
||||
if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok {
|
||||
foreignKeys = append(foreignKeys, foreignKey)
|
||||
}
|
||||
switch indirectType.Kind() {
|
||||
case reflect.Slice:
|
||||
elemType := indirectType.Elem()
|
||||
@ -212,21 +211,41 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
}
|
||||
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
if foreignKey == "" {
|
||||
foreignKey = scopeType.Name() + "Id"
|
||||
}
|
||||
|
||||
if many2many := gormSettings["MANY2MANY"]; many2many != "" {
|
||||
relationship.Kind = "many_to_many"
|
||||
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
|
||||
if associationForeignKey == "" {
|
||||
associationForeignKey = elemType.Name() + "Id"
|
||||
|
||||
// foreign keys
|
||||
if len(foreignKeys) == 0 {
|
||||
for _, field := range scope.PrimaryFields() {
|
||||
foreignKeys = append(foreignKeys, field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
relationship.ForeignFieldName = foreignKey
|
||||
relationship.ForeignDBName = ToDBName(foreignKey)
|
||||
relationship.AssociationForeignFieldName = associationForeignKey
|
||||
relationship.AssociationForeignDBName = ToDBName(associationForeignKey)
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
|
||||
joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
|
||||
}
|
||||
}
|
||||
|
||||
// association foreign keys
|
||||
var associationForeignKeys []string
|
||||
if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
|
||||
associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]}
|
||||
} else {
|
||||
for _, field := range toScope.PrimaryFields() {
|
||||
associationForeignKeys = append(associationForeignKeys, field.DBName)
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range associationForeignKeys {
|
||||
if field, ok := toScope.FieldByName(name); ok {
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
|
||||
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
|
||||
}
|
||||
}
|
||||
|
||||
joinTableHandler := JoinTableHandler{}
|
||||
joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
|
||||
@ -234,12 +253,30 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
field.Relationship = relationship
|
||||
} else {
|
||||
relationship.Kind = "has_many"
|
||||
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
||||
relationship.ForeignFieldName = foreignField.Name
|
||||
relationship.ForeignDBName = foreignField.DBName
|
||||
foreignField.IsForeignKey = true
|
||||
field.Relationship = relationship
|
||||
} else if relationship.ForeignFieldName != "" {
|
||||
|
||||
if len(foreignKeys) == 0 {
|
||||
for _, field := range scope.PrimaryFields() {
|
||||
if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil {
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName)
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||
foreignField.IsForeignKey = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName)
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||
foreignField.IsForeignKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
field.Relationship = relationship
|
||||
}
|
||||
}
|
||||
@ -258,28 +295,56 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
belongsToForeignKey := foreignKey
|
||||
if belongsToForeignKey == "" {
|
||||
belongsToForeignKey = field.Name + "Id"
|
||||
if len(foreignKeys) == 0 {
|
||||
for _, f := range scope.PrimaryFields() {
|
||||
if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil {
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName)
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||
foreignField.IsForeignKey = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName)
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||
foreignField.IsForeignKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil {
|
||||
relationship.Kind = "belongs_to"
|
||||
relationship.ForeignFieldName = foreignField.Name
|
||||
relationship.ForeignDBName = foreignField.DBName
|
||||
foreignField.IsForeignKey = true
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
relationship.Kind = "has_one"
|
||||
field.Relationship = relationship
|
||||
} else {
|
||||
if foreignKey == "" {
|
||||
foreignKey = modelStruct.ModelType.Name() + "Id"
|
||||
if len(foreignKeys) == 0 {
|
||||
for _, f := range toScope.PrimaryFields() {
|
||||
if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil {
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName)
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||
foreignField.IsForeignKey = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, foreignKey := range foreignKeys {
|
||||
if foreignField := getForeignField(foreignKey, fields); foreignField != nil {
|
||||
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name)
|
||||
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName)
|
||||
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
|
||||
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
|
||||
foreignField.IsForeignKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
relationship.Kind = "has_one"
|
||||
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
||||
relationship.ForeignFieldName = foreignField.Name
|
||||
relationship.ForeignDBName = foreignField.DBName
|
||||
foreignField.IsForeignKey = true
|
||||
field.Relationship = relationship
|
||||
} else if relationship.ForeignFieldName != "" {
|
||||
|
||||
if len(relationship.ForeignFieldNames) != 0 {
|
||||
relationship.Kind = "belongs_to"
|
||||
field.Relationship = relationship
|
||||
}
|
||||
}
|
||||
|
64
preload.go
64
preload.go
@ -8,12 +8,15 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getRealValue(value reflect.Value, field string) interface{} {
|
||||
result := reflect.Indirect(value).FieldByName(field).Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
|
||||
for _, column := range columns {
|
||||
result := reflect.Indirect(value).FieldByName(column).Interface()
|
||||
if r, ok := result.(driver.Valuer); ok {
|
||||
result, _ = r.Value()
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
return result
|
||||
return
|
||||
}
|
||||
|
||||
func equalAsString(a interface{}, b interface{}) bool {
|
||||
@ -97,26 +100,24 @@ func makeSlice(typ reflect.Type) interface{} {
|
||||
}
|
||||
|
||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||
primaryName := scope.PrimaryField().Name
|
||||
primaryKeys := scope.getColumnAsArray(primaryName)
|
||||
relation := field.Relationship
|
||||
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
relation := field.Relationship
|
||||
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
|
||||
|
||||
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
|
||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
value := getRealValue(result, relation.ForeignFieldNames)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
|
||||
if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
|
||||
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
|
||||
break
|
||||
}
|
||||
@ -131,27 +132,24 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
|
||||
}
|
||||
|
||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
||||
primaryName := scope.PrimaryField().Name
|
||||
primaryKeys := scope.getColumnAsArray(primaryName)
|
||||
relation := field.Relationship
|
||||
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
relation := field.Relationship
|
||||
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
|
||||
|
||||
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
|
||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
value := getRealValue(result, relation.ForeignFieldNames)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if equalAsString(getRealValue(object, primaryName), value) {
|
||||
if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) {
|
||||
f := object.FieldByName(field.Name)
|
||||
f.Set(reflect.Append(f, result))
|
||||
break
|
||||
@ -165,25 +163,23 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
|
||||
|
||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName)
|
||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
associationPrimaryKey := scope.New(results).PrimaryField().Name
|
||||
|
||||
scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error)
|
||||
scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
value := getRealValue(result, associationPrimaryKey)
|
||||
value := getRealValue(result, relation.AssociationForeignFieldNames)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
object := reflect.Indirect(objects.Index(j))
|
||||
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
|
||||
if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
|
||||
object.FieldByName(field.Name).Set(result)
|
||||
}
|
||||
}
|
||||
@ -193,15 +189,23 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) {
|
||||
func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
|
||||
values := scope.IndirectValue()
|
||||
switch values.Kind() {
|
||||
case reflect.Slice:
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
|
||||
var result []interface{}
|
||||
for _, column := range columns {
|
||||
result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
case reflect.Struct:
|
||||
return []interface{}{values.FieldByName(column).Interface()}
|
||||
var result []interface{}
|
||||
for _, column := range columns {
|
||||
result = append(result, values.FieldByName(column).Interface())
|
||||
}
|
||||
return [][]interface{}{result}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/jinzhu/now"
|
||||
|
||||
"testing"
|
||||
@ -556,7 +557,7 @@ func TestSelectWithEscapedFieldName(t *testing.T) {
|
||||
func TestSelectWithVariables(t *testing.T) {
|
||||
DB.Save(&User{Name: "jinzhu"})
|
||||
|
||||
rows, _ := DB.Table("users").Select("? as fake", "name").Rows()
|
||||
rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows()
|
||||
|
||||
if !rows.Next() {
|
||||
t.Errorf("Should have returned at least one row")
|
||||
|
8
scope.go
8
scope.go
@ -3,6 +3,7 @@ package gorm
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -87,6 +88,13 @@ func (scope *Scope) Quote(str string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func (scope *Scope) QuoteIfPossible(str string) string {
|
||||
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
|
||||
return scope.Quote(str)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// Dialect get dialect
|
||||
func (scope *Scope) Dialect() Dialect {
|
||||
return scope.db.parent.dialect
|
||||
|
@ -149,7 +149,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
|
||||
if valuer, ok := interface{}(arg).(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
str = strings.Replace(str, "?", scope.Dialect().Quote(fmt.Sprintf("%v", arg)), 1)
|
||||
str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
|
||||
}
|
||||
}
|
||||
return
|
||||
@ -265,10 +265,24 @@ func (scope *Scope) groupSql() string {
|
||||
}
|
||||
|
||||
func (scope *Scope) havingSql() string {
|
||||
if scope.Search.havingCondition == nil {
|
||||
if scope.Search.havingConditions == nil {
|
||||
return ""
|
||||
}
|
||||
return " HAVING " + scope.buildWhereCondition(scope.Search.havingCondition)
|
||||
|
||||
var andConditions []string
|
||||
|
||||
for _, clause := range scope.Search.havingConditions {
|
||||
if sql := scope.buildWhereCondition(clause); sql != "" {
|
||||
andConditions = append(andConditions, sql)
|
||||
}
|
||||
}
|
||||
|
||||
combinedSql := strings.Join(andConditions, " AND ")
|
||||
if len(combinedSql) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return " HAVING " + combinedSql
|
||||
}
|
||||
|
||||
func (scope *Scope) joinsSql() string {
|
||||
@ -415,12 +429,21 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()
|
||||
scope.Err(toScope.db.Where(sql, foreignKeyValue).Find(value).Error)
|
||||
query := toScope.db
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(foreignKey); ok {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
scope.Err(query.Find(value).Error)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
|
||||
query := toScope.db.Where(sql, scope.PrimaryKeyValue())
|
||||
query := toScope.db
|
||||
for idx, foreignKey := range relationship.ForeignDBNames {
|
||||
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
if relationship.PolymorphicType != "" {
|
||||
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
|
||||
}
|
||||
@ -442,6 +465,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
return scope
|
||||
}
|
||||
|
||||
/**
|
||||
Return the table options string or an empty string if the table options does not exist
|
||||
*/
|
||||
func (scope *Scope) getTableOptions() string{
|
||||
tableOptions, ok := scope.Get("gorm:table_options")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return tableOptions.(string)
|
||||
}
|
||||
|
||||
func (scope *Scope) createJoinTable(field *StructField) {
|
||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
@ -450,16 +484,22 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||
|
||||
var sqlTypes []string
|
||||
for _, s := range []*Scope{scope, toScope} {
|
||||
for _, primaryField := range s.GetModelStruct().PrimaryFields {
|
||||
value := reflect.Indirect(reflect.New(primaryField.Struct.Type))
|
||||
for idx, fieldName := range relationship.ForeignFieldNames {
|
||||
if field, ok := scope.Fields()[fieldName]; ok {
|
||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
|
||||
dbName := ToDBName(s.GetModelStruct().ModelType.Name() + primaryField.Name)
|
||||
sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType)
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join(sqlTypes, ","))).Error)
|
||||
for idx, fieldName := range relationship.AssociationForeignFieldNames {
|
||||
if field, ok := toScope.Fields()[fieldName]; ok {
|
||||
value := reflect.Indirect(reflect.New(field.Struct.Type))
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
|
||||
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
|
||||
}
|
||||
}
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error)
|
||||
}
|
||||
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
||||
}
|
||||
@ -484,7 +524,7 @@ func (scope *Scope) createTable() *Scope {
|
||||
if len(primaryKeys) > 0 {
|
||||
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
|
||||
}
|
||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec()
|
||||
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
|
||||
return scope
|
||||
}
|
||||
|
||||
@ -515,11 +555,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||
|
||||
var columns []string
|
||||
for _, name := range column {
|
||||
if regexp.MustCompile("^[a-zA-Z]+$").MatchString(name) {
|
||||
columns = append(columns, scope.Quote(name))
|
||||
} else {
|
||||
columns = append(columns, name)
|
||||
}
|
||||
columns = append(columns, scope.QuoteIfPossible(name))
|
||||
}
|
||||
|
||||
sqlCreate := "CREATE INDEX"
|
||||
@ -532,9 +568,10 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
|
||||
|
||||
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
|
||||
var table = scope.TableName()
|
||||
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
|
||||
var keyName = fmt.Sprintf("%s_%s_%s_foreign", table, field, regexp.MustCompile("[^a-zA-Z]").ReplaceAllString(dest, "_"))
|
||||
keyName = regexp.MustCompile("_+").ReplaceAllString(keyName, "_")
|
||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.Quote(keyName), scope.Quote(field), scope.Quote(dest), onDelete, onUpdate)).Exec()
|
||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), scope.QuoteIfPossible(dest), onDelete, onUpdate)).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) removeIndex(indexName string) {
|
||||
|
46
search.go
46
search.go
@ -3,24 +3,24 @@ package gorm
|
||||
import "fmt"
|
||||
|
||||
type search struct {
|
||||
db *DB
|
||||
whereConditions []map[string]interface{}
|
||||
orConditions []map[string]interface{}
|
||||
notConditions []map[string]interface{}
|
||||
havingCondition map[string]interface{}
|
||||
initAttrs []interface{}
|
||||
assignAttrs []interface{}
|
||||
selects map[string]interface{}
|
||||
omits []string
|
||||
orders []string
|
||||
joins string
|
||||
preload []searchPreload
|
||||
offset string
|
||||
limit string
|
||||
group string
|
||||
tableName string
|
||||
raw bool
|
||||
Unscoped bool
|
||||
db *DB
|
||||
whereConditions []map[string]interface{}
|
||||
orConditions []map[string]interface{}
|
||||
notConditions []map[string]interface{}
|
||||
havingConditions []map[string]interface{}
|
||||
initAttrs []interface{}
|
||||
assignAttrs []interface{}
|
||||
selects map[string]interface{}
|
||||
omits []string
|
||||
orders []string
|
||||
joins string
|
||||
preload []searchPreload
|
||||
offset string
|
||||
limit string
|
||||
group string
|
||||
tableName string
|
||||
raw bool
|
||||
Unscoped bool
|
||||
}
|
||||
|
||||
type searchPreload struct {
|
||||
@ -60,8 +60,12 @@ func (s *search) Assign(attrs ...interface{}) *search {
|
||||
|
||||
func (s *search) Order(value string, reorder ...bool) *search {
|
||||
if len(reorder) > 0 && reorder[0] {
|
||||
s.orders = []string{value}
|
||||
} else {
|
||||
if value != "" {
|
||||
s.orders = []string{value}
|
||||
} else {
|
||||
s.orders = []string{}
|
||||
}
|
||||
} else if value != "" {
|
||||
s.orders = append(s.orders, value)
|
||||
}
|
||||
return s
|
||||
@ -93,7 +97,7 @@ func (s *search) Group(query string) *search {
|
||||
}
|
||||
|
||||
func (s *search) Having(query string, values ...interface{}) *search {
|
||||
s.havingCondition = map[string]interface{}{"query": query, "args": values}
|
||||
s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
|
||||
return s
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user