Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Pär Karlsson 2015-08-02 01:30:04 +02:00
commit 9e131e5ec0
17 changed files with 446 additions and 248 deletions

View File

@ -138,12 +138,14 @@ db.SingularTable(true)
```go ```go
// Create table // Create table
db.CreateTable(&User{}) db.CreateTable(&User{})
db.Set("gorm:table_options", "ENGINE=InnoDB").CreateTable(&User{})
// Drop table // Drop table
db.DropTable(&User{}) db.DropTable(&User{})
// Automating Migration // Automating Migration
db.AutoMigrate(&User{}) db.AutoMigrate(&User{})
db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{})
db.AutoMigrate(&User{}, &Product{}, &Order{}) db.AutoMigrate(&User{}, &Product{}, &Order{})
// Feel free to change your struct, AutoMigrate will keep your database up-to-date. // Feel free to change your struct, AutoMigrate will keep your database up-to-date.
// AutoMigrate will ONLY add *new columns* and *new indexes*, // AutoMigrate will ONLY add *new columns* and *new indexes*,
@ -1126,7 +1128,7 @@ type Product struct {
// 2nd param : destination table(id) // 2nd param : destination table(id)
// 3rd param : ONDELETE // 3rd param : ONDELETE
// 4th param : ONUPDATE // 4th param : ONUPDATE
db.Model(&User{}).AddForeignKey("role_id", "roles", "CASCADE", "RESTRICT") db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
// Add index // Add index
db.Model(&User{}).AddIndex("idx_user_name", "name") db.Model(&User{}).AddIndex("idx_user_name", "name")

View File

@ -4,11 +4,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"reflect" "reflect"
"strings"
) )
type Association struct { type Association struct {
Scope *Scope Scope *Scope
PrimaryKey interface{}
Column string Column string
Error error Error error
Field *Field Field *Field
@ -45,53 +45,36 @@ func (association *Association) Append(values ...interface{}) *Association {
return association.setErr(scope.db.Error) return association.setErr(scope.db.Error)
} }
func (association *Association) getPrimaryKeys(values ...interface{}) []interface{} {
primaryKeys := []interface{}{}
scope := association.Scope
for _, value := range values {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
if reflectValue.Kind() == reflect.Slice {
for i := 0; i < reflectValue.Len(); i++ {
if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank {
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
}
}
} else if reflectValue.Kind() == reflect.Struct {
if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank {
primaryKeys = append(primaryKeys, primaryField.Field.Interface())
}
}
}
return primaryKeys
}
func (association *Association) Delete(values ...interface{}) *Association { func (association *Association) Delete(values ...interface{}) *Association {
primaryKeys := association.getPrimaryKeys(values...)
if len(primaryKeys) == 0 {
association.setErr(errors.New("no primary key found"))
} else {
scope := association.Scope scope := association.Scope
relationship := association.Field.Relationship relationship := association.Field.Relationship
// many to many // many to many
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB()
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...)
sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys))
query = query.Where(sql, toQueryValues(primaryKeys)...)
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
leftValues := reflect.Zero(association.Field.Field.Type()) leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ { for i := 0; i < association.Field.Field.Len(); i++ {
value := association.Field.Field.Index(i) reflectValue := association.Field.Field.Index(i)
if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil { primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0]
var included = false var included = false
for _, primaryKey := range primaryKeys { for _, pk := range primaryKeys {
if equalAsString(primaryKey, primaryField.Field.Interface()) { if equalAsString(primaryKey, pk) {
included = true included = true
} }
} }
if !included { if !included {
leftValues = reflect.Append(leftValues, value) leftValues = reflect.Append(leftValues, reflectValue)
}
} }
} }
association.Field.Set(leftValues) association.Field.Set(leftValues)
@ -99,7 +82,6 @@ func (association *Association) Delete(values ...interface{}) *Association {
} else { } else {
association.setErr(errors.New("delete only support many to many")) association.setErr(errors.New("delete only support many to many"))
} }
}
return association return association
} }
@ -109,16 +91,16 @@ func (association *Association) Replace(values ...interface{}) *Association {
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
field := association.Field.Field field := association.Field.Field
oldPrimaryKeys := association.getPrimaryKeys(field.Interface()) oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface())
association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Field.Set(reflect.Zero(association.Field.Field.Type()))
association.Append(values...) association.Append(values...)
newPrimaryKeys := association.getPrimaryKeys(field.Interface()) newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface())
var addedPrimaryKeys = []interface{}{} var addedPrimaryKeys = [][]interface{}{}
for _, newKey := range newPrimaryKeys { for _, newKey := range newPrimaryKeys {
hasEqual := false hasEqual := false
for _, oldKey := range oldPrimaryKeys { for _, oldKey := range oldPrimaryKeys {
if reflect.DeepEqual(newKey, oldKey) { if equalAsString(newKey, oldKey) {
hasEqual = true hasEqual = true
break break
} }
@ -127,13 +109,21 @@ func (association *Association) Replace(values ...interface{}) *Association {
addedPrimaryKeys = append(addedPrimaryKeys, newKey) addedPrimaryKeys = append(addedPrimaryKeys, newKey)
} }
} }
for _, primaryKey := range association.getPrimaryKeys(values...) {
for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) {
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
} }
if len(addedPrimaryKeys) > 0 { if len(addedPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB()
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys))
query = query.Where(sql, toQueryValues(addedPrimaryKeys)...)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship))
} }
} else { } else {
@ -146,8 +136,13 @@ func (association *Association) Clear() *Association {
relationship := association.Field.Relationship relationship := association.Field.Relationship
scope := association.Scope scope := association.Scope
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := scope.NewDB()
query := scope.NewDB().Where(sql, association.PrimaryKey) for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Field.Set(reflect.Zero(association.Field.Field.Type()))
} else { } else {
@ -168,18 +163,104 @@ func (association *Association) Count() int {
if relationship.Kind == "many_to_many" { if relationship.Kind == "many_to_many" {
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) query := scope.DB()
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
field.Field.Interface())
}
}
if relationship.PolymorphicType != "" { if relationship.PolymorphicType != "" {
countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName())
} }
countScope.Count(&count) query.Table(newScope.TableName()).Count(&count)
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
if v, ok := scope.FieldByName(association.Column); ok { query := scope.DB()
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) for idx, foreignKey := range relationship.ForeignDBNames {
scope.DB().Table(newScope.TableName()).Where(whereSql, v).Count(&count) if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)),
field.Field.Interface())
} }
} }
query.Table(newScope.TableName()).Count(&count)
}
return count return count
} }
func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} {
results := [][]interface{}{}
scope := association.Scope
for _, value := range values {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
if reflectValue.Kind() == reflect.Slice {
for i := 0; i < reflectValue.Len(); i++ {
primaryKeys := []interface{}{}
newScope := scope.New(reflectValue.Index(i).Interface())
for _, column := range columns {
if field, ok := newScope.FieldByName(column); ok {
primaryKeys = append(primaryKeys, field.Field.Interface())
} else {
primaryKeys = append(primaryKeys, "")
}
}
results = append(results, primaryKeys)
}
} else if reflectValue.Kind() == reflect.Struct {
newScope := scope.New(value)
var primaryKeys []interface{}
for _, column := range columns {
if field, ok := newScope.FieldByName(column); ok {
primaryKeys = append(primaryKeys, field.Field.Interface())
} else {
primaryKeys = append(primaryKeys, "")
}
}
results = append(results, primaryKeys)
}
}
return results
}
func toQueryMarks(primaryValues [][]interface{}) string {
var results []string
for _, primaryValue := range primaryValues {
var marks []string
for _,_ = range primaryValue {
marks = append(marks, "?")
}
if len(marks) > 1 {
results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ",")))
} else {
results = append(results, strings.Join(marks, ""))
}
}
return strings.Join(results, ",")
}
func toQueryCondition(scope *Scope, columns []string) string {
var newColumns []string
for _, column := range columns {
newColumns = append(newColumns, scope.Quote(column))
}
if len(columns) > 1 {
return fmt.Sprintf("(%v)", strings.Join(newColumns, ","))
} else {
return strings.Join(columns, ",")
}
}
func toQueryValues(primaryValues [][]interface{}) (values []interface{}) {
for _, primaryValue := range primaryValues {
for _, value := range primaryValue {
values = append(values, value)
}
}
return values
}

