From 8848fc476dc93dfc4db337744462a1675f993f7f Mon Sep 17 00:00:00 2001 From: Gabriel Date: Sun, 19 Jul 2015 22:37:08 +0000 Subject: [PATCH 01/26] Table suffix to create tables with InnoDB engine with mysql. Alter table is not affected yet, only create table and auto migration --- README.md | 1 + main.go | 10 ++++++++-- scope_private.go | 16 +++++++++++++--- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ccab06db..5919b188 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,7 @@ import ( db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") // db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB. // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") +// db, err := gorm.OpenWithTableSuffix("mysql", "ENGINE=InnoDB", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") // db, err := gorm.Open("sqlite3", "/tmp/gorm.db") // You can also use an existing database connection handle diff --git a/main.go b/main.go index aba51fc4..b2c625f5 100644 --- a/main.go +++ b/main.go @@ -32,11 +32,16 @@ type DB struct { dialect Dialect singularTable bool source string + tableSuffix string values map[string]interface{} joinTableHandlers map[string]JoinTableHandler } func Open(dialect string, args ...interface{}) (DB, error) { + return OpenWithTableSuffix(dialect, "", args) +} + +func OpenWithTableSuffix(dialect, tableSuffix string, args ...interface{}) (DB, error) { var db DB var err error @@ -69,6 +74,7 @@ func Open(dialect string, args ...interface{}) (DB, error) { logger: defaultLogger, callback: DefaultCallback, source: source, + tableSuffix:tableSuffix, values: map[string]interface{}{}, db: dbSql, } @@ -370,7 +376,7 @@ func (s *DB) RecordNotFound() bool { // Migrations func (s *DB) CreateTable(value interface{}) *DB { - return s.clone().NewScope(value).createTable().db + return s.clone().NewScope(value).Set("gorm:table_suffix", s.tableSuffix).createTable().db } func (s *DB) DropTable(value interface{}) *DB { @@ -390,7 +396,7 @@ func (s *DB) HasTable(value interface{}) bool { func (s *DB) AutoMigrate(values ...interface{}) *DB { db := s.clone() for _, value := range values { - db = db.NewScope(value).NeedPtr().autoMigrate().db + db = db.NewScope(value).NeedPtr().Set("gorm:table_suffix", s.tableSuffix).autoMigrate().db } return db } diff --git a/scope_private.go b/scope_private.go index 85f07e99..2ab340b2 100644 --- a/scope_private.go +++ b/scope_private.go @@ -442,6 +442,17 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { return scope } +/** + Return the table suffix string or an empty string if the table suffix does not exist +*/ +func (scope *Scope) getTableSuffix() string{ + tableSuffix, ok := scope.Get("gorm:table_suffix") + if !ok { + return "" + } + return tableSuffix.(string) +} + func (scope *Scope) createJoinTable(field *StructField) { if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { joinTableHandler := relationship.JoinTableHandler @@ -458,8 +469,7 @@ func (scope *Scope) createJoinTable(field *StructField) { sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType) } } - - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join(sqlTypes, ","))).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableSuffix())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -484,7 +494,7 @@ func (scope *Scope) createTable() *Scope { if len(primaryKeys) > 0 { primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr)).Exec() + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableSuffix())).Exec() return scope } From 260000d00f244674ce23b5a960465db1a508e768 Mon Sep 17 00:00:00 2001 From: Gabriel Date: Mon, 20 Jul 2015 22:46:04 +0000 Subject: [PATCH 02/26] Propagate argument in open function with table options --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index b2c625f5..8470a8c7 100644 --- a/main.go +++ b/main.go @@ -38,7 +38,7 @@ type DB struct { } func Open(dialect string, args ...interface{}) (DB, error) { - return OpenWithTableSuffix(dialect, "", args) + return OpenWithTableSuffix(dialect, "", args...) } func OpenWithTableSuffix(dialect, tableSuffix string, args ...interface{}) (DB, error) { From a9cdf1dc7f9920a015a8e4e9dab046c4f466f5a0 Mon Sep 17 00:00:00 2001 From: kiwih Date: Wed, 22 Jul 2015 15:00:20 +1200 Subject: [PATCH 03/26] Add basic support for multiple HAVING clauses. All clauses will be ANDed together. --- scope_private.go | 18 ++++++++++++++++-- search.go | 38 +++++++++++++++++++------------------- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/scope_private.go b/scope_private.go index edd0dbe9..f8620f34 100644 --- a/scope_private.go +++ b/scope_private.go @@ -265,10 +265,24 @@ func (scope *Scope) groupSql() string { } func (scope *Scope) havingSql() string { - if scope.Search.havingCondition == nil { + if scope.Search.havingConditions == nil { return "" } - return " HAVING " + scope.buildWhereCondition(scope.Search.havingCondition) + + var andConditions []string + + for _, clause := range scope.Search.havingConditions { + if sql := scope.buildWhereCondition(clause); sql != "" { + andConditions = append(andConditions, sql) + } + } + + combinedSql := strings.Join(andConditions, " AND ") + if len(combinedSql) == 0 { + return "" + } + + return " HAVING " + combinedSql } func (scope *Scope) joinsSql() string { diff --git a/search.go b/search.go index 9411af43..130415ef 100644 --- a/search.go +++ b/search.go @@ -3,24 +3,24 @@ package gorm import "fmt" type search struct { - db *DB - whereConditions []map[string]interface{} - orConditions []map[string]interface{} - notConditions []map[string]interface{} - havingCondition map[string]interface{} - initAttrs []interface{} - assignAttrs []interface{} - selects map[string]interface{} - omits []string - orders []string - joins string - preload []searchPreload - offset string - limit string - group string - tableName string - raw bool - Unscoped bool + db *DB + whereConditions []map[string]interface{} + orConditions []map[string]interface{} + notConditions []map[string]interface{} + havingConditions []map[string]interface{} + initAttrs []interface{} + assignAttrs []interface{} + selects map[string]interface{} + omits []string + orders []string + joins string + preload []searchPreload + offset string + limit string + group string + tableName string + raw bool + Unscoped bool } type searchPreload struct { @@ -93,7 +93,7 @@ func (s *search) Group(query string) *search { } func (s *search) Having(query string, values ...interface{}) *search { - s.havingCondition = map[string]interface{}{"query": query, "args": values} + s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) return s } From 2fe185eb77a22a3dbdb42cebed976598162ec294 Mon Sep 17 00:00:00 2001 From: liudan Date: Tue, 28 Jul 2015 16:33:32 +0800 Subject: [PATCH 04/26] fix panic in function databaseName when there are special characters in password, such as '?', '/' --- common_dialect.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common_dialect.go b/common_dialect.go index 281df8a7..3b646869 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -70,8 +70,8 @@ func (commonDialect) Quote(key string) string { } func (commonDialect) databaseName(scope *Scope) string { - from := strings.Index(scope.db.parent.source, "/") + 1 - to := strings.Index(scope.db.parent.source, "?") + from := strings.LastIndex(scope.db.parent.source, "/") + 1 + to := strings.LastIndex(scope.db.parent.source, "?") if to == -1 { to = len(scope.db.parent.source) } From a29230c86f0849b51bbb3c19b71f108fbb74541d Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 14:26:48 +0800 Subject: [PATCH 05/26] multpile foreign keys --- callback_create.go | 8 ++- callback_shared.go | 27 +++++++-- callback_update.go | 7 ++- join_table_handler.go | 35 ++--------- main.go | 2 +- model_struct.go | 134 ++++++++++++++++++++++++++++-------------- 6 files changed, 127 insertions(+), 86 deletions(-) diff --git a/callback_create.go b/callback_create.go index 7f21ed6a..bded5324 100644 --- a/callback_create.go +++ b/callback_create.go @@ -35,9 +35,11 @@ func Create(scope *Scope) { } } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { - columns = append(columns, scope.Quote(relationField.DBName)) - sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) + for _, dbName := range relationship.ForeignDBNames { + if relationField := fields[dbName]; !scope.changeableField(relationField) { + columns = append(columns, scope.Quote(relationField.DBName)) + sqls = append(sqls, scope.AddToVars(relationField.Field.Interface())) + } } } } diff --git a/callback_shared.go b/callback_shared.go index c1b9bd00..1e9d320f 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -19,8 +19,13 @@ func SaveBeforeAssociations(scope *Scope) { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { value := field.Field scope.Err(scope.NewDB().Save(value.Addr().Interface()).Error) - if relationship.ForeignFieldName != "" { - scope.Err(scope.SetColumn(relationship.ForeignFieldName, scope.New(value.Addr().Interface()).PrimaryKeyValue())) + if len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + } + } } } } @@ -44,8 +49,13 @@ func SaveAfterAssociations(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" { - scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) + if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + } + } } if relationship.PolymorphicType != "" { @@ -61,8 +71,13 @@ func SaveAfterAssociations(scope *Scope) { default: elem := value.Addr().Interface() newScope := scope.New(elem) - if relationship.ForeignFieldName != "" { - scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) + if len(relationship.ForeignFieldNames) != 0 { + for idx, fieldName := range relationship.ForeignFieldNames { + associationForeignName := relationship.AssociationForeignDBNames[idx] + if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { + scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + } + } } if relationship.PolymorphicType != "" { diff --git a/callback_update.go b/callback_update.go index c3f7b4b6..6090ee6b 100644 --- a/callback_update.go +++ b/callback_update.go @@ -55,9 +55,10 @@ func Update(scope *Scope) { sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - if relationField := fields[relationship.ForeignDBName]; !scope.changeableField(relationField) { - if !relationField.IsBlank { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface()))) + for _, dbName := range relationship.ForeignDBNames { + if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank { + sql := fmt.Sprintf("%v = %v", scope.Quote(relationField.DBName), scope.AddToVars(relationField.Field.Interface())) + sqls = append(sqls, sql) } } } diff --git a/join_table_handler.go b/join_table_handler.go index 07ecee2e..10e1e848 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -45,41 +45,18 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s s.TableName = tableName s.Source = JoinTableSource{ModelType: source} - sourceScope := &Scope{Value: reflect.New(source).Interface()} - sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields - for _, primaryField := range sourcePrimaryFields { - if relationship.ForeignDBName == "" { - relationship.ForeignFieldName = source.Name() + primaryField.Name - relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName) - } - - var dbName string - if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { - dbName = relationship.ForeignDBName - } else { - dbName = ToDBName(source.Name() + primaryField.Name) - } - + for idx, dbName := range relationship.ForeignFieldNames { s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: dbName, - AssociationDBName: primaryField.DBName, + DBName: relationship.ForeignDBNames[idx], + AssociationDBName: dbName, }) } s.Destination = JoinTableSource{ModelType: destination} - destinationScope := &Scope{Value: reflect.New(destination).Interface()} - destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields - for _, primaryField := range destinationPrimaryFields { - var dbName string - if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { - dbName = relationship.AssociationForeignDBName - } else { - dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name) - } - + for idx, dbName := range relationship.AssociationForeignFieldNames { s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: dbName, - AssociationDBName: primaryField.DBName, + DBName: relationship.AssociationForeignDBNames[idx], + AssociationDBName: dbName, }) } } diff --git a/main.go b/main.go index aba51fc4..7c4c4df4 100644 --- a/main.go +++ b/main.go @@ -445,7 +445,7 @@ func (s *DB) Association(column string) *Association { err = errors.New("primary key can't be nil") } else { if field, ok := scope.FieldByName(column); ok { - if field.Relationship == nil || field.Relationship.ForeignFieldName == "" { + if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) } else { return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field} diff --git a/model_struct.go b/model_struct.go index 10423ae2..119e6dc9 100644 --- a/model_struct.go +++ b/model_struct.go @@ -61,14 +61,14 @@ func (structField *StructField) clone() *StructField { } type Relationship struct { - Kind string - PolymorphicType string - PolymorphicDBName string - ForeignFieldName string - ForeignDBName string - AssociationForeignFieldName string - AssociationForeignDBName string - JoinTableHandler JoinTableHandlerInterface + Kind string + PolymorphicType string + PolymorphicDBName string + ForeignFieldNames []string + ForeignDBNames []string + AssociationForeignFieldNames []string + AssociationForeignDBNames []string + JoinTableHandler JoinTableHandlerInterface } var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} @@ -190,12 +190,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var relationship = &Relationship{} - foreignKey := gormSettings["FOREIGNKEY"] if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldName = polymorphicField.Name - relationship.ForeignDBName = polymorphicField.DBName + relationship.ForeignFieldNames = []string{polymorphicField.Name} + relationship.ForeignDBNames = []string{polymorphicField.DBName} relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName polymorphicType.IsForeignKey = true @@ -204,6 +203,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } + var foreignKeys []string + if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok { + foreignKeys := append(foreignKeys, gormSettings["FOREIGNKEY"]) + } switch indirectType.Kind() { case reflect.Slice: elemType := indirectType.Elem() @@ -212,34 +215,63 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if elemType.Kind() == reflect.Struct { - if foreignKey == "" { - foreignKey = scopeType.Name() + "Id" - } - if many2many := gormSettings["MANY2MANY"]; many2many != "" { relationship.Kind = "many_to_many" - associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] - if associationForeignKey == "" { - associationForeignKey = elemType.Name() + "Id" + + // foreign keys + if len(foreignKeys) == 0 { + for _, field := range scope.PrimaryFields() { + foreignKeys = append(foreignKeys, field.DBName) + } } - relationship.ForeignFieldName = foreignKey - relationship.ForeignDBName = ToDBName(foreignKey) - relationship.AssociationForeignFieldName = associationForeignKey - relationship.AssociationForeignDBName = ToDBName(associationForeignKey) + for _, foreignKey := range foreignKeys { + if field, ok := scope.FieldByName(foreignKey); ok { + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, field.DBName) + joinTableDBName := ToDBName(scopeType.Name()) + "_" + field.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) + } + } + + // association foreign keys + var associationForeignKeys []string + if foreignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { + associationForeignKeys = []string{gormSettings["ASSOCIATIONFOREIGNKEY"]} + } else { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for _, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, name) + joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } joinTableHandler := JoinTableHandler{} joinTableHandler.Setup(relationship, many2many, scopeType, elemType) relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { + if len(foreignKeys) == 0 { + for _, field := range scope.PrimaryFields() { + foreignKeys = append(foreignKeys, scopeType.Name()+field.Name) + } + } + relationship.Kind = "has_many" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + + if len(relationship.ForeignFieldNames) != 0 { field.Relationship = relationship } } @@ -258,28 +290,42 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } continue } else { - belongsToForeignKey := foreignKey - if belongsToForeignKey == "" { - belongsToForeignKey = field.Name + "Id" + belongsToForeignKeys := foreignKeys + if len(belongsToForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + belongsToForeignKeys = append(belongsToForeignKeys, field.Name+field.Name) + } } - if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { + for _, foreignKey := range belongsToForeignKeys { + if foreignField := getForeignField(foreignKey, fields); foreignField != nil { + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + + if len(relationship.ForeignFieldNames) != 0 { relationship.Kind = "belongs_to" - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true field.Relationship = relationship } else { - if foreignKey == "" { - foreignKey = modelStruct.ModelType.Name() + "Id" + hasOneForeignKeys := foreignKeys + if len(hasOneForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + hasOneForeignKeys = append(hasOneForeignKeys, modelStruct.ModelType.Name()+field.Name) + } } - relationship.Kind = "has_one" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { + + for _, foreignKey := range hasOneForeignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } + } + + if len(relationship.ForeignFieldNames) != 0 { + relationship.Kind = "has_one" field.Relationship = relationship } } From dc428d2364789513875de77ecee30d3af9efbe1c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 17:26:10 +0800 Subject: [PATCH 06/26] Fix compile error for association --- association.go | 207 ++++++++++++++++++++++++++++++++++--------------- main.go | 2 +- 2 files changed, 145 insertions(+), 64 deletions(-) diff --git a/association.go b/association.go index e34a10bd..f62e712b 100644 --- a/association.go +++ b/association.go @@ -4,14 +4,14 @@ import ( "errors" "fmt" "reflect" + "strings" ) type Association struct { - Scope *Scope - PrimaryKey interface{} - Column string - Error error - Field *Field + Scope *Scope + Column string + Error error + Field *Field } func (association *Association) setErr(err error) *Association { @@ -45,60 +45,43 @@ func (association *Association) Append(values ...interface{}) *Association { return association.setErr(scope.db.Error) } -func (association *Association) getPrimaryKeys(values ...interface{}) []interface{} { - primaryKeys := []interface{}{} +func (association *Association) Delete(values ...interface{}) *Association { scope := association.Scope + relationship := association.Field.Relationship - for _, value := range values { - reflectValue := reflect.Indirect(reflect.ValueOf(value)) - if reflectValue.Kind() == reflect.Slice { - for i := 0; i < reflectValue.Len(); i++ { - if primaryField := scope.New(reflectValue.Index(i).Interface()).PrimaryField(); !primaryField.IsBlank { - primaryKeys = append(primaryKeys, primaryField.Field.Interface()) - } - } - } else if reflectValue.Kind() == reflect.Struct { - if primaryField := scope.New(value).PrimaryField(); !primaryField.IsBlank { - primaryKeys = append(primaryKeys, primaryField.Field.Interface()) + // many to many + if relationship.Kind == "many_to_many" { + query := scope.NewDB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) } } - } - return primaryKeys -} -func (association *Association) Delete(values ...interface{}) *Association { - primaryKeys := association.getPrimaryKeys(values...) + primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, values...) + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)) + query = query.Where(sql, toQueryValues(primaryKeys)...) - if len(primaryKeys) == 0 { - association.setErr(errors.New("no primary key found")) - } else { - scope := association.Scope - relationship := association.Field.Relationship - // many to many - if relationship.Kind == "many_to_many" { - sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) - if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { - leftValues := reflect.Zero(association.Field.Field.Type()) - for i := 0; i < association.Field.Field.Len(); i++ { - value := association.Field.Field.Index(i) - if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil { - var included = false - for _, primaryKey := range primaryKeys { - if equalAsString(primaryKey, primaryField.Field.Interface()) { - included = true - } - } - if !included { - leftValues = reflect.Append(leftValues, value) + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { + leftValues := reflect.Zero(association.Field.Field.Type()) + for i := 0; i < association.Field.Field.Len(); i++ { + value := association.Field.Field.Index(i) + if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil { + var included = false + for _, primaryKey := range primaryKeys { + if equalAsString(primaryKey, primaryField.Field.Interface()) { + included = true } } + if !included { + leftValues = reflect.Append(leftValues, value) + } } - association.Field.Set(leftValues) } - } else { - association.setErr(errors.New("delete only support many to many")) + association.Field.Set(leftValues) } + } else { + association.setErr(errors.New("delete only support many to many")) } return association } @@ -109,12 +92,12 @@ func (association *Association) Replace(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { field := association.Field.Field - oldPrimaryKeys := association.getPrimaryKeys(field.Interface()) + oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, field.Interface()) association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Append(values...) - newPrimaryKeys := association.getPrimaryKeys(field.Interface()) + newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, field.Interface()) - var addedPrimaryKeys = []interface{}{} + var addedPrimaryKeys = [][]interface{}{} for _, newKey := range newPrimaryKeys { hasEqual := false for _, oldKey := range oldPrimaryKeys { @@ -127,13 +110,21 @@ func (association *Association) Replace(values ...interface{}) *Association { addedPrimaryKeys = append(addedPrimaryKeys, newKey) } } - for _, primaryKey := range association.getPrimaryKeys(values...) { + + for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignDBNames, values...) { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } if len(addedPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) + query := scope.NewDB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + + sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) + query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) } } else { @@ -146,8 +137,13 @@ func (association *Association) Clear() *Association { relationship := association.Field.Relationship scope := association.Scope if relationship.Kind == "many_to_many" { - sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey) + query := scope.NewDB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.ForeignFieldNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { @@ -168,18 +164,103 @@ func (association *Association) Count() int { if relationship.Kind == "many_to_many" { relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) - countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) + query := scope.DB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), + field.Field.Interface()) + } + } + if relationship.PolymorphicType != "" { - countScope = countScope.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - countScope.Count(&count) + query.Count(&count) } else if relationship.Kind == "belongs_to" { - if v, ok := scope.FieldByName(association.Column); ok { - whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) - scope.DB().Table(newScope.TableName()).Where(whereSql, v).Count(&count) + query := scope.DB() + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), scope.Quote(foreignKey)), + field.Field.Interface()) + } } + query.Table(newScope.TableName()).Count(&count) } return count } + +func (association *Association) getPrimaryKeys(columns []string, values ...interface{}) [][]interface{} { + results := [][]interface{}{} + scope := association.Scope + + for _, value := range values { + primaryKeys := []interface{}{} + + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + if reflectValue.Kind() == reflect.Slice { + for i := 0; i < reflectValue.Len(); i++ { + newScope := scope.New(reflectValue.Index(i).Interface()) + for _, column := range columns { + if field, ok := newScope.FieldByName(column); ok { + primaryKeys = append(primaryKeys, field.Field.Interface()) + } else { + primaryKeys = append(primaryKeys, "") + } + } + } + } else if reflectValue.Kind() == reflect.Struct { + newScope := scope.New(value) + for _, column := range columns { + if field, ok := newScope.FieldByName(column); ok { + primaryKeys = append(primaryKeys, field.Field.Interface()) + } else { + primaryKeys = append(primaryKeys, "") + } + } + } + + results = append(results, primaryKeys) + } + return results +} + +func toQueryMarks(primaryValues [][]interface{}) string { + var results []string + + for _, primaryValue := range primaryValues { + var marks []string + for range primaryValue { + marks = append(marks, "?") + } + + if len(marks) > 1 { + results = append(results, fmt.Sprintf("(%v)", strings.Join(marks, ","))) + } else { + results = append(results, strings.Join(marks, "")) + } + } + return strings.Join(results, ",") +} + +func toQueryCondition(scope *Scope, columns []string) string { + var newColumns []string + for _, column := range columns { + newColumns = append(newColumns, scope.Quote(column)) + } + + if len(columns) > 1 { + return fmt.Sprintf("(%v)", strings.Join(newColumns, ",")) + } else { + return strings.Join(columns, ",") + } +} + +func toQueryValues(primaryValues [][]interface{}) (values []interface{}) { + for _, primaryValue := range primaryValues { + for _, value := range primaryValue { + values = append(values, value) + } + } + return values +} diff --git a/main.go b/main.go index 7c4c4df4..e7f93a02 100644 --- a/main.go +++ b/main.go @@ -448,7 +448,7 @@ func (s *DB) Association(column string) *Association { if field.Relationship == nil || len(field.Relationship.ForeignFieldNames) == 0 { err = fmt.Errorf("invalid association %v for %v", column, scope.IndirectValue().Type()) } else { - return &Association{Scope: scope, Column: column, PrimaryKey: primaryField.Field.Interface(), Field: field} + return &Association{Scope: scope, Column: column, Field: field} } } else { err = fmt.Errorf("%v doesn't have column %v", scope.IndirectValue().Type(), column) From fea291e796f796d9b9d4ff0305f768c0b3c320d2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 17:58:49 +0800 Subject: [PATCH 07/26] Fix compile error for scope_private --- model_struct.go | 75 ++++++++++++++++++++++++++++++++---------------- scope_private.go | 19 ++++++++---- 2 files changed, 64 insertions(+), 30 deletions(-) diff --git a/model_struct.go b/model_struct.go index 119e6dc9..6e1ff055 100644 --- a/model_struct.go +++ b/model_struct.go @@ -256,18 +256,27 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.JoinTableHandler = &joinTableHandler field.Relationship = relationship } else { + relationship.Kind = "has_many" + if len(foreignKeys) == 0 { for _, field := range scope.PrimaryFields() { - foreignKeys = append(foreignKeys, scopeType.Name()+field.Name) + if foreignField := getForeignField(scopeType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } - } - - relationship.Kind = "has_many" - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + } else { + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } } @@ -293,15 +302,23 @@ func (scope *Scope) GetModelStruct() *ModelStruct { belongsToForeignKeys := foreignKeys if len(belongsToForeignKeys) == 0 { for _, field := range toScope.PrimaryFields() { - belongsToForeignKeys = append(belongsToForeignKeys, field.Name+field.Name) + if foreignField := getForeignField(field.Name+field.Name, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } - } - - for _, foreignKey := range belongsToForeignKeys { - if foreignField := getForeignField(foreignKey, fields); foreignField != nil { - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + } else { + for _, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } } @@ -312,15 +329,23 @@ func (scope *Scope) GetModelStruct() *ModelStruct { hasOneForeignKeys := foreignKeys if len(hasOneForeignKeys) == 0 { for _, field := range toScope.PrimaryFields() { - hasOneForeignKeys = append(hasOneForeignKeys, modelStruct.ModelType.Name()+field.Name) + if foreignField := getForeignField(modelStruct.ModelType.Name()+field.Name, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } - } - - for _, foreignKey := range hasOneForeignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) - relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) - foreignField.IsForeignKey = true + } else { + for _, foreignKey := range hasOneForeignKeys { + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) + relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) + foreignField.IsForeignKey = true + } } } diff --git a/scope_private.go b/scope_private.go index edd0dbe9..931db3de 100644 --- a/scope_private.go +++ b/scope_private.go @@ -415,12 +415,21 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { joinTableHandler := relationship.JoinTableHandler scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { - sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() - scope.Err(toScope.db.Where(sql, foreignKeyValue).Find(value).Error) + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + scope.Err(query.Find(value).Error) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) - query := toScope.db.Where(sql, scope.PrimaryKeyValue()) + query := toScope.db + for idx, foreignKey := range relationship.ForeignDBNames { + if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + } + } + if relationship.PolymorphicType != "" { query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } From 4e8272cf9d764be31566aaea0e691816b67df8e0 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 18:19:49 +0800 Subject: [PATCH 08/26] Fix compile error for preload --- model_struct.go | 2 +- preload.go | 63 ++++++++++++++++++++++++++----------------------- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/model_struct.go b/model_struct.go index 6e1ff055..7e4b683c 100644 --- a/model_struct.go +++ b/model_struct.go @@ -205,7 +205,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var foreignKeys []string if foreignKey, ok := gormSettings["FOREIGNKEY"]; ok { - foreignKeys := append(foreignKeys, gormSettings["FOREIGNKEY"]) + foreignKeys = append(foreignKeys, foreignKey) } switch indirectType.Kind() { case reflect.Slice: diff --git a/preload.go b/preload.go index 03910c44..0a302ab2 100644 --- a/preload.go +++ b/preload.go @@ -8,12 +8,15 @@ import ( "strings" ) -func getRealValue(value reflect.Value, field string) interface{} { - result := reflect.Indirect(value).FieldByName(field).Interface() - if r, ok := result.(driver.Valuer); ok { - result, _ = r.Value() +func getRealValue(value reflect.Value, columns []string) (results []interface{}) { + for _, column := range columns { + result := reflect.Indirect(value).FieldByName(column).Interface() + if r, ok := result.(driver.Valuer); ok { + result, _ = r.Value() + } + results = append(results, result) } - return result + return } func equalAsString(a interface{}, b interface{}) bool { @@ -97,26 +100,23 @@ func makeSlice(typ reflect.Type) interface{} { } func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { - primaryName := scope.PrimaryField().Name - primaryKeys := scope.getColumnAsArray(primaryName) + relation := field.Relationship + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) if len(primaryKeys) == 0 { return } results := makeSlice(field.Struct.Type) - relation := field.Relationship - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - - scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, relation.ForeignFieldName) + value := getRealValue(result, relation.ForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + if equalAsString(getRealValue(objects.Index(j), relation.AssociationForeignFieldNames), value) { reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) break } @@ -131,27 +131,24 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) } func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { - primaryName := scope.PrimaryField().Name - primaryKeys := scope.getColumnAsArray(primaryName) + relation := field.Relationship + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) if len(primaryKeys) == 0 { return } results := makeSlice(field.Struct.Type) - relation := field.Relationship - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - - scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) if scope.IndirectValue().Kind() == reflect.Slice { for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldName) + value := getRealValue(result, relation.ForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { + if equalAsString(getRealValue(object, relation.AssociationForeignFieldNames), value) { f := object.FieldByName(field.Name) f.Set(reflect.Append(f, result)) break @@ -165,25 +162,23 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { relation := field.Relationship - primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName) + primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames) if len(primaryKeys) == 0 { return } results := makeSlice(field.Struct.Type) - associationPrimaryKey := scope.New(results).PrimaryField().Name - - scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) if scope.IndirectValue().Kind() == reflect.Slice { - value := getRealValue(result, associationPrimaryKey) + value := getRealValue(result, relation.AssociationForeignFieldNames) objects := scope.IndirectValue() for j := 0; j < objects.Len(); j++ { object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { + if equalAsString(getRealValue(object, relation.ForeignFieldNames), value) { object.FieldByName(field.Name).Set(result) } } @@ -193,15 +188,23 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } } -func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { +func (scope *Scope) getColumnAsArray(columns []string) (results [][]interface{}) { values := scope.IndirectValue() switch values.Kind() { case reflect.Slice: for i := 0; i < values.Len(); i++ { - columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) + var result []interface{} + for _, column := range columns { + result = append(result, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) + } + results = append(results, result) } case reflect.Struct: - return []interface{}{values.FieldByName(column).Interface()} + var result []interface{} + for _, column := range columns { + result = append(result, values.FieldByName(column).Interface()) + } + return [][]interface{}{result} } return } From 7decf73356c289d71ed2bf74bbc854777eede3a5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 18:41:43 +0800 Subject: [PATCH 09/26] Fix test TestHasOneAndHasManyAssociation --- association_test.go | 6 +++--- callback_shared.go | 6 +++--- model_struct.go | 11 +++++------ 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/association_test.go b/association_test.go index ea5b1b80..205a929e 100644 --- a/association_test.go +++ b/association_test.go @@ -23,11 +23,11 @@ func TestHasOneAndHasManyAssociation(t *testing.T) { } if err := DB.Save(&post).Error; err != nil { - t.Errorf("Got errors when save post") + t.Errorf("Got errors when save post", err.Error()) } - if DB.First(&Category{}, "name = ?", "Category 1").Error != nil { - t.Errorf("Category should be saved") + if err := DB.First(&Category{}, "name = ?", "Category 1").Error; err != nil { + t.Errorf("Category should be saved", err.Error()) } var p Post diff --git a/callback_shared.go b/callback_shared.go index 1e9d320f..fc6b23b3 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -52,8 +52,8 @@ func SaveAfterAssociations(scope *Scope) { if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + if f, ok := scope.FieldByName(associationForeignName); ok { + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) } } } @@ -75,7 +75,7 @@ func SaveAfterAssociations(scope *Scope) { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { - scope.Err(scope.SetColumn(fieldName, f.Field.Interface())) + scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) } } } diff --git a/model_struct.go b/model_struct.go index 7e4b683c..902bc2cd 100644 --- a/model_struct.go +++ b/model_struct.go @@ -299,12 +299,11 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } continue } else { - belongsToForeignKeys := foreignKeys - if len(belongsToForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - if foreignField := getForeignField(field.Name+field.Name, fields); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + if len(foreignKeys) == 0 { + for _, f := range toScope.PrimaryFields() { + if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) foreignField.IsForeignKey = true From 9c52c29e903aaa402a2ff28dc865cbe367ec6e3a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 18:56:05 +0800 Subject: [PATCH 10/26] Fix test TestRelated --- callback_shared.go | 2 +- model_struct.go | 7 +++---- scope_private.go | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/callback_shared.go b/callback_shared.go index fc6b23b3..547059e3 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -74,7 +74,7 @@ func SaveAfterAssociations(scope *Scope) { if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] - if f, ok := scope.New(value.Addr().Interface()).FieldByName(associationForeignName); ok { + if f, ok := scope.FieldByName(associationForeignName); ok { scope.Err(newScope.SetColumn(fieldName, f.Field.Interface())) } } diff --git a/model_struct.go b/model_struct.go index 902bc2cd..aa124625 100644 --- a/model_struct.go +++ b/model_struct.go @@ -325,10 +325,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct { relationship.Kind = "belongs_to" field.Relationship = relationship } else { - hasOneForeignKeys := foreignKeys - if len(hasOneForeignKeys) == 0 { + if len(foreignKeys) == 0 { for _, field := range toScope.PrimaryFields() { - if foreignField := getForeignField(modelStruct.ModelType.Name()+field.Name, fields); foreignField != nil { + if foreignField := getForeignField(modelStruct.ModelType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) @@ -337,7 +336,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } } else { - for _, foreignKey := range hasOneForeignKeys { + for _, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) diff --git a/scope_private.go b/scope_private.go index 931db3de..1d58e6a2 100644 --- a/scope_private.go +++ b/scope_private.go @@ -417,8 +417,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } else if relationship.Kind == "belongs_to" { query := toScope.db for idx, foreignKey := range relationship.ForeignDBNames { - if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + if field, ok := scope.FieldByName(foreignKey); ok { + query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) } } scope.Err(query.Find(value).Error) From ebbeecd10f2ce44a37aa00b628db4c8ce4b9e088 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 22:18:56 +0800 Subject: [PATCH 11/26] Fix test TestManyToMany --- association.go | 40 ++++++++++++++++++++-------------------- association_test.go | 1 + 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/association.go b/association.go index f62e712b..b088c1dd 100644 --- a/association.go +++ b/association.go @@ -58,25 +58,24 @@ func (association *Association) Delete(values ...interface{}) *Association { } } - primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, values...) + primaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)) query = query.Where(sql, toQueryValues(primaryKeys)...) if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { - value := association.Field.Field.Index(i) - if primaryField := association.Scope.New(value.Interface()).PrimaryField(); primaryField != nil { - var included = false - for _, primaryKey := range primaryKeys { - if equalAsString(primaryKey, primaryField.Field.Interface()) { - included = true - } - } - if !included { - leftValues = reflect.Append(leftValues, value) + reflectValue := association.Field.Field.Index(i) + primaryKey := association.getPrimaryKeys(relationship.ForeignFieldNames, reflectValue.Interface())[0] + var included = false + for _, pk := range primaryKeys { + if equalAsString(primaryKey, pk) { + included = true } } + if !included { + leftValues = reflect.Append(leftValues, reflectValue) + } } association.Field.Set(leftValues) } @@ -92,16 +91,16 @@ func (association *Association) Replace(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { field := association.Field.Field - oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, field.Interface()) + oldPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) association.Field.Set(reflect.Zero(association.Field.Field.Type())) association.Append(values...) - newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignDBNames, field.Interface()) + newPrimaryKeys := association.getPrimaryKeys(relationship.AssociationForeignFieldNames, field.Interface()) var addedPrimaryKeys = [][]interface{}{} for _, newKey := range newPrimaryKeys { hasEqual := false for _, oldKey := range oldPrimaryKeys { - if reflect.DeepEqual(newKey, oldKey) { + if equalAsString(newKey, oldKey) { hasEqual = true break } @@ -111,7 +110,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } } - for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignDBNames, values...) { + for _, primaryKey := range association.getPrimaryKeys(relationship.AssociationForeignFieldNames, values...) { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } @@ -123,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) *Association { } } - sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(addedPrimaryKeys)) query = query.Where(sql, toQueryValues(addedPrimaryKeys)...) association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) } @@ -195,11 +194,10 @@ func (association *Association) getPrimaryKeys(columns []string, values ...inter scope := association.Scope for _, value := range values { - primaryKeys := []interface{}{} - reflectValue := reflect.Indirect(reflect.ValueOf(value)) if reflectValue.Kind() == reflect.Slice { for i := 0; i < reflectValue.Len(); i++ { + primaryKeys := []interface{}{} newScope := scope.New(reflectValue.Index(i).Interface()) for _, column := range columns { if field, ok := newScope.FieldByName(column); ok { @@ -208,9 +206,11 @@ func (association *Association) getPrimaryKeys(columns []string, values ...inter primaryKeys = append(primaryKeys, "") } } + results = append(results, primaryKeys) } } else if reflectValue.Kind() == reflect.Struct { newScope := scope.New(value) + var primaryKeys []interface{} for _, column := range columns { if field, ok := newScope.FieldByName(column); ok { primaryKeys = append(primaryKeys, field.Field.Interface()) @@ -218,9 +218,9 @@ func (association *Association) getPrimaryKeys(columns []string, values ...inter primaryKeys = append(primaryKeys, "") } } - } - results = append(results, primaryKeys) + results = append(results, primaryKeys) + } } return results } diff --git a/association_test.go b/association_test.go index 205a929e..dfda46a5 100644 --- a/association_test.go +++ b/association_test.go @@ -186,6 +186,7 @@ func TestManyToMany(t *testing.T) { var language Language DB.Where("name = ?", "EE").First(&language) DB.Model(&user).Association("Languages").Delete(language, &language) + if DB.Model(&user).Association("Languages").Count() != len(totalLanguages)-1 || len(user.Languages) != len(totalLanguages)-1 { t.Errorf("Relations should be deleted with Delete") } From f50956cfef11fffd7df785f5571b377fae042b6f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 22:36:04 +0800 Subject: [PATCH 12/26] Fix test TestSelectWithCreate --- model_struct.go | 8 ++++---- preload.go | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/model_struct.go b/model_struct.go index aa124625..50437778 100644 --- a/model_struct.go +++ b/model_struct.go @@ -326,10 +326,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { field.Relationship = relationship } else { if len(foreignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - if foreignField := getForeignField(modelStruct.ModelType.Name()+field.Name, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, field.DBName) + for _, f := range scope.PrimaryFields() { + if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) foreignField.IsForeignKey = true diff --git a/preload.go b/preload.go index 0a302ab2..0db6fbde 100644 --- a/preload.go +++ b/preload.go @@ -101,6 +101,7 @@ func makeSlice(typ reflect.Type) interface{} { func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { relation := field.Relationship + primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames) if len(primaryKeys) == 0 { return @@ -168,7 +169,7 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } results := makeSlice(field.Struct.Type) - scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) + scope.Err(scope.NewDB().Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, conditions...).Error) resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { From f00b95d305e086d2644f26ef1f445f4df391470e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 30 Jul 2015 22:59:25 +0800 Subject: [PATCH 13/26] Passed all tests for multiple primary keys --- association.go | 2 +- model_struct.go | 26 ++++++++++++++------------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/association.go b/association.go index b088c1dd..4d3fb15f 100644 --- a/association.go +++ b/association.go @@ -174,7 +174,7 @@ func (association *Association) Count() int { if relationship.PolymorphicType != "" { query = query.Where(fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.PolymorphicDBName)), scope.TableName()) } - query.Count(&count) + query.Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "belongs_to" { query := scope.DB() for idx, foreignKey := range relationship.ForeignDBNames { diff --git a/model_struct.go b/model_struct.go index 50437778..468002d5 100644 --- a/model_struct.go +++ b/model_struct.go @@ -195,6 +195,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { relationship.ForeignFieldNames = []string{polymorphicField.Name} relationship.ForeignDBNames = []string{polymorphicField.DBName} + relationship.AssociationForeignFieldNames = []string{scope.PrimaryField().Name} + relationship.AssociationForeignDBNames = []string{scope.PrimaryField().DBName} relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName polymorphicType.IsForeignKey = true @@ -300,8 +302,8 @@ func (scope *Scope) GetModelStruct() *ModelStruct { continue } else { if len(foreignKeys) == 0 { - for _, f := range toScope.PrimaryFields() { - if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { + for _, f := range scope.PrimaryFields() { + if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) @@ -311,9 +313,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } else { for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, fields); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) foreignField.IsForeignKey = true @@ -322,12 +324,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "belongs_to" + relationship.Kind = "has_one" field.Relationship = relationship } else { if len(foreignKeys) == 0 { - for _, f := range scope.PrimaryFields() { - if foreignField := getForeignField(modelStruct.ModelType.Name()+f.Name, toScope.GetStructFields()); foreignField != nil { + for _, f := range toScope.PrimaryFields() { + if foreignField := getForeignField(field.Name+f.Name, fields); foreignField != nil { relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, f.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, f.DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) @@ -337,9 +339,9 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } else { for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scope.PrimaryField().Name) - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scope.PrimaryField().DBName) + if foreignField := getForeignField(foreignKey, fields); foreignField != nil { + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, toScope.PrimaryField().Name) + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, toScope.PrimaryField().DBName) relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.Name) relationship.ForeignDBNames = append(relationship.ForeignDBNames, foreignField.DBName) foreignField.IsForeignKey = true @@ -348,7 +350,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if len(relationship.ForeignFieldNames) != 0 { - relationship.Kind = "has_one" + relationship.Kind = "belongs_to" field.Relationship = relationship } } From 6a7dda9a32e187c044178aadb0a4510f053a73fa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jul 2015 15:25:33 +0800 Subject: [PATCH 14/26] Fix AssociationForeignFieldNames for many2many relations --- model_struct.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_struct.go b/model_struct.go index 468002d5..9c07db9b 100644 --- a/model_struct.go +++ b/model_struct.go @@ -247,7 +247,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, name := range associationForeignKeys { if field, ok := toScope.FieldByName(name); ok { - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, name) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) } From 25ba9487aa744c5c484c16af15221b440a90c98c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 31 Jul 2015 16:33:44 +0800 Subject: [PATCH 15/26] Create join table with computed foreign keys --- scope_private.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/scope_private.go b/scope_private.go index 1d58e6a2..e440f7a4 100644 --- a/scope_private.go +++ b/scope_private.go @@ -459,12 +459,19 @@ func (scope *Scope) createJoinTable(field *StructField) { toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} var sqlTypes []string - for _, s := range []*Scope{scope, toScope} { - for _, primaryField := range s.GetModelStruct().PrimaryFields { - value := reflect.Indirect(reflect.New(primaryField.Struct.Type)) + for idx, fieldName := range relationship.ForeignFieldNames { + if field, ok := scope.Fields()[fieldName]; ok { + value := reflect.Indirect(reflect.New(field.Struct.Type)) primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) - dbName := ToDBName(s.GetModelStruct().ModelType.Name() + primaryField.Name) - sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType) + sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+primaryKeySqlType) + } + } + + for idx, fieldName := range relationship.AssociationForeignFieldNames { + if field, ok := toScope.Fields()[fieldName]; ok { + value := reflect.Indirect(reflect.New(field.Struct.Type)) + primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) + sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+primaryKeySqlType) } } From 1c227d4243c5083689d7e467c42fdea9af93d24e Mon Sep 17 00:00:00 2001 From: Rahul Ghose Date: Fri, 31 Jul 2015 16:52:46 +0530 Subject: [PATCH 16/26] this fixes syntax error relevant issue: https://github.com/jinzhu/gorm/issues/588 --- association.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/association.go b/association.go index 4d3fb15f..342dd6cd 100644 --- a/association.go +++ b/association.go @@ -230,7 +230,7 @@ func toQueryMarks(primaryValues [][]interface{}) string { for _, primaryValue := range primaryValues { var marks []string - for range primaryValue { + for _,_ = range primaryValue { marks = append(marks, "?") } From fa864331428595f6d43232c3c2043e626b79aa7b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Aug 2015 09:06:06 +0800 Subject: [PATCH 17/26] Overwrite initiliazed slices as empty slices --- callback_query.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callback_query.go b/callback_query.go index 4de911e8..387e813d 100644 --- a/callback_query.go +++ b/callback_query.go @@ -30,7 +30,7 @@ func Query(scope *Scope) { if kind := dest.Kind(); kind == reflect.Slice { isSlice = true destType = dest.Type().Elem() - dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType)))) + dest.Set(reflect.MakeSlice(dest.Type(), 0, 0)) if destType.Kind() == reflect.Ptr { isPtr = true From 8a88d665d5a7c7b6a03e8cf8b6b300aa57e70be1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Aug 2015 09:25:06 +0800 Subject: [PATCH 18/26] Add QuoteIfPossible for Scope --- scope.go | 8 ++++++++ scope_private.go | 8 ++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/scope.go b/scope.go index cd6b235d..104a3728 100644 --- a/scope.go +++ b/scope.go @@ -3,6 +3,7 @@ package gorm import ( "errors" "fmt" + "regexp" "strings" "time" @@ -87,6 +88,13 @@ func (scope *Scope) Quote(str string) string { } } +func (scope *Scope) QuoteIfPossible(str string) string { + if regexp.MustCompile("^[a-zA-Z]+(.[a-zA-Z]+)*$").MatchString(str) { + return scope.Quote(str) + } + return str +} + // Dialect get dialect func (scope *Scope) Dialect() Dialect { return scope.db.parent.dialect diff --git a/scope_private.go b/scope_private.go index e440f7a4..e22a726a 100644 --- a/scope_private.go +++ b/scope_private.go @@ -531,11 +531,7 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { var columns []string for _, name := range column { - if regexp.MustCompile("^[a-zA-Z]+$").MatchString(name) { - columns = append(columns, scope.Quote(name)) - } else { - columns = append(columns, name) - } + columns = append(columns, scope.QuoteIfPossible(name)) } sqlCreate := "CREATE INDEX" @@ -550,7 +546,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on var table = scope.TableName() var keyName = fmt.Sprintf("%s_%s_foreign", table, field) var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.Quote(keyName), scope.Quote(field), scope.Quote(dest), onDelete, onUpdate)).Exec() + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), scope.QuoteIfPossible(dest), onDelete, onUpdate)).Exec() } func (scope *Scope) removeIndex(indexName string) { From 6f30170fec05c648b530155d54aa302144527e6f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Aug 2015 11:09:17 +0800 Subject: [PATCH 19/26] Use copy logger into itself instead of using parent's --- main.go | 2 +- main_private.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/main.go b/main.go index e7f93a02..4e93ba28 100644 --- a/main.go +++ b/main.go @@ -113,7 +113,7 @@ func (s *DB) Callback() *callback { } func (s *DB) SetLogger(l logger) { - s.parent.logger = l + s.logger = l } func (s *DB) LogMode(enable bool) *DB { diff --git a/main_private.go b/main_private.go index 914f7007..84f10e35 100644 --- a/main_private.go +++ b/main_private.go @@ -3,7 +3,7 @@ package gorm import "time" func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} for key, value := range s.values { db.values[key] = value @@ -34,7 +34,7 @@ func (s *DB) err(err error) error { } func (s *DB) print(v ...interface{}) { - s.parent.logger.(logger).Print(v...) + s.logger.(logger).Print(v...) } func (s *DB) log(v ...interface{}) { From 05b3f036f8bbb96b22336bbe3510253879ac0ab1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Aug 2015 11:20:08 +0800 Subject: [PATCH 20/26] Change plural engine to github.com/qor/inflection --- model_struct.go | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/model_struct.go b/model_struct.go index 9c07db9b..72caec24 100644 --- a/model_struct.go +++ b/model_struct.go @@ -5,10 +5,11 @@ import ( "fmt" "go/ast" "reflect" - "regexp" "strconv" "strings" "time" + + "github.com/qor/inflection" ) var modelStructs = map[reflect.Type]*ModelStruct{} @@ -71,9 +72,6 @@ type Relationship struct { JoinTableHandler JoinTableHandlerInterface } -var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} -var pluralMapValues = []string{"ches", "sses", "shes", "days", "ies", "xes", "${1}s"} - func (scope *Scope) GetModelStruct() *ModelStruct { var modelStruct ModelStruct @@ -113,11 +111,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } else { name := ToDBName(scopeType.Name()) if scope.db == nil || !scope.db.parent.singularTable { - for index, reg := range pluralMapKeys { - if reg.MatchString(name) { - name = reg.ReplaceAllString(name, pluralMapValues[index]) - } - } + name = inflection.Plural(name) } modelStruct.defaultTableName = name From 393d8a3a524598430ddb32bf0d130bc1ab1a9940 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Aug 2015 11:27:01 +0800 Subject: [PATCH 21/26] Fix possible duplciated foreign key name --- scope_private.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scope_private.go b/scope_private.go index 4b653dba..fdf14dc7 100644 --- a/scope_private.go +++ b/scope_private.go @@ -558,7 +558,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { var table = scope.TableName() - var keyName = fmt.Sprintf("%s_%s_foreign", table, field) + var keyName = fmt.Sprintf("%s_%s_%s_foreign", table, field, regexp.MustCompile("[^a-zA-Z]").ReplaceAllString(dest, "_")) + keyName = regexp.MustCompile("_+").ReplaceAllString(keyName, "_") var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.QuoteIfPossible(keyName), scope.QuoteIfPossible(field), scope.QuoteIfPossible(dest), onDelete, onUpdate)).Exec() } From 85a682e820e07d3594fc141a7f0d50cc588917fa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Aug 2015 11:28:31 +0800 Subject: [PATCH 22/26] Update README for AddForeignKey example --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index db6270c5..03fbee15 100644 --- a/README.md +++ b/README.md @@ -1126,7 +1126,7 @@ type Product struct { // 2nd param : destination table(id) // 3rd param : ONDELETE // 4th param : ONUPDATE -db.Model(&User{}).AddForeignKey("role_id", "roles", "CASCADE", "RESTRICT") +db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") // Add index db.Model(&User{}).AddIndex("idx_user_name", "name") From f07216e90e3718d101ed8a6ed346707e2bae2460 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Aug 2015 11:31:36 +0800 Subject: [PATCH 23/26] Allow pass blank string to Order --- search.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/search.go b/search.go index 130415ef..2c3df2d1 100644 --- a/search.go +++ b/search.go @@ -60,8 +60,12 @@ func (s *search) Assign(attrs ...interface{}) *search { func (s *search) Order(value string, reorder ...bool) *search { if len(reorder) > 0 && reorder[0] { - s.orders = []string{value} - } else { + if value != "" { + s.orders = []string{value} + } else { + s.orders = []string{} + } + } else if value != "" { s.orders = append(s.orders, value) } return s From e2e417a8c28526667a54311d42980e966e5e2809 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Aug 2015 11:49:34 +0800 Subject: [PATCH 24/26] Fix complicated Select --- query_test.go | 3 ++- scope_private.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/query_test.go b/query_test.go index b15d01ba..580d06c4 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" + "github.com/jinzhu/gorm" "github.com/jinzhu/now" "testing" @@ -556,7 +557,7 @@ func TestSelectWithEscapedFieldName(t *testing.T) { func TestSelectWithVariables(t *testing.T) { DB.Save(&User{Name: "jinzhu"}) - rows, _ := DB.Table("users").Select("? as fake", "name").Rows() + rows, _ := DB.Table("users").Select("? as fake", gorm.Expr("name")).Rows() if !rows.Next() { t.Errorf("Should have returned at least one row") diff --git a/scope_private.go b/scope_private.go index fdf14dc7..4344f22e 100644 --- a/scope_private.go +++ b/scope_private.go @@ -149,7 +149,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.Dialect().Quote(fmt.Sprintf("%v", arg)), 1) + str = strings.Replace(str, "?", scope.AddToVars(arg), 1) } } return From a7762ea7d6a6d2c24c6b3541cb16347e15f69bcc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 1 Aug 2015 12:00:35 +0800 Subject: [PATCH 25/26] Return error happend in Create/Update when using FirstOrCreate --- main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 4e93ba28..dbac3b33 100644 --- a/main.go +++ b/main.go @@ -254,9 +254,9 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { if !result.RecordNotFound() { return result } - c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) + c.err(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates).db.Error) } else if len(c.search.assignAttrs) > 0 { - c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates) + c.err(c.NewScope(out).InstanceSet("gorm:update_interface", s.search.assignAttrs).callCallbacks(s.parent.callback.updates).db.Error) } return c } From eef40a06ff56eca675bf36d176812156f46d8d3a Mon Sep 17 00:00:00 2001 From: Gabriel Date: Sat, 1 Aug 2015 22:46:38 +0000 Subject: [PATCH 26/26] Rename the parameter to table_options and avoid introduction of new API function OpenWithTableSuffix --- README.md | 3 ++- main.go | 10 ++-------- scope_private.go | 12 ++++++------ 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 5919b188..4aa1137a 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,6 @@ import ( db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") // db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB. // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") -// db, err := gorm.OpenWithTableSuffix("mysql", "ENGINE=InnoDB", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") // db, err := gorm.Open("sqlite3", "/tmp/gorm.db") // You can also use an existing database connection handle @@ -137,12 +136,14 @@ db.SingularTable(true) ```go // Create table db.CreateTable(&User{}) +db.Set("gorm:table_options", "ENGINE=InnoDB").CreateTable(&User{}) // Drop table db.DropTable(&User{}) // Automating Migration db.AutoMigrate(&User{}) +db.Set("gorm:table_options", "ENGINE=InnoDB").AutoMigrate(&User{}) db.AutoMigrate(&User{}, &Product{}, &Order{}) // Feel free to change your struct, AutoMigrate will keep your database up-to-date. // AutoMigrate will ONLY add *new columns* and *new indexes*, diff --git a/main.go b/main.go index 8470a8c7..aba51fc4 100644 --- a/main.go +++ b/main.go @@ -32,16 +32,11 @@ type DB struct { dialect Dialect singularTable bool source string - tableSuffix string values map[string]interface{} joinTableHandlers map[string]JoinTableHandler } func Open(dialect string, args ...interface{}) (DB, error) { - return OpenWithTableSuffix(dialect, "", args...) -} - -func OpenWithTableSuffix(dialect, tableSuffix string, args ...interface{}) (DB, error) { var db DB var err error @@ -74,7 +69,6 @@ func OpenWithTableSuffix(dialect, tableSuffix string, args ...interface{}) (DB, logger: defaultLogger, callback: DefaultCallback, source: source, - tableSuffix:tableSuffix, values: map[string]interface{}{}, db: dbSql, } @@ -376,7 +370,7 @@ func (s *DB) RecordNotFound() bool { // Migrations func (s *DB) CreateTable(value interface{}) *DB { - return s.clone().NewScope(value).Set("gorm:table_suffix", s.tableSuffix).createTable().db + return s.clone().NewScope(value).createTable().db } func (s *DB) DropTable(value interface{}) *DB { @@ -396,7 +390,7 @@ func (s *DB) HasTable(value interface{}) bool { func (s *DB) AutoMigrate(values ...interface{}) *DB { db := s.clone() for _, value := range values { - db = db.NewScope(value).NeedPtr().Set("gorm:table_suffix", s.tableSuffix).autoMigrate().db + db = db.NewScope(value).NeedPtr().autoMigrate().db } return db } diff --git a/scope_private.go b/scope_private.go index 2ab340b2..e7efd9a8 100644 --- a/scope_private.go +++ b/scope_private.go @@ -443,14 +443,14 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } /** - Return the table suffix string or an empty string if the table suffix does not exist + Return the table options string or an empty string if the table options does not exist */ -func (scope *Scope) getTableSuffix() string{ - tableSuffix, ok := scope.Get("gorm:table_suffix") +func (scope *Scope) getTableOptions() string{ + tableOptions, ok := scope.Get("gorm:table_options") if !ok { return "" } - return tableSuffix.(string) + return tableOptions.(string) } func (scope *Scope) createJoinTable(field *StructField) { @@ -469,7 +469,7 @@ func (scope *Scope) createJoinTable(field *StructField) { sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType) } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableSuffix())).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), scope.getTableOptions())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -494,7 +494,7 @@ func (scope *Scope) createTable() *Scope { if len(primaryKeys) > 0 { primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableSuffix())).Exec() + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() return scope }