diff --git a/README.md b/README.md index db6270c5..dbfbedd5 100644 --- a/README.md +++ b/README.md @@ -138,12 +138,14 @@ db.SingularTable(true) ```go // Create table db.CreateTable(&User{}) +db.Set("gorm:table_options", "ENGINE=InnoDB").CreateTable(&User{}) // Drop table db.DropTable(&User{}) // Automating Migration db.AutoMigrate(&User{}) +db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}) db.AutoMigrate(&User{}, &Product{}, &Order{}) // Feel free to change your struct, AutoMigrate will keep your database up-to-date. // AutoMigrate will ONLY add *new columns* and *new indexes*, @@ -1126,7 +1128,7 @@ type Product struct { // 2nd param : destination table(id) // 3rd param : ONDELETE // 4th param : ONUPDATE -db.Model(&User{}).AddForeignKey("role_id", "roles", "CASCADE", "RESTRICT") +db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") // Add index db.Model(&User{}).AddIndex("idx_user_name", "name") diff --git a/association.go b/association.go index e34a10bd..342dd6cd 100644 --- a/association.go +++ b/association.go @@ -4,14 +4,14 @@ import ( "errors" "fmt" "reflect" + "strings" ) type Association struct { - Scope *Scope - PrimaryKey interface{} - Column string - Error error - Field *Field + Scope *Scope + Column string + Error error + Field *Field } func (association *Association) setErr(err error) *Association { @@ -45,60 +45,42 @@ func (association *Association) Append(values ...interface{}) *Association { return association.setErr(scope.db.Error) } -func (association *Association) getPrimaryKeys(values ...interface{}) []interface{} { - primaryKeys := []interface{}{} +func (association *Association) Delete(values ...interface{}) *Association { scope := association.Scope + relationship := association.Field.Relationship - for _, value := range values { - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - if reflectValue.Kind() == reflect.Slice { - for i := 0; i < reflectValue.Len(); i++ { - if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank { - primaryKeys = append(primaryKeys, primaryField.Field.Interface()) - } - } - } else if reflectValue.Kind() == reflect.Struct { - if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank { - primaryKeys = append(primaryKeys, primaryField.Field.Interface()) + // many to many + if relationship.Kind == "many_to_many" { + query := scope.NewDB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) } } - } - return primaryKeys -} -func (association *Association) Delete(values ...interface{}) *Association { - primaryKeys := association.getPrimaryKeys(values...) + primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)) + query = query.Where(sql, toQueryValues(primaryKeys)...) - if len(primaryKeys) == 0 { - association.setErr(errors.New("no primary key found")) - } else { - scope := association.Scope - relationship := association.Field.Relationship - // many to many - if relationship.Kind == "many_to_many" { - sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) - if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { - leftValues := reflect.Zero(association.Field.Field.Type()) - for i := 0; i < association.Field.Field.Len(); i++ { - value := association.Field.Field.Index(i) - if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil { - var included = false - for _, primaryKey := range primaryKeys { - if equalAsString(primaryKey, primaryField.Field.Interface()) { - included = true - } - } - if !included { - leftValues = reflect.Append(leftValues, value) - } + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { + leftValues := reflect.Zero(association.Field.Field.Type()) + for i := 0; i < association.Field.Field.Len(); i++ { + reflectValue := association.Field.Field.Index(i) + primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0] + var included = false + for _, pk := range primaryKeys { + if equalAsString(primaryKey, pk) { + included = true } } - association.Field.Set(leftValues) + if !included { + leftValues = reflect.Append(leftValues, reflectValue) + } } - } else { - association.setErr(errors.New("delete only support many to many")) + association.Field.Set(leftValues) } + } else { + association.setErr(errors.New("delete only support many to many")) } return association } @@ -109,16 +91,16 @@ func (association *Association) Replace(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { field := association.Field.Field - oldPrimaryKeys := association.getPrimaryKeys(field.Interface()) + oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Append(values...) - newPrimaryKeys := association.getPrimaryKeys(field.Interface()) + newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) - var addedPrimaryKeys = []interface{}{} + var addedPrimaryKeys = [][]interface{}{} for _, newKey := range newPrimaryKeys { hasEqual := false for _, oldKey := range oldPrimaryKeys { - if reflect.DeepEqual(newKey, oldKey) { + if equalAsString(newKey, oldKey) { hasEqual = true break } @@ -127,13 +109,21 @@ func (association *Association) Replace(values ...interface{}) *Association { addedPrimaryKeys = append(addedPrimaryKeys, newKey) } } - for _, primaryKey := range association.getPrimaryKeys(values...) { + + for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } if len(addedPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) + query := scope.NewDB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) + query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) } } else { @@ -146,8 +136,13 @@ func (association *Association) Clear() *Association { relationship := association.Field.Relationship scope := association.Scope if relationship.Kind == "many_to_many" { - sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey) + query := scope.NewDB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { @@ -168,18 +163,104 @@ func (association *Association) Count() int { if relationship.Kind == "many_to_many" { relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) - countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) + query := scope.DB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), + field.Field.Interface()) + } + } + if relationship.PolymorphicType != "" { - countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - countScope.Count(&count) + query.Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "belongs_to" { - if v, ok := scope.FieldByName(association.Column); ok { - whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) - scope.DB().Table(newScope.TableName()).Where(whereSql, v).Count(&count) + query := scope.DB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), + field.Field.Interface()) + } } + query.Table(newScope.TableName()).Count(&count) } return count } + +func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} { + results := [][]interface{}{} + scope := association.Scope + + for _, value := range values { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + if reflectValue.Kind() == reflect.Slice { + for i := 0; i < reflectValue.Len(); i++ { + primaryKeys := []interface{}{} + newScope := scope.New(reflectValue.Index(i).Interface()) + for _, column := range columns { + if field, ok := newScope.FieldByName(column); ok { + primaryKeys = append(primaryKeys, field.Field.Interface()) + } else { + primaryKeys = append(primaryKeys, "") + } + } + results = append(results, primaryKeys) + } + } else if reflectValue.Kind() == reflect.Struct { + newScope := scope.New(value) + var primaryKeys []interface{} + for _, column := range columns { + if field, ok := newScope.FieldByName(column); ok { + primaryKeys = append(primaryKeys, field.Field.Interface()) + } else { + primaryKeys = append(primaryKeys, "") + } + } + + results = append(results, primaryKeys) + } + } + return results +} + +func toQueryMarks(primaryValues [][]interface{}) string { + var results []string + + for _, primaryValue := range primaryValues { + var marks []string + for _,_ = range primaryValue { + marks = append(marks, "?") + } + + if len(marks) > 1 { + results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) + } else { + results = append(results, strings.Join(marks, "")) + } + } + return strings.Join(results, ",") +} + +func toQueryCondition(scope *Scope, columns []string) string { + var newColumns []string + for _, column := range columns { + newColumns = append(newColumns, scope.Quote(column)) + } + + if len(columns) > 1 { + return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) + } else { + return strings.Join(columns, ",") + } +} + +func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { + for _, primaryValue := range primaryValues { + for _, value := range primaryValue { + values = append(values, value) + } + } + return values +} diff --git a/association_test.go b/association_test.go index ea5b1b80..dfda46a5 100644 --- a/association_test.go +++ b/association_test.go @@ -23,11 +23,11 @@ func TestHasOneAndHasManyAssociation(t *testing.T) { } if err := DB.Save(&post).Error; err != nil { - t.Errorf("Got errors when save post") + t.Errorf("Got errors when save post", err.Error()) } - if DB.First(&Category{}, "name = ?", "Category 1").Error != nil { - t.Errorf("Category should be saved") + if err := DB.First(&Category{}, "name = ?", "Category 1").Error; err != nil { + t.Errorf("Category should be saved", err.Error()) } var p Post @@ -186,6 +186,7 @@ func TestManyToMany(t *testing.T) { var language Language DB.Where("name = ?", "EE").First(&language) DB.Model(&user).Association("Languages").Delete(language, &language) + if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { t.Errorf("Relations should be deleted with Delete") } diff --git a/callback_create.go b/callback_create.go index 7f21ed6a..bded5324 100644 --- a/callback_create.go +++ b/callback_create.go @@ -35,9 +35,11 @@ func Create(scope *Scope) { } } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { - columns = append(columns, scope.Quote(relationField.DBName)) - sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) + for _, dbName := range relationship.ForeignDBNames { + if relationField := fields[dbName]; !scope.changeableField(relationField) { + columns = append(columns, scope.Quote(relationField.DBName)) + sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) + } } } } diff --git a/callback_query.go b/callback_query.go index 4de911e8..387e813d 100644 --- a/callback_query.go +++ b/callback_query.go @@ -30,7 +30,7 @@ func Query(scope *Scope) { if kind := dest.Kind(); kind == reflect.Slice { isSlice = true destType = dest.Type().Elem() - dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType)))) + dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) if destType.Kind() == reflect.Ptr { isPtr = true diff --git a/callback_shared.go b/callback_shared.go index c1b9bd00..547059e3 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -19,8 +19,13 @@ func SaveBeforeAssociations(scope *Scope) { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { value := field.Field scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) - if relationship.ForeignFieldName != "" { - scope.Err(scope.SetColumn(relationship.ForeignFieldName, scope.New(value.Addr().Interface()).PrimaryKeyValue())) + if len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + } + } } } } @@ -44,8 +49,13 @@ func SaveAfterAssociations(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" { - scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } } if relationship.PolymorphicType != "" { @@ -61,8 +71,13 @@ func SaveAfterAssociations(scope *Scope) { default: elem := value.Addr().Interface() newScope := scope.New(elem) - if relationship.ForeignFieldName != "" { - scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) + if len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) + } + } } if relationship.PolymorphicType != "" { diff --git a/callback_update.go b/callback_update.go index c3f7b4b6..6090ee6b 100644 --- a/callback_update.go +++ b/callback_update.go @@ -55,9 +55,10 @@ func Update(scope *Scope) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { - if !relationField.IsBlank { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))) + for _, dbName := range relationship.ForeignDBNames { + if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { + sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())) + sqls = append(sqls, sql) } } } diff --git a/common_dialect.go b/common_dialect.go index 281df8a7..3b646869 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -70,8 +70,8 @@ func (commonDialect) Quote(key string) string { } func (commonDialect) databaseName(scope *Scope) string { - from := strings.Index(scope.db.parent.source, "/") + 1 - to := strings.Index(scope.db.parent.source, "?") + from := strings.LastIndex(scope.db.parent.source, "/") + 1 + to := strings.LastIndex(scope.db.parent.source, "?") if to == -1 { to = len(scope.db.parent.source) } diff --git a/join_table_handler.go b/join_table_handler.go index 07ecee2e..10e1e848 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -45,41 +45,18 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s s.TableName = tableName s.Source = JoinTableSource{ModelType: source} - sourceScope := &Scope{Value: reflect.New(source).Interface()} - sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields - for _, primaryField := range sourcePrimaryFields { - if relationship.ForeignDBName == "" { - relationship.ForeignFieldName = source.Name() + primaryField.Name - relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName) - } - - var dbName string - if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { - dbName = relationship.ForeignDBName - } else { - dbName = ToDBName(source.Name() + primaryField.Name) - } - + for idx, dbName := range relationship.ForeignFieldNames { s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: dbName, - AssociationDBName: primaryField.DBName, + DBName: relationship.ForeignDBNames[idx], + AssociationDBName: dbName, }) } s.Destination = JoinTableSource{ModelType: destination} - destinationScope := &Scope{Value: reflect.New(destination).Interface()} - destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields - for _, primaryField := range destinationPrimaryFields { - var dbName string - if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { - dbName = relationship.AssociationForeignDBName - } else { - dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name) - } - + for idx, dbName := range relationship.AssociationForeignFieldNames { s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: dbName, - AssociationDBName: primaryField.DBName, + DBName: relationship.AssociationForeignDBNames[idx], + AssociationDBName: dbName, }) } } diff --git a/main.go b/main.go index bdaf0d71..4c5cad11 100644 --- a/main.go +++ b/main.go @@ -118,7 +118,7 @@ func (s *DB) Callback() *callback { } func (s *DB) SetLogger(l logger) { - s.parent.logger = l + s.logger = l } func (s *DB) LogMode(enable bool) *DB { @@ -259,9 +259,9 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { if !result.RecordNotFound() { return result } - c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) + c.err(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates).db.Error) } else if len(c.search.assignAttrs) > 0 { - c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates) + c.err(c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates).db.Error) } return c } @@ -450,10 +450,10 @@ func (s *DB) Association(column string) *Association { err = errors.New("primary key can't be nil") } else { if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || field.Relationship.ForeignFieldName == "" { + if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) } else { - return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field} + return &Association{Scope: scope, Column: column, Field: field} } } else { err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) diff --git a/main_private.go b/main_private.go index 914f7007..84f10e35 100644 --- a/main_private.go +++ b/main_private.go @@ -3,7 +3,7 @@ package gorm import "time" func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} for key, value := range s.values { db.values[key] = value @@ -34,7 +34,7 @@ func (s *DB) err(err error) error { } func (s *DB) print(v ...interface{}) { - s.parent.logger.(logger).Print(v...) + s.logger.(logger).Print(v...) } func (s *DB) log(v ...interface{}) { diff --git a/model_struct.go b/model_struct.go index 10423ae2..72caec24 100644 --- a/model_struct.go +++ b/model_struct.go @@ -5,10 +5,11 @@ import ( "fmt" "go/ast" "reflect" - "regexp" "strconv" "strings" "time" + + "github.com/qor/inflection" ) var modelStructs = map[reflect.Type]*ModelStruct{} @@ -61,19 +62,16 @@ func (structField *StructField) clone() *StructField { } type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - ForeignFieldName string - ForeignDBName string - AssociationForeignFieldName string - AssociationForeignDBName string - JoinTableHandler JoinTableHandlerInterface + Kind string + PolymorphicType string + PolymorphicDBName string + ForeignFieldNames []string + ForeignDBNames []string + AssociationForeignFieldNames []string + AssociationForeignDBNames []string + JoinTableHandler JoinTableHandlerInterface } -var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} -var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} - func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct @@ -113,11 +111,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } else { name := ToDBName(scopeType.Name()) if scope.db == nil || !scope.db.parent.singularTable { - for index, reg := range pluralMapKeys { - if reg.MatchString(name) { - name = reg.ReplaceAllString(name, pluralMapValues[index]) - } - } + name = inflection.Plural(name) } modelStruct.defaultTableName = name @@ -190,12 +184,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var relationship = &Relationship{} - foreignKey := gormSettings["FOREIGNKEY"] if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldName = polymorphicField.Name - relationship.ForeignDBName = polymorphicField.DBName + relationship.ForeignFieldNames = []string{polymorphicField.Name} + relationship.ForeignDBNames = []string{polymorphicField.DBName} + relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} + relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName polymorphicType.IsForeignKey = true @@ -204,6 +199,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } + var foreignKeys []string + if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok { + foreignKeys = append(foreignKeys, foreignKey) + } switch indirectType.Kind() { case reflect.Slice: elemType := indirectType.Elem() @@ -212,21 +211,41 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if elemType.Kind() == reflect.Struct { - if foreignKey == "" { - foreignKey = scopeType.Name() + "Id" - } - if many2many := gormSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] - if associationForeignKey == "" { - associationForeignKey = elemType.Name() + "Id" + + // foreign keys + if len(foreignKeys) == 0 { + for _, field := range scope.PrimaryFields() { + foreignKeys = append(foreignKeys, field.DBName) + } } - relationship.ForeignFieldName = foreignKey - relationship.ForeignDBName = ToDBName(foreignKey) - relationship.AssociationForeignFieldName = associationForeignKey - relationship.AssociationForeignDBName = ToDBName(associationForeignKey) + for _, foreignKey := range foreignKeys { + if field, ok := scope.FieldByName(foreignKey); ok { + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) + joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) + } + } + + // association foreign keys + var associationForeignKeys []string + if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]} + } else { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for _, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } joinTableHandler := JoinTableHandler{} joinTableHandler.Setup(relationship, many2many, scopeType, elemType) @@ -234,12 +253,30 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.Relationship = relationship } else { relationship.Kind = "has_many" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { + + if len(foreignKeys) == 0 { + for _, field := range scope.PrimaryFields() { + if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + } else { + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + } + + if len(relationship.ForeignFieldNames) != 0 { field.Relationship = relationship } } @@ -258,28 +295,56 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } continue } else { - belongsToForeignKey := foreignKey - if belongsToForeignKey == "" { - belongsToForeignKey = field.Name + "Id" + if len(foreignKeys) == 0 { + for _, f := range scope.PrimaryFields() { + if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + } else { + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } } - if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { - relationship.Kind = "belongs_to" - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true + if len(relationship.ForeignFieldNames) != 0 { + relationship.Kind = "has_one" field.Relationship = relationship } else { - if foreignKey == "" { - foreignKey = modelStruct.ModelType.Name() + "Id" + if len(foreignKeys) == 0 { + for _, f := range toScope.PrimaryFields() { + if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + } else { + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } } - relationship.Kind = "has_one" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { + + if len(relationship.ForeignFieldNames) != 0 { + relationship.Kind = "belongs_to" field.Relationship = relationship } } diff --git a/preload.go b/preload.go index 03910c44..0db6fbde 100644 --- a/preload.go +++ b/preload.go @@ -8,12 +8,15 @@ import ( "strings" ) -func getRealValue(value reflect.Value, field string) interface{} { - result := reflect.Indirect(value).FieldByName(field).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() +func getRealValue(value reflect.Value, columns []string) (results []interface{}) { + for _, column := range columns { + result := reflect.Indirect(value).FieldByName(column).Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) } - return result + return } func equalAsString(a interface{}, b interface{}) bool { @@ -97,26 +100,24 @@ func makeSlice(typ reflect.Type) interface{} { } func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - primaryName := scope.PrimaryField().Name - primaryKeys := scope.getColumnAsArray(primaryName) + relation := field.Relationship + + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) if len(primaryKeys) == 0 { return } results := makeSlice(field.Struct.Type) - relation := field.Relationship - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - - scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, relation.ForeignFieldName) + value := getRealValue(result, relation.ForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) { reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) break } @@ -131,27 +132,24 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) } func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - primaryName := scope.PrimaryField().Name - primaryKeys := scope.getColumnAsArray(primaryName) + relation := field.Relationship + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) if len(primaryKeys) == 0 { return } results := makeSlice(field.Struct.Type) - relation := field.Relationship - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - - scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) if scope.IndirectValue().Kind() == reflect.Slice { for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldName) + value := getRealValue(result, relation.ForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { + if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, result)) break @@ -165,25 +163,23 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { relation := field.Relationship - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName) + primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) if len(primaryKeys) == 0 { return } results := makeSlice(field.Struct.Type) - associationPrimaryKey := scope.New(results).PrimaryField().Name - - scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, associationPrimaryKey) + value := getRealValue(result, relation.AssociationForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { + if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) { object.FieldByName(field.Name).Set(result) } } @@ -193,15 +189,23 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } } -func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { +func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { values := scope.IndirectValue() switch values.Kind() { case reflect.Slice: for i := 0; i < values.Len(); i++ { - columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) + var result []interface{} + for _, column := range columns { + result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) + } + results = append(results, result) } case reflect.Struct: - return []interface{}{values.FieldByName(column).Interface()} + var result []interface{} + for _, column := range columns { + result = append(result, values.FieldByName(column).Interface()) + } + return [][]interface{}{result} } return } diff --git a/query_test.go b/query_test.go index b15d01ba..580d06c4 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" + "github.com/jinzhu/gorm" "github.com/jinzhu/now" "testing" @@ -556,7 +557,7 @@ func TestSelectWithEscapedFieldName(t *testing.T) { func TestSelectWithVariables(t *testing.T) { DB.Save(&User{Name: "jinzhu"}) - rows, _ := DB.Table("users").Select("? as fake", "name").Rows() + rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() if !rows.Next() { t.Errorf("Should have returned at least one row") diff --git a/scope.go b/scope.go index cd6b235d..104a3728 100644 --- a/scope.go +++ b/scope.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "regexp" "strings" "time" @@ -87,6 +88,13 @@ func (scope *Scope) Quote(str string) string { } } +func (scope *Scope) QuoteIfPossible(str string) string { + if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) { + return scope.Quote(str) + } + return str +} + // Dialect get dialect func (scope *Scope) Dialect() Dialect { return scope.db.parent.dialect diff --git a/scope_private.go b/scope_private.go index edd0dbe9..1d906e76 100644 --- a/scope_private.go +++ b/scope_private.go @@ -149,7 +149,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.Dialect().Quote(fmt.Sprintf("%v", arg)), 1) + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) } } return @@ -265,10 +265,24 @@ func (scope *Scope) groupSql() string { } func (scope *Scope) havingSql() string { - if scope.Search.havingCondition == nil { + if scope.Search.havingConditions == nil { return "" } - return " HAVING " + scope.buildWhereCondition(scope.Search.havingCondition) + + var andConditions []string + + for _, clause := range scope.Search.havingConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + andConditions = append(andConditions, sql) + } + } + + combinedSql := strings.Join(andConditions, " AND ") + if len(combinedSql) == 0 { + return "" + } + + return " HAVING " + combinedSql } func (scope *Scope) joinsSql() string { @@ -415,12 +429,21 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { joinTableHandler := relationship.JoinTableHandler scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() - scope.Err(toScope.db.Where(sql, foreignKeyValue).Find(value).Error) + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(foreignKey); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) + } + } + scope.Err(query.Find(value).Error) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) - query := toScope.db.Where(sql, scope.PrimaryKeyValue()) + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + if relationship.PolymorphicType != "" { query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } @@ -442,6 +465,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { return scope } +/** + Return the table options string or an empty string if the table options does not exist +*/ +func (scope *Scope) getTableOptions() string{ + tableOptions, ok := scope.Get("gorm:table_options") + if !ok { + return "" + } + return tableOptions.(string) +} + func (scope *Scope) createJoinTable(field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { joinTableHandler := relationship.JoinTableHandler @@ -450,16 +484,22 @@ func (scope *Scope) createJoinTable(field *StructField) { toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} var sqlTypes []string - for _, s := range []*Scope{scope, toScope} { - for _, primaryField := range s.GetModelStruct().PrimaryFields { - value := reflect.Indirect(reflect.New(primaryField.Struct.Type)) + for idx, fieldName := range relationship.ForeignFieldNames { + if field, ok := scope.Fields()[fieldName]; ok { + value := reflect.Indirect(reflect.New(field.Struct.Type)) primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) - dbName := ToDBName(s.GetModelStruct().ModelType.Name() + primaryField.Name) - sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType) + sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType) } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join(sqlTypes, ","))).Error) + for idx, fieldName := range relationship.AssociationForeignFieldNames { + if field, ok := toScope.Fields()[fieldName]; ok { + value := reflect.Indirect(reflect.New(field.Struct.Type)) + primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) + sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType) + } + } + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -484,7 +524,7 @@ func (scope *Scope) createTable() *Scope { if len(primaryKeys) > 0 { primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec() + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() return scope } @@ -515,11 +555,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { var columns []string for _, name := range column { - if regexp.MustCompile("^[a-zA-Z]+$").MatchString(name) { - columns = append(columns, scope.Quote(name)) - } else { - columns = append(columns, name) - } + columns = append(columns, scope.QuoteIfPossible(name)) } sqlCreate := "CREATE INDEX" @@ -532,9 +568,10 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { var table = scope.TableName() - var keyName = fmt.Sprintf("%s_%s_foreign", table, field) + var keyName = fmt.Sprintf("%s_%s_%s_foreign", table, field, regexp.MustCompile("[^a-zA-Z]").ReplaceAllString(dest, "_")) + keyName = regexp.MustCompile("_+").ReplaceAllString(keyName, "_") var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.Quote(keyName), scope.Quote(field), scope.Quote(dest), onDelete, onUpdate)).Exec() + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), scope.QuoteIfPossible(dest), onDelete, onUpdate)).Exec() } func (scope *Scope) removeIndex(indexName string) { diff --git a/search.go b/search.go index 9411af43..2c3df2d1 100644 --- a/search.go +++ b/search.go @@ -3,24 +3,24 @@ package gorm import "fmt" type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingCondition map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []string - joins string - preload []searchPreload - offset string - limit string - group string - tableName string - raw bool - Unscoped bool + db *DB + whereConditions []map[string]interface{} + orConditions []map[string]interface{} + notConditions []map[string]interface{} + havingConditions []map[string]interface{} + initAttrs []interface{} + assignAttrs []interface{} + selects map[string]interface{} + omits []string + orders []string + joins string + preload []searchPreload + offset string + limit string + group string + tableName string + raw bool + Unscoped bool } type searchPreload struct { @@ -60,8 +60,12 @@ func (s *search) Assign(attrs ...interface{}) *search { func (s *search) Order(value string, reorder ...bool) *search { if len(reorder) > 0 && reorder[0] { - s.orders = []string{value} - } else { + if value != "" { + s.orders = []string{value} + } else { + s.orders = []string{} + } + } else if value != "" { s.orders = append(s.orders, value) } return s @@ -93,7 +97,7 @@ func (s *search) Group(query string) *search { } func (s *search) Having(query string, values ...interface{}) *search { - s.havingCondition = map[string]interface{}{"query": query, "args": values} + s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) return s }