View File

@ -23,11 +23,11 @@ func TestHasOneAndHasManyAssociation(t *testing.T) {
} }
if err := DB.Save(&post).Error; err != nil { if err := DB.Save(&post).Error; err != nil {
t.Errorf("Got errors when save post") t.Errorf("Got errors when save post", err.Error())
} }
if DB.First(&Category{}, "name = ?", "Category 1").Error != nil { if err := DB.First(&Category{}, "name = ?", "Category 1").Error; err != nil {
t.Errorf("Category should be saved") t.Errorf("Category should be saved", err.Error())
} }
var p Post var p Post
@ -186,6 +186,7 @@ func TestManyToMany(t *testing.T) {
var language Language var language Language
DB.Where("name = ?", "EE").First(&language) DB.Where("name = ?", "EE").First(&language)
DB.Model(&user).Association("Languages").Delete(language, &language) DB.Model(&user).Association("Languages").Delete(language, &language)
if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 {
t.Errorf("Relations should be deleted with Delete") t.Errorf("Relations should be deleted with Delete")
} }

View File

@ -35,13 +35,15 @@ func Create(scope *Scope) {
} }
} }
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { for _, dbName := range relationship.ForeignDBNames {
if relationField := fields[dbName]; !scope.changeableField(relationField) {
columns = append(columns, scope.Quote(relationField.DBName)) columns = append(columns, scope.Quote(relationField.DBName))
sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) sqls = append(sqls, scope.AddToVars(relationField.Field.Interface()))
} }
} }
} }
} }
}
returningKey := "*" returningKey := "*"
primaryField := scope.PrimaryField() primaryField := scope.PrimaryField()

