diff --git a/README.md b/README.md index a6dd1865..26edcaa0 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,7 @@ import ( ) db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") +// db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB. // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local") // db, err := gorm.Open("sqlite3", "/tmp/gorm.db") @@ -360,7 +361,7 @@ db.Model(&user).Updates(User{Name: "hello", Age: 18}) ### Update Without Callbacks -By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks: +By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations: ```go db.Model(&user).UpdateColumn("name", "hello") @@ -546,7 +547,7 @@ Supports polymorphic has-many and has-one associations. Id int Name string OwnerId int - OwnerType int + OwnerType string } ``` Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors. diff --git a/association.go b/association.go index b011971a..89bb1bec 100644 --- a/association.go +++ b/association.go @@ -77,7 +77,7 @@ func (association *Association) Delete(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) - if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { value := association.Field.Field.Index(i) @@ -132,7 +132,7 @@ func (association *Association) Replace(values ...interface{}) *Association { sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) - association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(query, relationship)) } else { association.setErr(errors.New("replace only support many to many")) } @@ -145,7 +145,7 @@ func (association *Association) Clear() *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey) - if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { association.setErr(err) @@ -163,9 +163,7 @@ func (association *Association) Count() int { newScope := scope.New(association.Field.Field.Interface()) if relationship.Kind == "many_to_many" { - query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName). - Where(relationship.ForeignDBName+" = ?", association.PrimaryKey) - scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count) + relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) diff --git a/callback_query.go b/callback_query.go index 4538b272..5daa5fec 100644 --- a/callback_query.go +++ b/callback_query.go @@ -21,7 +21,7 @@ func Query(scope *Scope) { dest = reflect.Indirect(reflect.ValueOf(value)) } - if orderBy, ok := scope.InstanceGet("gorm:order_by_primary_key"); ok { + if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { if primaryKey := scope.PrimaryKey(); primaryKey != "" { scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy)) } diff --git a/callback_shared.go b/callback_shared.go index 99ad8f50..88158cfc 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -11,6 +11,9 @@ func CommitOrRollbackTransaction(scope *Scope) { } func SaveBeforeAssociations(scope *Scope) { + if !scope.shouldSaveAssociations() { + return + } for _, field := range scope.Fields() { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { @@ -25,6 +28,9 @@ func SaveBeforeAssociations(scope *Scope) { } func SaveAfterAssociations(scope *Scope) { + if !scope.shouldSaveAssociations() { + return + } for _, field := range scope.Fields() { if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { if relationship := field.Relationship; relationship != nil && @@ -38,7 +44,7 @@ func SaveAfterAssociations(scope *Scope) { elem := value.Index(i).Addr().Interface() newScope := newDB.NewScope(elem) - if relationship.JoinTable == "" && relationship.ForeignFieldName != "" { + if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" { scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue())) } @@ -48,9 +54,8 @@ func SaveAfterAssociations(scope *Scope) { scope.Err(newDB.Save(elem).Error) - if joinTable := relationship.JoinTable; joinTable != "" { - scope.Err(scope.db.GetJoinTableHandler(joinTable). - Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue())) + if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { + scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value)) } } default: diff --git a/common_dialect.go b/common_dialect.go index 9360cd26..281df8a7 100644 --- a/common_dialect.go +++ b/common_dialect.go @@ -9,19 +9,19 @@ import ( type commonDialect struct{} -func (s *commonDialect) BinVar(i int) string { - return "?" +func (commonDialect) BinVar(i int) string { + return "$$" // ? } -func (s *commonDialect) SupportLastInsertId() bool { +func (commonDialect) SupportLastInsertId() bool { return true } -func (s *commonDialect) HasTop() bool { +func (commonDialect) HasTop() bool { return false } -func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "BOOLEAN" @@ -57,19 +57,19 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String())) } -func (s *commonDialect) ReturningStr(tableName, key string) string { +func (commonDialect) ReturningStr(tableName, key string) string { return "" } -func (s *commonDialect) SelectFromDummyTable() string { +func (commonDialect) SelectFromDummyTable() string { return "" } -func (s *commonDialect) Quote(key string) string { - return fmt.Sprintf("`%s`", key) +func (commonDialect) Quote(key string) string { + return fmt.Sprintf(`"%s"`, key) } -func (s *commonDialect) databaseName(scope *Scope) string { +func (commonDialect) databaseName(scope *Scope) string { from := strings.Index(scope.db.parent.source, "/") + 1 to := strings.Index(scope.db.parent.source, "?") if to == -1 { @@ -78,24 +78,24 @@ func (s *commonDialect) databaseName(scope *Scope) string { return scope.db.parent.source[from:to] } -func (s *commonDialect) HasTable(scope *Scope, tableName string) bool { +func (c commonDialect) HasTable(scope *Scope, tableName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count) return count > 0 } -func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count) return count > 0 } -func (s *commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) return count > 0 } -func (s *commonDialect) RemoveIndex(scope *Scope, indexName string) { +func (commonDialect) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) } diff --git a/dialect.go b/dialect.go index 2e64cca5..f3221075 100644 --- a/dialect.go +++ b/dialect.go @@ -24,6 +24,8 @@ func NewDialect(driver string) Dialect { switch driver { case "postgres": d = &postgres{} + case "foundation": + d = &foundation{} case "mysql": d = &mysql{} case "sqlite3": diff --git a/field.go b/field.go index e122adb4..8f5efa6d 100644 --- a/field.go +++ b/field.go @@ -27,9 +27,13 @@ func (field *Field) Set(value interface{}) error { if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok { if v, ok := value.(reflect.Value); ok { - scanner.Scan(v.Interface()) + if err := scanner.Scan(v.Interface()); err != nil { + return err + } } else { - scanner.Scan(value) + if err := scanner.Scan(value); err != nil { + return err + } } } else { reflectValue, ok := value.(reflect.Value) diff --git a/foundation.go b/foundation.go new file mode 100644 index 00000000..a9c8f500 --- /dev/null +++ b/foundation.go @@ -0,0 +1,78 @@ +package gorm + +import ( + "fmt" + "reflect" + "time" +) + +type foundation struct { + commonDialect +} + +func (foundation) BinVar(i int) string { + return fmt.Sprintf("$%v", i) +} + +func (foundation) SupportLastInsertId() bool { + return false +} + +func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string { + switch value.Kind() { + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "serial" + } + return "int" + case reflect.Int64, reflect.Uint64: + if autoIncrease { + return "bigserial" + } + return "bigint" + case reflect.Float32, reflect.Float64: + return "double" + case reflect.String: + if size > 0 && size < 65532 { + return fmt.Sprintf("varchar(%d)", size) + } + return "clob" + case reflect.Struct: + if _, ok := value.Interface().(time.Time); ok { + return "datetime" + } + default: + if _, ok := value.Interface().([]byte); ok { + return "blob" + } + } + panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String())) +} + +func (f foundation) ReturningStr(tableName, key string) string { + return fmt.Sprintf("RETURNING %v.%v", f.Quote(tableName), key) +} + +func (foundation) HasTable(scope *Scope, tableName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName).Row().Scan(&count) + return count > 0 +} + +func (foundation) HasColumn(scope *Scope, tableName string, columnName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) + return count > 0 +} + +func (f foundation) RemoveIndex(scope *Scope, indexName string) { + scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", f.Quote(indexName))) +} + +func (foundation) HasIndex(scope *Scope, tableName string, indexName string) bool { + var count int + scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) + return count > 0 +} diff --git a/join_table.go b/join_table.go deleted file mode 100644 index 3ffa4f87..00000000 --- a/join_table.go +++ /dev/null @@ -1,48 +0,0 @@ -package gorm - -import ( - "fmt" - "strings" -) - -type JoinTableHandler interface { - Table(*DB, *Relationship) string - Add(*DB, *Relationship, interface{}, interface{}) error - Delete(*DB, *Relationship) error - Scope(*DB, *Relationship) *DB -} - -type defaultJoinTableHandler struct{} - -func (s *defaultJoinTableHandler) Table(db *DB, relationship *Relationship) string { - return relationship.JoinTable -} - -func (s *defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error { - scope := db.NewScope("") - quotedForeignDBName := scope.Quote(relationship.ForeignDBName) - quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName) - table := s.Table(db, relationship) - - sql := fmt.Sprintf( - "INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);", - scope.Quote(table), - strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","), - scope.Dialect().SelectFromDummyTable(), - scope.Quote(table), - quotedForeignDBName, - quotedAssociationDBName, - ) - - return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error -} - -func (s *defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error { - return db.Table(s.Table(db, relationship)).Delete("").Error -} - -func (s *defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB { - return db.Table(s.Table(db, relationship)) -} - -var DefaultJoinTableHandler = &defaultJoinTableHandler{} diff --git a/join_table_handler.go b/join_table_handler.go new file mode 100644 index 00000000..9f705564 --- /dev/null +++ b/join_table_handler.go @@ -0,0 +1,155 @@ +package gorm + +import ( + "errors" + "fmt" + "reflect" + "strings" +) + +type JoinTableHandlerInterface interface { + Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) + Table(db *DB) string + Add(db *DB, source interface{}, destination interface{}) error + Delete(db *DB, sources ...interface{}) error + JoinWith(db *DB, source interface{}) *DB +} + +type JoinTableForeignKey struct { + DBName string + AssociationDBName string +} + +type JoinTableSource struct { + ModelType reflect.Type + ForeignKeys []JoinTableForeignKey +} + +type JoinTableHandler struct { + TableName string `sql:"-"` + Source JoinTableSource `sql:"-"` + Destination JoinTableSource `sql:"-"` +} + +func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { + s.TableName = tableName + + s.Source = JoinTableSource{ModelType: source} + sourceScope := &Scope{Value: reflect.New(source).Interface()} + for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + if relationship.ForeignDBName == "" { + relationship.ForeignFieldName = source.Name() + primaryField.Name + relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName) + } + s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ + DBName: relationship.ForeignDBName, + AssociationDBName: primaryField.DBName, + }) + } + + s.Destination = JoinTableSource{ModelType: destination} + destinationScope := &Scope{Value: reflect.New(destination).Interface()} + for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { + if relationship.AssociationForeignDBName == "" { + relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name + relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName) + } + s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ + DBName: relationship.AssociationForeignDBName, + AssociationDBName: primaryField.DBName, + }) + } +} + +func (s JoinTableHandler) Table(*DB) string { + return s.TableName +} + +func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { + values := map[string]interface{}{} + + for _, source := range sources { + scope := db.NewScope(source) + modelType := scope.GetModelStruct().ModelType + + if s.Source.ModelType == modelType { + for _, foreignKey := range s.Source.ForeignKeys { + values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() + } + } else if s.Destination.ModelType == modelType { + for _, foreignKey := range s.Destination.ForeignKeys { + values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface() + } + } + } + return values +} + +func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { + scope := db.NewScope("") + searchMap := s.GetSearchMap(db, source1, source2) + + var assignColumns, binVars, conditions []string + var values []interface{} + for key, value := range searchMap { + assignColumns = append(assignColumns, key) + binVars = append(binVars, `?`) + conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) + values = append(values, value) + } + + for _, value := range values { + values = append(values, value) + } + + quotedTable := s.Table(db) + sql := fmt.Sprintf( + "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", + quotedTable, + strings.Join(assignColumns, ","), + strings.Join(binVars, ","), + scope.Dialect().SelectFromDummyTable(), + quotedTable, + strings.Join(conditions, " AND "), + ) + + return db.Exec(sql, values...).Error +} + +func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error { + var conditions []string + var values []interface{} + + for key, value := range s.GetSearchMap(db, sources...) { + conditions = append(conditions, fmt.Sprintf("%v = ?", key)) + values = append(values, value) + } + + return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error +} + +func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { + quotedTable := s.Table(db) + + scope := db.NewScope(source) + modelType := scope.GetModelStruct().ModelType + var joinConditions []string + var queryConditions []string + var values []interface{} + if s.Source.ModelType == modelType { + for _, foreignKey := range s.Destination.ForeignKeys { + destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() + joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) + } + + for _, foreignKey := range s.Source.ForeignKeys { + queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName))) + values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface()) + } + return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))). + Where(strings.Join(queryConditions, " AND "), values...) + } else { + db.Error = errors.New("wrong source type for join table handler") + return db + } +} diff --git a/join_table_test.go b/join_table_test.go index 2624fdb2..f8b097b6 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -15,40 +15,37 @@ type Person struct { } type PersonAddress struct { + gorm.JoinTableHandler PersonID int AddressID int DeletedAt time.Time CreatedAt time.Time } -func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string { - return relationship.JoinTable -} - -func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error { +func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ - relationship.ForeignDBName: foreignValue, - relationship.AssociationForeignDBName: associationValue, + "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), + "address_id": db.NewScope(associationValue).PrimaryKeyValue(), }).Assign(map[string]interface{}{ - relationship.ForeignFieldName: foreignValue, - relationship.AssociationForeignFieldName: associationValue, - "DeletedAt": gorm.Expr("NULL"), + "person_id": foreignValue, + "address_id": associationValue, + "deleted_at": gorm.Expr("NULL"), }).FirstOrCreate(&PersonAddress{}).Error } -func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error { +func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error { return db.Delete(&PersonAddress{}).Error } -func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB { - table := pa.Table(db, relationship) - return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) +func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB { + table := pa.Table(db) + return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } func TestJoinTable(t *testing.T) { DB.Exec("drop table person_addresses;") DB.AutoMigrate(&Person{}) - DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses") + DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{}) address1 := &Address{Address1: "address 1"} address2 := &Address{Address1: "address 2"} diff --git a/main.go b/main.go index 87fecd59..7049675e 100644 --- a/main.go +++ b/main.go @@ -55,6 +55,9 @@ func Open(dialect string, args ...interface{}) (DB, error) { driver = value source = args[1].(string) } + if driver == "foundation" { + driver = "postgres" // FoundationDB speaks a postgres-compatible protocol. + } dbSql, err = sql.Open(driver, source) case sqlCommon: source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() @@ -193,14 +196,14 @@ func (s *DB) Assign(attrs ...interface{}) *DB { func (s *DB) First(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) newScope.Search.Limit(1) - return newScope.InstanceSet("gorm:order_by_primary_key", "ASC"). + return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } func (s *DB) Last(out interface{}, where ...interface{}) *DB { newScope := s.clone().NewScope(out) newScope.Search.Limit(1) - return newScope.InstanceSet("gorm:order_by_primary_key", "DESC"). + return newScope.Set("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } @@ -266,6 +269,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB { func (s *DB) UpdateColumns(values interface{}) *DB { return s.clone().NewScope(s.Value). Set("gorm:update_column", true). + Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callback.updates).db } @@ -470,29 +474,16 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { return } -func (s *DB) GetJoinTableHandler(table string) JoinTableHandler { - if s.parent.joinTableHandlers != nil { - if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok { - return joinTableHandler +func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { + for _, field := range s.NewScope(source).GetModelStruct().StructFields { + if field.Name == column || field.DBName == column { + if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { + source := (&Scope{Value: source}).GetModelStruct().ModelType + destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType + handler.Setup(field.Relationship, many2many, source, destination) + field.Relationship.JoinTableHandler = handler + s.Table(handler.Table(s)).AutoMigrate(handler) + } } - if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok { - return joinTableHandler - } - } - return DefaultJoinTableHandler -} - -func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) { - if s.parent.joinTableHandlers == nil { - s.parent.joinTableHandlers = map[string]JoinTableHandler{} - } - - if len(tables) > 0 { - for _, table := range tables { - s.parent.joinTableHandlers[table] = joinTableHandler - s.Table(table).AutoMigrate(joinTableHandler) - } - } else { - s.parent.joinTableHandlers["*"] = joinTableHandler } } diff --git a/main_test.go b/main_test.go index 3864fcbd..b547534c 100644 --- a/main_test.go +++ b/main_test.go @@ -36,6 +36,9 @@ func init() { case "postgres": fmt.Println("testing postgres...") DB, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable") + case "foundation": + fmt.Println("testing foundation...") + DB, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable") case "mssql": fmt.Println("testing mssql...") DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433") @@ -445,6 +448,14 @@ func TestHaving(t *testing.T) { } } +func DialectHasTzSupport() bool { + // NB: mssql and FoundationDB do not support time zones. + if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" { + return false + } + return true +} + func TestTimeWithZone(t *testing.T) { var format = "2006-01-02 15:04:05 -0700" var times []time.Time @@ -456,26 +467,30 @@ func TestTimeWithZone(t *testing.T) { name := "time_with_zone_" + strconv.Itoa(index) user := User{Name: name, Birthday: vtime} - // TODO mssql does not support time zones - if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" { + if !DialectHasTzSupport() { + // If our driver dialect doesn't support TZ's, just use UTC for everything here. user.Birthday = vtime.UTC() } + DB.Save(&user) - if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { - t.Errorf("User's birthday should not be changed after save") + expectedBirthday := "2013-02-18 17:51:49 +0000" + foundBirthday := user.Birthday.UTC().Format(format) + if foundBirthday != expectedBirthday { + t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday) } var findUser, findUser2, findUser3 User DB.First(&findUser, "name = ?", name) - if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" { - t.Errorf("User's birthday should not be changed after find") + foundBirthday = findUser.Birthday.UTC().Format(format) + if foundBirthday != expectedBirthday { + t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday) } - if DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() { + if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() { t.Errorf("User should be found") } - if !DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() { + if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() { t.Errorf("User should not be found") } } diff --git a/model_struct.go b/model_struct.go index 02e23dfe..f73c902b 100644 --- a/model_struct.go +++ b/model_struct.go @@ -60,7 +60,7 @@ type Relationship struct { ForeignDBName string AssociationForeignFieldName string AssociationForeignDBName string - JoinTable string + JoinTableHandler JoinTableHandlerInterface } var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")} @@ -146,143 +146,147 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - for _, field := range fields { - if !field.IsIgnored { - fieldStruct := field.Struct - fieldType, indirectType := fieldStruct.Type, fieldStruct.Type - if indirectType.Kind() == reflect.Ptr { - indirectType = indirectType.Elem() - } - - if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { - field.IsScanner, field.IsNormal = true, true - } - - if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { - field.IsNormal = true - } - - if !field.IsNormal { - gormSettings := parseTagSetting(field.Tag.Get("gorm")) - toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) - - getForeignField := func(column string, fields []*StructField) *StructField { - for _, field := range fields { - if field.Name == column || field.DBName == ToDBName(column) { - return field - } - } - return nil + defer func() { + for _, field := range fields { + if !field.IsIgnored { + fieldStruct := field.Struct + fieldType, indirectType := fieldStruct.Type, fieldStruct.Type + if indirectType.Kind() == reflect.Ptr { + indirectType = indirectType.Elem() } - var relationship = &Relationship{} - - foreignKey := gormSettings["FOREIGNKEY"] - if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { - if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { - if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { - relationship.ForeignFieldName = polymorphicField.Name - relationship.ForeignDBName = polymorphicField.DBName - relationship.PolymorphicType = polymorphicType.Name - relationship.PolymorphicDBName = polymorphicType.DBName - polymorphicType.IsForeignKey = true - polymorphicField.IsForeignKey = true - } - } + if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner { + field.IsScanner, field.IsNormal = true, true } - switch indirectType.Kind() { - case reflect.Slice: - elemType := indirectType.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } - - if elemType.Kind() == reflect.Struct { - if foreignKey == "" { - foreignKey = scopeType.Name() + "Id" - } - - if many2many := gormSettings["MANY2MANY"]; many2many != "" { - relationship.Kind = "many_to_many" - relationship.JoinTable = many2many - - associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] - if associationForeignKey == "" { - associationForeignKey = elemType.Name() + "Id" - } - - relationship.ForeignFieldName = foreignKey - relationship.ForeignDBName = ToDBName(foreignKey) - relationship.AssociationForeignFieldName = associationForeignKey - relationship.AssociationForeignDBName = ToDBName(associationForeignKey) - field.Relationship = relationship - } else { - relationship.Kind = "has_many" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { - field.Relationship = relationship - } - } - } else { - field.IsNormal = true - } - case reflect.Struct: - if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { - for _, toField := range toScope.GetStructFields() { - toField = toField.clone() - toField.Names = append([]string{fieldStruct.Name}, toField.Names...) - modelStruct.StructFields = append(modelStruct.StructFields, toField) - if toField.IsPrimaryKey { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) - } - } - continue - } else { - belongsToForeignKey := foreignKey - if belongsToForeignKey == "" { - belongsToForeignKey = field.Name + "Id" - } - - if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { - relationship.Kind = "belongs_to" - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else { - if foreignKey == "" { - foreignKey = modelStruct.ModelType.Name() + "Id" - } - relationship.Kind = "has_one" - if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { - relationship.ForeignFieldName = foreignField.Name - relationship.ForeignDBName = foreignField.DBName - foreignField.IsForeignKey = true - field.Relationship = relationship - } else if relationship.ForeignFieldName != "" { - field.Relationship = relationship - } - } - } - default: + if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime { field.IsNormal = true } - } - if field.IsNormal { - if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { - field.IsPrimaryKey = true - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + if !field.IsNormal { + gormSettings := parseTagSetting(field.Tag.Get("gorm")) + toScope := scope.New(reflect.New(fieldStruct.Type).Interface()) + + getForeignField := func(column string, fields []*StructField) *StructField { + for _, field := range fields { + if field.Name == column || field.DBName == ToDBName(column) { + return field + } + } + return nil + } + + var relationship = &Relationship{} + + foreignKey := gormSettings["FOREIGNKEY"] + if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil { + if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil { + relationship.ForeignFieldName = polymorphicField.Name + relationship.ForeignDBName = polymorphicField.DBName + relationship.PolymorphicType = polymorphicType.Name + relationship.PolymorphicDBName = polymorphicType.DBName + polymorphicType.IsForeignKey = true + polymorphicField.IsForeignKey = true + } + } + } + + switch indirectType.Kind() { + case reflect.Slice: + elemType := indirectType.Elem() + if elemType.Kind() == reflect.Ptr { + elemType = elemType.Elem() + } + + if elemType.Kind() == reflect.Struct { + if foreignKey == "" { + foreignKey = scopeType.Name() + "Id" + } + + if many2many := gormSettings["MANY2MANY"]; many2many != "" { + relationship.Kind = "many_to_many" + associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"] + if associationForeignKey == "" { + associationForeignKey = elemType.Name() + "Id" + } + + relationship.ForeignFieldName = foreignKey + relationship.ForeignDBName = ToDBName(foreignKey) + relationship.AssociationForeignFieldName = associationForeignKey + relationship.AssociationForeignDBName = ToDBName(associationForeignKey) + + joinTableHandler := JoinTableHandler{} + joinTableHandler.Setup(relationship, many2many, scopeType, elemType) + relationship.JoinTableHandler = &joinTableHandler + field.Relationship = relationship + } else { + relationship.Kind = "has_many" + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else if relationship.ForeignFieldName != "" { + field.Relationship = relationship + } + } + } else { + field.IsNormal = true + } + case reflect.Struct: + if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + for _, toField := range toScope.GetStructFields() { + toField = toField.clone() + toField.Names = append([]string{fieldStruct.Name}, toField.Names...) + modelStruct.StructFields = append(modelStruct.StructFields, toField) + if toField.IsPrimaryKey { + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField) + } + } + continue + } else { + belongsToForeignKey := foreignKey + if belongsToForeignKey == "" { + belongsToForeignKey = field.Name + "Id" + } + + if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil { + relationship.Kind = "belongs_to" + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else { + if foreignKey == "" { + foreignKey = modelStruct.ModelType.Name() + "Id" + } + relationship.Kind = "has_one" + if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil { + relationship.ForeignFieldName = foreignField.Name + relationship.ForeignDBName = foreignField.DBName + foreignField.IsForeignKey = true + field.Relationship = relationship + } else if relationship.ForeignFieldName != "" { + field.Relationship = relationship + } + } + } + default: + field.IsNormal = true + } + } + + if field.IsNormal { + if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" { + field.IsPrimaryKey = true + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) + } } } + modelStruct.StructFields = append(modelStruct.StructFields, field) } - modelStruct.StructFields = append(modelStruct.StructFields, field) - } + }() modelStructs[scopeType] = &modelStruct @@ -308,7 +312,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string { additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"] if value, ok := sqlSettings["DEFAULT"]; ok { - additionalType = additionalType + "DEFAULT " + value + additionalType = additionalType + " DEFAULT " + value } if field.IsScanner { diff --git a/mssql.go b/mssql.go index dc8e2917..c44541c7 100644 --- a/mssql.go +++ b/mssql.go @@ -7,21 +7,15 @@ import ( "time" ) -type mssql struct{} - -func (s *mssql) BinVar(i int) string { - return "$$" // ? +type mssql struct { + commonDialect } -func (s *mssql) SupportLastInsertId() bool { +func (mssql) HasTop() bool { return true } -func (s *mssql) HasTop() bool { - return true -} - -func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bit" @@ -57,19 +51,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String())) } -func (s *mssql) ReturningStr(tableName, key string) string { - return "" -} - -func (s *mssql) SelectFromDummyTable() string { - return "" -} - -func (s *mssql) Quote(key string) string { - return fmt.Sprintf(" \"%s\"", key) -} - -func (s *mssql) databaseName(scope *Scope) string { +func (mssql) databaseName(scope *Scope) string { dbStr := strings.Split(scope.db.parent.source, ";") for _, value := range dbStr { s := strings.Split(value, "=") @@ -80,24 +62,20 @@ func (s *mssql) databaseName(scope *Scope) string { return "" } -func (s *mssql) HasTable(scope *Scope, tableName string) bool { +func (s mssql) HasTable(scope *Scope, tableName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count) return count > 0 } -func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) return count > 0 } -func (s *mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count) return count > 0 } - -func (s *mssql) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) -} diff --git a/mysql.go b/mysql.go index d2eb08a5..e37a23e0 100644 --- a/mysql.go +++ b/mysql.go @@ -3,25 +3,14 @@ package gorm import ( "fmt" "reflect" - "strings" "time" ) -type mysql struct{} - -func (s *mysql) BinVar(i int) string { - return "$$" // ? +type mysql struct { + commonDialect } -func (s *mysql) SupportLastInsertId() bool { - return true -} - -func (s *mysql) HasTop() bool { - return false -} - -func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" @@ -57,45 +46,10 @@ func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String())) } -func (s *mysql) ReturningStr(tableName, key string) string { - return "" -} - -func (s *mysql) SelectFromDummyTable() string { - return "FROM DUAL" -} - -func (s *mysql) Quote(key string) string { +func (mysql) Quote(key string) string { return fmt.Sprintf("`%s`", key) } -func (s *mysql) databaseName(scope *Scope) string { - from := strings.Index(scope.db.parent.source, "/") + 1 - to := strings.Index(scope.db.parent.source, "?") - if to == -1 { - to = len(scope.db.parent.source) - } - return scope.db.parent.source[from:to] -} - -func (s *mysql) HasTable(scope *Scope, tableName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count) - return count > 0 -} - -func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count) - return count > 0 -} - -func (s *mysql) HasIndex(scope *Scope, tableName string, indexName string) bool { - var count int - scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count) - return count > 0 -} - -func (s *mysql) RemoveIndex(scope *Scope, indexName string) { - scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName())) +func (mysql) SelectFromDummyTable() string { + return "FROM DUAL" } diff --git a/postgres.go b/postgres.go index 83c37e1f..4218e1ba 100644 --- a/postgres.go +++ b/postgres.go @@ -11,21 +11,18 @@ import ( ) type postgres struct { + commonDialect } -func (s *postgres) BinVar(i int) string { +func (postgres) BinVar(i int) string { return fmt.Sprintf("$%v", i) } -func (s *postgres) SupportLastInsertId() bool { +func (postgres) SupportLastInsertId() bool { return false } -func (s *postgres) HasTop() bool { - return false -} - -func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" @@ -62,35 +59,27 @@ func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) stri panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String())) } -func (s *postgres) ReturningStr(tableName, key string) string { +func (s postgres) ReturningStr(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key) } -func (s *postgres) SelectFromDummyTable() string { - return "" -} - -func (s *postgres) Quote(key string) string { - return fmt.Sprintf("\"%s\"", key) -} - -func (s *postgres) HasTable(scope *Scope, tableName string) bool { +func (postgres) HasTable(scope *Scope, tableName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count) return count > 0 } -func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (postgres) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count) return count > 0 } -func (s *postgres) RemoveIndex(scope *Scope, indexName string) { +func (postgres) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) } -func (s *postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count) return count > 0 diff --git a/scope.go b/scope.go index fccc5b88..d8e39348 100644 --- a/scope.go +++ b/scope.go @@ -398,3 +398,11 @@ func (scope *Scope) changeableField(field *Field) bool { return !field.IsIgnored } + +func (scope *Scope) shouldSaveAssociations() bool { + saveAssociations, ok := scope.Get("gorm:save_associations") + if ok && !saveAssociations.(bool) { + return false + } + return true +} diff --git a/scope_private.go b/scope_private.go index e51e1faf..8de12ced 100644 --- a/scope_private.go +++ b/scope_private.go @@ -3,7 +3,6 @@ package gorm import ( "database/sql" "database/sql/driver" - "errors" "fmt" "reflect" "regexp" @@ -360,7 +359,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) scope.Search.Select(column) if dest.Kind() != reflect.Slice { - scope.Err(errors.New("results should be a slice")) + scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) return scope } @@ -402,18 +401,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if fromField != nil { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { - joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) - quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db, relationship)) - - joinSql := fmt.Sprintf( - "INNER JOIN %v ON %v.%v = %v.%v", - quotedJoinTable, - quotedJoinTable, - scope.Quote(relationship.AssociationForeignDBName), - toScope.QuotedTableName(), - scope.Quote(toScope.PrimaryKey())) - whereSql := fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName)) - scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error) + joinTableHandler := relationship.JoinTableHandler + scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() @@ -443,9 +432,9 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { } func (scope *Scope) createJoinTable(field *StructField) { - if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" { - joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable) - joinTable := joinTableHandler.Table(scope.db, relationship) + if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil { + joinTableHandler := relationship.JoinTableHandler + joinTable := joinTableHandler.Table(scope.db) if !scope.Dialect().HasTable(scope, joinTable) { primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false) scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", @@ -455,7 +444,7 @@ func (scope *Scope) createJoinTable(field *StructField) { scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")), ).Error) } - scope.NewDB().Table(joinTable).AutoMigrate() + scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } } @@ -570,7 +559,7 @@ func (scope *Scope) autoIndex() *Scope { if name == "UNIQUE_INDEX" { name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) } - uniqueIndexes[name] = append(indexes[name], field.DBName) + uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) } } diff --git a/sqlite3.go b/sqlite3.go index ce71ee08..afe70e3a 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -6,21 +6,11 @@ import ( "time" ) -type sqlite3 struct{} - -func (s *sqlite3) BinVar(i int) string { - return "$$" // ? +type sqlite3 struct { + commonDialect } -func (s *sqlite3) SupportLastInsertId() bool { - return true -} - -func (s *sqlite3) HasTop() bool { - return false -} - -func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { +func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "bool" @@ -50,36 +40,24 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) strin panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String())) } -func (s *sqlite3) ReturningStr(tableName, key string) string { - return "" -} - -func (s *sqlite3) SelectFromDummyTable() string { - return "" -} - -func (s *sqlite3) Quote(key string) string { - return fmt.Sprintf("\"%s\"", key) -} - -func (s *sqlite3) HasTable(scope *Scope, tableName string) bool { +func (sqlite3) HasTable(scope *Scope, tableName string) bool { var count int scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count) return count > 0 } -func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { +func (sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool { var count int scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count) return count > 0 } -func (s *sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { +func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool { var count int scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count) return count > 0 } -func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) { +func (sqlite3) RemoveIndex(scope *Scope, indexName string) { scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName)) } diff --git a/test_all.sh b/test_all.sh index 6c5593b3..bd28294d 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "mysql" "sqlite") +dialects=("postgres" "foundation" "mysql" "sqlite") for dialect in "${dialects[@]}" ; do GORM_DIALECT=${dialect} go test diff --git a/update_test.go b/update_test.go index 8a019087..9a0af806 100644 --- a/update_test.go +++ b/update_test.go @@ -382,3 +382,32 @@ func TestOmitWithUpdateColumn(t *testing.T) { t.Errorf("Should omit name column when update user") } } + +func TestUpdateColumnsSkipsAssociations(t *testing.T) { + user := getPreparedUser("update_columns_user", "special_role") + user.Age = 99 + address1 := "first street" + user.BillingAddress = Address{Address1: address1} + DB.Save(user) + + // Update a single field of the user and verify that the changed address is not stored. + newAge := int64(100) + user.BillingAddress.Address1 = "second street" + db := DB.Model(user).UpdateColumns(User{Age: newAge}) + if db.RowsAffected != 1 { + t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected) + } + + // Verify that Age now=`newAge`. + freshUser := &User{Id: user.Id} + DB.First(freshUser) + if freshUser.Age != newAge { + t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age) + } + + // Verify that user's BillingAddress.Address1 is not changed and is still "first street". + DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID) + if freshUser.BillingAddress.Address1 != address1 { + t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1) + } +}