View File

@ -30,7 +30,7 @@ func Query(scope *Scope) {
if kind := dest.Kind(); kind == reflect.Slice { if kind := dest.Kind(); kind == reflect.Slice {
isSlice = true isSlice = true
destType = dest.Type().Elem() destType = dest.Type().Elem()
dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType)))) dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
if destType.Kind() == reflect.Ptr { if destType.Kind() == reflect.Ptr {
isPtr = true isPtr = true

View File

@ -19,8 +19,13 @@ func SaveBeforeAssociations(scope *Scope) {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
value := field.Field value := field.Field
scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error)
if relationship.ForeignFieldName != "" { if len(relationship.ForeignFieldNames) != 0 {
scope.Err(scope.SetColumn(relationship.ForeignFieldName, scope.New(value.Addr().Interface()).PrimaryKeyValue())) for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok {
scope.Err(scope.SetColumn(fieldName, f.Field.Interface()))
}
}
} }
} }
} }
@ -44,8 +49,13 @@ func SaveAfterAssociations(scope *Scope) {
elem := value.Index(i).Addr().Interface() elem := value.Index(i).Addr().Interface()
newScope := newDB.NewScope(elem) newScope := newDB.NewScope(elem)
if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" { if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 {
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.FieldByName(associationForeignName); ok {
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
}
}
} }
if relationship.PolymorphicType != "" { if relationship.PolymorphicType != "" {
@ -61,8 +71,13 @@ func SaveAfterAssociations(scope *Scope) {
default: default:
elem := value.Addr().Interface() elem := value.Addr().Interface()
newScope := scope.New(elem) newScope := scope.New(elem)
if relationship.ForeignFieldName != "" { if len(relationship.ForeignFieldNames) != 0 {
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) for idx, fieldName := range relationship.ForeignFieldNames {
associationForeignName := relationship.AssociationForeignDBNames[idx]
if f, ok := scope.FieldByName(associationForeignName); ok {
scope.Err(newScope.SetColumn(fieldName, f.Field.Interface()))
}
}
} }
if relationship.PolymorphicType != "" { if relationship.PolymorphicType != "" {

View File

@ -55,9 +55,10 @@ func Update(scope *Scope) {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
} }
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { for _, dbName := range relationship.ForeignDBNames {
if !relationField.IsBlank { if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))) sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))
sqls = append(sqls, sql)
} }
} }
} }

View File

@ -70,8 +70,8 @@ func (commonDialect) Quote(key string) string {
} }
func (commonDialect) databaseName(scope *Scope) string { func (commonDialect) databaseName(scope *Scope) string {
from := strings.Index(scope.db.parent.source, "/") + 1 from := strings.LastIndex(scope.db.parent.source, "/") + 1
to := strings.Index(scope.db.parent.source, "?") to := strings.LastIndex(scope.db.parent.source, "?")
if to == -1 { if to == -1 {
to = len(scope.db.parent.source) to = len(scope.db.parent.source)
} }

View File

@ -45,41 +45,18 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s
s.TableName = tableName s.TableName = tableName
s.Source = JoinTableSource{ModelType: source} s.Source = JoinTableSource{ModelType: source}
sourceScope := &Scope{Value: reflect.New(source).Interface()} for idx, dbName := range relationship.ForeignFieldNames {
sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields
for _, primaryField := range sourcePrimaryFields {
if relationship.ForeignDBName == "" {
relationship.ForeignFieldName = source.Name() + primaryField.Name
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
}
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.ForeignDBName
} else {
dbName = ToDBName(source.Name() + primaryField.Name)
}
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: dbName, DBName: relationship.ForeignDBNames[idx],
AssociationDBName: primaryField.DBName, AssociationDBName: dbName,
}) })
} }
s.Destination = JoinTableSource{ModelType: destination} s.Destination = JoinTableSource{ModelType: destination}
destinationScope := &Scope{Value: reflect.New(destination).Interface()} for idx, dbName := range relationship.AssociationForeignFieldNames {
destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields
for _, primaryField := range destinationPrimaryFields {
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.AssociationForeignDBName
} else {
dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name)
}
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: dbName, DBName: relationship.AssociationForeignDBNames[idx],
AssociationDBName: primaryField.DBName, AssociationDBName: dbName,
}) })
} }
} }

10
main.go
View File

@ -118,7 +118,7 @@ func (s *DB) Callback() *callback {
} }
func (s *DB) SetLogger(l logger) { func (s *DB) SetLogger(l logger) {
s.parent.logger = l s.logger = l
} }
func (s *DB) LogMode(enable bool) *DB { func (s *DB) LogMode(enable bool) *DB {
@ -259,9 +259,9 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
if !result.RecordNotFound() { if !result.RecordNotFound() {
return result return result
} }
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) c.err(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates).db.Error)
} else if len(c.search.assignAttrs) > 0 { } else if len(c.search.assignAttrs) > 0 {
c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates) c.err(c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates).db.Error)
} }
return c return c
} }
@ -450,10 +450,10 @@ func (s *DB) Association(column string) *Association {
err = errors.New("primary key can't be nil") err = errors.New("primary key can't be nil")
} else { } else {
if field, ok := scope.FieldByName(column); ok { if field, ok := scope.FieldByName(column); ok {
if field.Relationship == nil || field.Relationship.ForeignFieldName == "" { if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 {
err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type())
} else { } else {
return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field} return &Association{Scope: scope, Column: column, Field: field}
} }
} else { } else {
err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column)

View File

@ -3,7 +3,7 @@ package gorm
import "time" import "time"
func (s *DB) clone() *DB { func (s *DB) clone() *DB {
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error}
for key, value := range s.values { for key, value := range s.values {
db.values[key] = value db.values[key] = value
@ -34,7 +34,7 @@ func (s *DB) err(err error) error {
} }
func (s *DB) print(v ...interface{}) { func (s *DB) print(v ...interface{}) {
s.parent.logger.(logger).Print(v...) s.logger.(logger).Print(v...)
} }
func (s *DB) log(v ...interface{}) { func (s *DB) log(v ...interface{}) {

View File

@ -5,10 +5,11 @@ import (
"fmt" "fmt"
"go/ast" "go/ast"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/qor/inflection"
) )
var modelStructs = map[reflect.Type]*ModelStruct{} var modelStructs = map[reflect.Type]*ModelStruct{}
@ -64,16 +65,13 @@ type Relationship struct {
Kind string Kind string
PolymorphicType string PolymorphicType string
PolymorphicDBName string PolymorphicDBName string
ForeignFieldName string ForeignFieldNames []string
ForeignDBName string ForeignDBNames []string
AssociationForeignFieldName string AssociationForeignFieldNames []string
AssociationForeignDBName string AssociationForeignDBNames []string
JoinTableHandler JoinTableHandlerInterface JoinTableHandler JoinTableHandlerInterface
} }
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"}
func (scope *Scope) GetModelStruct() *ModelStruct { func (scope *Scope) GetModelStruct() *ModelStruct {
var modelStruct ModelStruct var modelStruct ModelStruct
@ -113,11 +111,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} else { } else {
name := ToDBName(scopeType.Name()) name := ToDBName(scopeType.Name())
if scope.db == nil || !scope.db.parent.singularTable { if scope.db == nil || !scope.db.parent.singularTable {
for index, reg := range pluralMapKeys { name = inflection.Plural(name)
if reg.MatchString(name) {
name = reg.ReplaceAllString(name, pluralMapValues[index])
}
}
} }
modelStruct.defaultTableName = name modelStruct.defaultTableName = name
@ -190,12 +184,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
var relationship = &Relationship{} var relationship = &Relationship{}
foreignKey := gormSettings["FOREIGNKEY"]
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
relationship.ForeignFieldName = polymorphicField.Name relationship.ForeignFieldNames = []string{polymorphicField.Name}
relationship.ForeignDBName = polymorphicField.DBName relationship.ForeignDBNames = []string{polymorphicField.DBName}
relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name}
relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName}
relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicType = polymorphicType.Name
relationship.PolymorphicDBName = polymorphicType.DBName relationship.PolymorphicDBName = polymorphicType.DBName
polymorphicType.IsForeignKey = true polymorphicType.IsForeignKey = true
@ -204,6 +199,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
} }
var foreignKeys []string
if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok {
foreignKeys = append(foreignKeys, foreignKey)
}
switch indirectType.Kind() { switch indirectType.Kind() {
case reflect.Slice: case reflect.Slice:
elemType := indirectType.Elem() elemType := indirectType.Elem()
@ -212,21 +211,41 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
if elemType.Kind() == reflect.Struct { if elemType.Kind() == reflect.Struct {
if foreignKey == "" {
foreignKey = scopeType.Name() + "Id"
}
if many2many := gormSettings["MANY2MANY"]; many2many != "" { if many2many := gormSettings["MANY2MANY"]; many2many != "" {
relationship.Kind = "many_to_many" relationship.Kind = "many_to_many"
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
if associationForeignKey == "" { // foreign keys
associationForeignKey = elemType.Name() + "Id" if len(foreignKeys) == 0 {
for _, field := range scope.PrimaryFields() {
foreignKeys = append(foreignKeys, field.DBName)
}
} }
relationship.ForeignFieldName = foreignKey for _, foreignKey := range foreignKeys {
relationship.ForeignDBName = ToDBName(foreignKey) if field, ok := scope.FieldByName(foreignKey); ok {
relationship.AssociationForeignFieldName = associationForeignKey relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName)
relationship.AssociationForeignDBName = ToDBName(associationForeignKey) joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName
relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName)
}
}
// association foreign keys
var associationForeignKeys []string
if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" {
associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]}
} else {
for _, field := range toScope.PrimaryFields() {
associationForeignKeys = append(associationForeignKeys, field.DBName)
}
}
for _, name := range associationForeignKeys {
if field, ok := toScope.FieldByName(name); ok {
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName)
joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName)
}
}
joinTableHandler := JoinTableHandler{} joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, many2many, scopeType, elemType) joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
@ -234,12 +253,30 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
field.Relationship = relationship field.Relationship = relationship
} else { } else {
relationship.Kind = "has_many" relationship.Kind = "has_many"
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name if len(foreignKeys) == 0 {
relationship.ForeignDBName = foreignField.DBName for _, field := range scope.PrimaryFields() {
if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil {
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
foreignField.IsForeignKey = true foreignField.IsForeignKey = true
field.Relationship = relationship }
} else if relationship.ForeignFieldName != "" { }
} else {
for _, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
foreignField.IsForeignKey = true
}
}
}
if len(relationship.ForeignFieldNames) != 0 {
field.Relationship = relationship field.Relationship = relationship
} }
} }
@ -258,28 +295,56 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
} }
continue continue
} else { } else {
belongsToForeignKey := foreignKey if len(foreignKeys) == 0 {
if belongsToForeignKey == "" { for _, f := range scope.PrimaryFields() {
belongsToForeignKey = field.Name + "Id" if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil {
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
foreignField.IsForeignKey = true
}
}
} else {
for _, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
foreignField.IsForeignKey = true
}
}
} }
if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { if len(relationship.ForeignFieldNames) != 0 {
relationship.Kind = "belongs_to" relationship.Kind = "has_one"
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship field.Relationship = relationship
} else { } else {
if foreignKey == "" { if len(foreignKeys) == 0 {
foreignKey = modelStruct.ModelType.Name() + "Id" for _, f := range toScope.PrimaryFields() {
} if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil {
relationship.Kind = "has_one" relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name)
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName)
relationship.ForeignFieldName = foreignField.Name relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBName = foreignField.DBName relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
foreignField.IsForeignKey = true foreignField.IsForeignKey = true
field.Relationship = relationship }
} else if relationship.ForeignFieldName != "" { }
} else {
for _, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, fields); foreignField != nil {
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName)
relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name)
relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName)
foreignField.IsForeignKey = true
}
}
}
if len(relationship.ForeignFieldNames) != 0 {
relationship.Kind = "belongs_to"
field.Relationship = relationship field.Relationship = relationship
} }
} }

View File

@ -8,12 +8,15 @@ import (
"strings" "strings"
) )
func getRealValue(value reflect.Value, field string) interface{} { func getRealValue(value reflect.Value, columns []string) (results []interface{}) {
result := reflect.Indirect(value).FieldByName(field).Interface() for _, column := range columns {
result := reflect.Indirect(value).FieldByName(column).Interface()
if r, ok := result.(driver.Valuer); ok { if r, ok := result.(driver.Valuer); ok {
result, _ = r.Value() result, _ = r.Value()
} }
return result results = append(results, result)
}
return
} }
func equalAsString(a interface{}, b interface{}) bool { func equalAsString(a interface{}, b interface{}) bool {
@ -97,26 +100,24 @@ func makeSlice(typ reflect.Type) interface{} {
} }
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name relation := field.Relationship
primaryKeys := scope.getColumnAsArray(primaryName)
primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
relation := field.Relationship scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results)) resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i) result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, relation.ForeignFieldName) value := getRealValue(result, relation.ForeignFieldNames)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) { if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break break
} }
@ -131,27 +132,24 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{})
} }
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name relation := field.Relationship
primaryKeys := scope.getColumnAsArray(primaryName) primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
relation := field.Relationship scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results)) resultValues := reflect.Indirect(reflect.ValueOf(results))
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i) result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldName) value := getRealValue(result, relation.ForeignFieldNames)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, primaryName), value) { if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) {
f := object.FieldByName(field.Name) f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result)) f.Set(reflect.Append(f, result))
break break
@ -165,25 +163,23 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{})
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship relation := field.Relationship
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName) primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames)
if len(primaryKeys) == 0 { if len(primaryKeys) == 0 {
return return
} }
results := makeSlice(field.Struct.Type) results := makeSlice(field.Struct.Type)
associationPrimaryKey := scope.New(results).PrimaryField().Name scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error)
scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results)) resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ { for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i) result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice { if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, associationPrimaryKey) value := getRealValue(result, relation.AssociationForeignFieldNames)
objects := scope.IndirectValue() objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ { for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j)) object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) {
object.FieldByName(field.Name).Set(result) object.FieldByName(field.Name).Set(result)
} }
} }
@ -193,15 +189,23 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{
} }
} }
func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) {
values := scope.IndirectValue() values := scope.IndirectValue()
switch values.Kind() { switch values.Kind() {
case reflect.Slice: case reflect.Slice:
for i := 0; i < values.Len(); i++ { for i := 0; i < values.Len(); i++ {
columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) var result []interface{}
for _, column := range columns {
result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
}
results = append(results, result)
} }
case reflect.Struct: case reflect.Struct:
return []interface{}{values.FieldByName(column).Interface()} var result []interface{}
for _, column := range columns {
result = append(result, values.FieldByName(column).Interface())
}
return [][]interface{}{result}
} }
return return
} }

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/jinzhu/gorm"
"github.com/jinzhu/now" "github.com/jinzhu/now"
"testing" "testing"
@ -556,7 +557,7 @@ func TestSelectWithEscapedFieldName(t *testing.T) {
func TestSelectWithVariables(t *testing.T) { func TestSelectWithVariables(t *testing.T) {
DB.Save(&User{Name: "jinzhu"}) DB.Save(&User{Name: "jinzhu"})
rows, _ := DB.Table("users").Select("? as fake", "name").Rows() rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows()
if !rows.Next() { if !rows.Next() {
t.Errorf("Should have returned at least one row") t.Errorf("Should have returned at least one row")

View File

@ -3,6 +3,7 @@ package gorm
import ( import (
"errors" "errors"
"fmt" "fmt"
"regexp"
"strings" "strings"
"time" "time"
@ -87,6 +88,13 @@ func (scope *Scope) Quote(str string) string {
} }
} }
func (scope *Scope) QuoteIfPossible(str string) string {
if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) {
return scope.Quote(str)
}
return str
}
// Dialect get dialect // Dialect get dialect
func (scope *Scope) Dialect() Dialect { func (scope *Scope) Dialect() Dialect {
return scope.db.parent.dialect return scope.db.parent.dialect

View File

@ -149,7 +149,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
if valuer, ok := interface{}(arg).(driver.Valuer); ok { if valuer, ok := interface{}(arg).(driver.Valuer); ok {
arg, _ = valuer.Value() arg, _ = valuer.Value()
} }
str = strings.Replace(str, "?", scope.Dialect().Quote(fmt.Sprintf("%v", arg)), 1) str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
} }
} }
return return
@ -265,10 +265,24 @@ func (scope *Scope) groupSql() string {
} }
func (scope *Scope) havingSql() string { func (scope *Scope) havingSql() string {
if scope.Search.havingCondition == nil { if scope.Search.havingConditions == nil {
return "" return ""
} }
return " HAVING " + scope.buildWhereCondition(scope.Search.havingCondition)
var andConditions []string
for _, clause := range scope.Search.havingConditions {
if sql := scope.buildWhereCondition(clause); sql != "" {
andConditions = append(andConditions, sql)
}
}
combinedSql := strings.Join(andConditions, " AND ")
if len(combinedSql) == 0 {
return ""
}
return " HAVING " + combinedSql
} }
func (scope *Scope) joinsSql() string { func (scope *Scope) joinsSql() string {
@ -415,12 +429,21 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
joinTableHandler := relationship.JoinTableHandler joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
} else if relationship.Kind == "belongs_to" { } else if relationship.Kind == "belongs_to" {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) query := toScope.db
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() for idx, foreignKey := range relationship.ForeignDBNames {
scope.Err(toScope.db.Where(sql, foreignKeyValue).Find(value).Error) if field, ok := scope.FieldByName(foreignKey); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface())
}
}
scope.Err(query.Find(value).Error)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := toScope.db
query := toScope.db.Where(sql, scope.PrimaryKeyValue()) for idx, foreignKey := range relationship.ForeignDBNames {
if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface())
}
}
if relationship.PolymorphicType != "" { if relationship.PolymorphicType != "" {
query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName())
} }
@ -442,6 +465,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
return scope return scope
} }
/**
Return the table options string or an empty string if the table options does not exist
*/
func (scope *Scope) getTableOptions() string{
tableOptions, ok := scope.Get("gorm:table_options")
if !ok {
return ""
}
return tableOptions.(string)
}
func (scope *Scope) createJoinTable(field *StructField) { func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler joinTableHandler := relationship.JoinTableHandler
@ -450,16 +484,22 @@ func (scope *Scope) createJoinTable(field *StructField) {
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes []string var sqlTypes []string
for _, s := range []*Scope{scope, toScope} { for idx, fieldName := range relationship.ForeignFieldNames {
for _, primaryField := range s.GetModelStruct().PrimaryFields { if field, ok := scope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(primaryField.Struct.Type)) value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
dbName := ToDBName(s.GetModelStruct().ModelType.Name() + primaryField.Name) sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType)
sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType)
} }
} }
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join(sqlTypes, ","))).Error) for idx, fieldName := range relationship.AssociationForeignFieldNames {
if field, ok := toScope.Fields()[fieldName]; ok {
value := reflect.Indirect(reflect.New(field.Struct.Type))
primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType)
}
}
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error)
} }
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
} }
@ -484,7 +524,7 @@ func (scope *Scope) createTable() *Scope {
if len(primaryKeys) > 0 { if len(primaryKeys) > 0 {
primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ","))
} }
scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec() scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec()
return scope return scope
} }
@ -515,11 +555,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
var columns []string var columns []string
for _, name := range column { for _, name := range column {
if regexp.MustCompile("^[a-zA-Z]+$").MatchString(name) { columns = append(columns, scope.QuoteIfPossible(name))
columns = append(columns, scope.Quote(name))
} else {
columns = append(columns, name)
}
} }
sqlCreate := "CREATE INDEX" sqlCreate := "CREATE INDEX"
@ -532,9 +568,10 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) {
func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) {
var table = scope.TableName() var table = scope.TableName()
var keyName = fmt.Sprintf("%s_%s_foreign", table, field) var keyName = fmt.Sprintf("%s_%s_%s_foreign", table, field, regexp.MustCompile("[^a-zA-Z]").ReplaceAllString(dest, "_"))
keyName = regexp.MustCompile("_+").ReplaceAllString(keyName, "_")
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.Quote(keyName), scope.Quote(field), scope.Quote(dest), onDelete, onUpdate)).Exec() scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), scope.QuoteIfPossible(dest), onDelete, onUpdate)).Exec()
} }
func (scope *Scope) removeIndex(indexName string) { func (scope *Scope) removeIndex(indexName string) {

View File

@ -7,7 +7,7 @@ type search struct {
whereConditions []map[string]interface{} whereConditions []map[string]interface{}
orConditions []map[string]interface{} orConditions []map[string]interface{}
notConditions []map[string]interface{} notConditions []map[string]interface{}
havingCondition map[string]interface{} havingConditions []map[string]interface{}
initAttrs []interface{} initAttrs []interface{}
assignAttrs []interface{} assignAttrs []interface{}
selects map[string]interface{} selects map[string]interface{}
@ -60,8 +60,12 @@ func (s *search) Assign(attrs ...interface{}) *search {
func (s *search) Order(value string, reorder ...bool) *search { func (s *search) Order(value string, reorder ...bool) *search {
if len(reorder) > 0 && reorder[0] { if len(reorder) > 0 && reorder[0] {
if value != "" {
s.orders = []string{value} s.orders = []string{value}
} else { } else {
s.orders = []string{}
}
} else if value != "" {
s.orders = append(s.orders, value) s.orders = append(s.orders, value)
} }
return s return s
@ -93,7 +97,7 @@ func (s *search) Group(query string) *search {
} }
func (s *search) Having(query string, values ...interface{}) *search { func (s *search) Having(query string, values ...interface{}) *search {
s.havingCondition = map[string]interface{}{"query": query, "args": values} s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values})
return s return s
} }