Merge pull request #1 from jinzhu/master

Update
This commit is contained in:
Vincenzo Prignano 2015-03-31 22:30:49 +01:00
commit 4e39820d9a
22 changed files with 534 additions and 407 deletions

View File

@ -96,6 +96,7 @@ import (
)
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB.
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
// db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
@ -360,7 +361,7 @@ db.Model(&user).Updates(User{Name: "hello", Age: 18})
### Update Without Callbacks
By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks:
By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations:
```go
db.Model(&user).UpdateColumn("name", "hello")
@ -546,7 +547,7 @@ Supports polymorphic has-many and has-one associations.
Id int
Name string
OwnerId int
OwnerType int
OwnerType string
}
```
Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors.

View File

@ -77,7 +77,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ {
value := association.Field.Field.Index(i)
@ -132,7 +132,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship))
association.setErr(relationship.JoinTableHandler.Delete(query, relationship))
} else {
association.setErr(errors.New("replace only support many to many"))
}
@ -145,7 +145,7 @@ func (association *Association) Clear() *Association {
if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey)
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
} else {
association.setErr(err)
@ -163,9 +163,7 @@ func (association *Association) Count() int {
newScope := scope.New(association.Field.Field.Interface())
if relationship.Kind == "many_to_many" {
query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count)
relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)

View File

@ -21,7 +21,7 @@ func Query(scope *Scope) {
dest = reflect.Indirect(reflect.ValueOf(value))
}
if orderBy, ok := scope.InstanceGet("gorm:order_by_primary_key"); ok {
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy))
}

View File

@ -11,6 +11,9 @@ func CommitOrRollbackTransaction(scope *Scope) {
}
func SaveBeforeAssociations(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
@ -25,6 +28,9 @@ func SaveBeforeAssociations(scope *Scope) {
}
func SaveAfterAssociations(scope *Scope) {
if !scope.shouldSaveAssociations() {
return
}
for _, field := range scope.Fields() {
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
if relationship := field.Relationship; relationship != nil &&
@ -38,7 +44,7 @@ func SaveAfterAssociations(scope *Scope) {
elem := value.Index(i).Addr().Interface()
newScope := newDB.NewScope(elem)
if relationship.JoinTable == "" && relationship.ForeignFieldName != "" {
if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" {
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
}
@ -48,9 +54,8 @@ func SaveAfterAssociations(scope *Scope) {
scope.Err(newDB.Save(elem).Error)
if joinTable := relationship.JoinTable; joinTable != "" {
scope.Err(scope.db.GetJoinTableHandler(joinTable).
Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue()))
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value))
}
}
default:

View File

@ -9,19 +9,19 @@ import (
type commonDialect struct{}
func (s *commonDialect) BinVar(i int) string {
return "?"
func (commonDialect) BinVar(i int) string {
return "$$" // ?
}
func (s *commonDialect) SupportLastInsertId() bool {
func (commonDialect) SupportLastInsertId() bool {
return true
}
func (s *commonDialect) HasTop() bool {
func (commonDialect) HasTop() bool {
return false
}
func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "BOOLEAN"
@ -57,19 +57,19 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool)
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
}
func (s *commonDialect) ReturningStr(tableName, key string) string {
func (commonDialect) ReturningStr(tableName, key string) string {
return ""
}
func (s *commonDialect) SelectFromDummyTable() string {
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (s *commonDialect) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *commonDialect) databaseName(scope *Scope) string {
func (commonDialect) databaseName(scope *Scope) string {
from := strings.Index(scope.db.parent.source, "/") + 1
to := strings.Index(scope.db.parent.source, "?")
if to == -1 {
@ -78,24 +78,24 @@ func (s *commonDialect) databaseName(scope *Scope) string {
return scope.db.parent.source[from:to]
}
func (s *commonDialect) HasTable(scope *Scope, tableName string) bool {
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
func (s *commonDialect) RemoveIndex(scope *Scope, indexName string) {
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}

View File

@ -24,6 +24,8 @@ func NewDialect(driver string) Dialect {
switch driver {
case "postgres":
d = &postgres{}
case "foundation":
d = &foundation{}
case "mysql":
d = &mysql{}
case "sqlite3":

View File

@ -27,9 +27,13 @@ func (field *Field) Set(value interface{}) error {
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
if v, ok := value.(reflect.Value); ok {
scanner.Scan(v.Interface())
if err := scanner.Scan(v.Interface()); err != nil {
return err
}
} else {
scanner.Scan(value)
if err := scanner.Scan(value); err != nil {
return err
}
}
} else {
reflectValue, ok := value.(reflect.Value)

78
foundation.go Normal file
View File

@ -0,0 +1,78 @@
package gorm
import (
"fmt"
"reflect"
"time"
)
type foundation struct {
commonDialect
}
func (foundation) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (foundation) SupportLastInsertId() bool {
return false
}
func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "serial"
}
return "int"
case reflect.Int64, reflect.Uint64:
if autoIncrease {
return "bigserial"
}
return "bigint"
case reflect.Float32, reflect.Float64:
return "double"
case reflect.String:
if size > 0 && size < 65532 {
return fmt.Sprintf("varchar(%d)", size)
}
return "clob"
case reflect.Struct:
if _, ok := value.Interface().(time.Time); ok {
return "datetime"
}
default:
if _, ok := value.Interface().([]byte); ok {
return "blob"
}
}
panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
}
func (f foundation) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", f.Quote(tableName), key)
}
func (foundation) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName).Row().Scan(&count)
return count > 0
}
func (foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
return count > 0
}
func (f foundation) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", f.Quote(indexName)))
}
func (foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}

View File

@ -1,48 +0,0 @@
package gorm
import (
"fmt"
"strings"
)
type JoinTableHandler interface {
Table(*DB, *Relationship) string
Add(*DB, *Relationship, interface{}, interface{}) error
Delete(*DB, *Relationship) error
Scope(*DB, *Relationship) *DB
}
type defaultJoinTableHandler struct{}
func (s *defaultJoinTableHandler) Table(db *DB, relationship *Relationship) string {
return relationship.JoinTable
}
func (s *defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error {
scope := db.NewScope("")
quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName)
table := s.Table(db, relationship)
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);",
scope.Quote(table),
strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","),
scope.Dialect().SelectFromDummyTable(),
scope.Quote(table),
quotedForeignDBName,
quotedAssociationDBName,
)
return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error
}
func (s *defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error {
return db.Table(s.Table(db, relationship)).Delete("").Error
}
func (s *defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB {
return db.Table(s.Table(db, relationship))
}
var DefaultJoinTableHandler = &defaultJoinTableHandler{}

155
join_table_handler.go Normal file
View File

@ -0,0 +1,155 @@
package gorm
import (
"errors"
"fmt"
"reflect"
"strings"
)
type JoinTableHandlerInterface interface {
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
Table(db *DB) string
Add(db *DB, source interface{}, destination interface{}) error
Delete(db *DB, sources ...interface{}) error
JoinWith(db *DB, source interface{}) *DB
}
type JoinTableForeignKey struct {
DBName string
AssociationDBName string
}
type JoinTableSource struct {
ModelType reflect.Type
ForeignKeys []JoinTableForeignKey
}
type JoinTableHandler struct {
TableName string `sql:"-"`
Source JoinTableSource `sql:"-"`
Destination JoinTableSource `sql:"-"`
}
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName
s.Source = JoinTableSource{ModelType: source}
sourceScope := &Scope{Value: reflect.New(source).Interface()}
for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields {
if relationship.ForeignDBName == "" {
relationship.ForeignFieldName = source.Name() + primaryField.Name
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
}
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: relationship.ForeignDBName,
AssociationDBName: primaryField.DBName,
})
}
s.Destination = JoinTableSource{ModelType: destination}
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields {
if relationship.AssociationForeignDBName == "" {
relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name
relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName)
}
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: relationship.AssociationForeignDBName,
AssociationDBName: primaryField.DBName,
})
}
}
func (s JoinTableHandler) Table(*DB) string {
return s.TableName
}
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
values := map[string]interface{}{}
for _, source := range sources {
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Source.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
}
} else if s.Destination.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
}
}
}
return values
}
func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
scope := db.NewScope("")
searchMap := s.GetSearchMap(db, source1, source2)
var assignColumns, binVars, conditions []string
var values []interface{}
for key, value := range searchMap {
assignColumns = append(assignColumns, key)
binVars = append(binVars, `?`)
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
values = append(values, value)
}
for _, value := range values {
values = append(values, value)
}
quotedTable := s.Table(db)
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
quotedTable,
strings.Join(assignColumns, ","),
strings.Join(binVars, ","),
scope.Dialect().SelectFromDummyTable(),
quotedTable,
strings.Join(conditions, " AND "),
)
return db.Exec(sql, values...).Error
}
func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
var conditions []string
var values []interface{}
for key, value := range s.GetSearchMap(db, sources...) {
conditions = append(conditions, fmt.Sprintf("%v = ?", key))
values = append(values, value)
}
return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}
func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
quotedTable := s.Table(db)
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
var joinConditions []string
var queryConditions []string
var values []interface{}
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}
for _, foreignKey := range s.Source.ForeignKeys {
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
}
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
Where(strings.Join(queryConditions, " AND "), values...)
} else {
db.Error = errors.New("wrong source type for join table handler")
return db
}
}

View File

@ -15,40 +15,37 @@ type Person struct {
}
type PersonAddress struct {
gorm.JoinTableHandler
PersonID int
AddressID int
DeletedAt time.Time
CreatedAt time.Time
}
func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string {
return relationship.JoinTable
}
func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
return db.Where(map[string]interface{}{
relationship.ForeignDBName: foreignValue,
relationship.AssociationForeignDBName: associationValue,
"person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
"address_id": db.NewScope(associationValue).PrimaryKeyValue(),
}).Assign(map[string]interface{}{
relationship.ForeignFieldName: foreignValue,
relationship.AssociationForeignFieldName: associationValue,
"DeletedAt": gorm.Expr("NULL"),
"person_id": foreignValue,
"address_id": associationValue,
"deleted_at": gorm.Expr("NULL"),
}).FirstOrCreate(&PersonAddress{}).Error
}
func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error {
func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error {
return db.Delete(&PersonAddress{}).Error
}
func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
table := pa.Table(db, relationship)
return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB {
table := pa.Table(db)
return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
}
func TestJoinTable(t *testing.T) {
DB.Exec("drop table person_addresses;")
DB.AutoMigrate(&Person{})
DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses")
DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{})
address1 := &Address{Address1: "address 1"}
address2 := &Address{Address1: "address 2"}

41
main.go
View File

@ -55,6 +55,9 @@ func Open(dialect string, args ...interface{}) (DB, error) {
driver = value
source = args[1].(string)
}
if driver == "foundation" {
driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
}
dbSql, err = sql.Open(driver, source)
case sqlCommon:
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
@ -193,14 +196,14 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
func (s *DB) First(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search.Limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "ASC").
return newScope.Set("gorm:order_by_primary_key", "ASC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
newScope := s.clone().NewScope(out)
newScope.Search.Limit(1)
return newScope.InstanceSet("gorm:order_by_primary_key", "DESC").
return newScope.Set("gorm:order_by_primary_key", "DESC").
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
@ -266,6 +269,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
func (s *DB) UpdateColumns(values interface{}) *DB {
return s.clone().NewScope(s.Value).
Set("gorm:update_column", true).
Set("gorm:save_associations", false).
InstanceSet("gorm:update_interface", values).
callCallbacks(s.parent.callback.updates).db
}
@ -470,29 +474,16 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
return
}
func (s *DB) GetJoinTableHandler(table string) JoinTableHandler {
if s.parent.joinTableHandlers != nil {
if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok {
return joinTableHandler
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
for _, field := range s.NewScope(source).GetModelStruct().StructFields {
if field.Name == column || field.DBName == column {
if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" {
source := (&Scope{Value: source}).GetModelStruct().ModelType
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler
s.Table(handler.Table(s)).AutoMigrate(handler)
}
}
if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok {
return joinTableHandler
}
}
return DefaultJoinTableHandler
}
func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) {
if s.parent.joinTableHandlers == nil {
s.parent.joinTableHandlers = map[string]JoinTableHandler{}
}
if len(tables) > 0 {
for _, table := range tables {
s.parent.joinTableHandlers[table] = joinTableHandler
s.Table(table).AutoMigrate(joinTableHandler)
}
} else {
s.parent.joinTableHandlers["*"] = joinTableHandler
}
}

View File

@ -36,6 +36,9 @@ func init() {
case "postgres":
fmt.Println("testing postgres...")
DB, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
case "foundation":
fmt.Println("testing foundation...")
DB, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
case "mssql":
fmt.Println("testing mssql...")
DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433")
@ -445,6 +448,14 @@ func TestHaving(t *testing.T) {
}
}
func DialectHasTzSupport() bool {
// NB: mssql and FoundationDB do not support time zones.
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" {
return false
}
return true
}
func TestTimeWithZone(t *testing.T) {
var format = "2006-01-02 15:04:05 -0700"
var times []time.Time
@ -456,26 +467,30 @@ func TestTimeWithZone(t *testing.T) {
name := "time_with_zone_" + strconv.Itoa(index)
user := User{Name: name, Birthday: vtime}
// TODO mssql does not support time zones
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" {
if !DialectHasTzSupport() {
// If our driver dialect doesn't support TZ's, just use UTC for everything here.
user.Birthday = vtime.UTC()
}
DB.Save(&user)
if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
t.Errorf("User's birthday should not be changed after save")
expectedBirthday := "2013-02-18 17:51:49 +0000"
foundBirthday := user.Birthday.UTC().Format(format)
if foundBirthday != expectedBirthday {
t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
}
var findUser, findUser2, findUser3 User
DB.First(&findUser, "name = ?", name)
if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
t.Errorf("User's birthday should not be changed after find")
foundBirthday = findUser.Birthday.UTC().Format(format)
if foundBirthday != expectedBirthday {
t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday)
}
if DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
t.Errorf("User should be found")
}
if !DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() {
if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() {
t.Errorf("User should not be found")
}
}

View File

@ -60,7 +60,7 @@ type Relationship struct {
ForeignDBName string
AssociationForeignFieldName string
AssociationForeignDBName string
JoinTable string
JoinTableHandler JoinTableHandlerInterface
}
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
@ -146,143 +146,147 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
}
for _, field := range fields {
if !field.IsIgnored {
fieldStruct := field.Struct
fieldType, indirectType := fieldStruct.Type, fieldStruct.Type
if indirectType.Kind() == reflect.Ptr {
indirectType = indirectType.Elem()
}
if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner {
field.IsScanner, field.IsNormal = true, true
}
if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
field.IsNormal = true
}
if !field.IsNormal {
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
getForeignField := func(column string, fields []*StructField) *StructField {
for _, field := range fields {
if field.Name == column || field.DBName == ToDBName(column) {
return field
}
}
return nil
defer func() {
for _, field := range fields {
if !field.IsIgnored {
fieldStruct := field.Struct
fieldType, indirectType := fieldStruct.Type, fieldStruct.Type
if indirectType.Kind() == reflect.Ptr {
indirectType = indirectType.Elem()
}
var relationship = &Relationship{}
foreignKey := gormSettings["FOREIGNKEY"]
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
relationship.ForeignFieldName = polymorphicField.Name
relationship.ForeignDBName = polymorphicField.DBName
relationship.PolymorphicType = polymorphicType.Name
relationship.PolymorphicDBName = polymorphicType.DBName
polymorphicType.IsForeignKey = true
polymorphicField.IsForeignKey = true
}
}
if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner {
field.IsScanner, field.IsNormal = true, true
}
switch indirectType.Kind() {
case reflect.Slice:
elemType := indirectType.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
if foreignKey == "" {
foreignKey = scopeType.Name() + "Id"
}
if many2many := gormSettings["MANY2MANY"]; many2many != "" {
relationship.Kind = "many_to_many"
relationship.JoinTable = many2many
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
if associationForeignKey == "" {
associationForeignKey = elemType.Name() + "Id"
}
relationship.ForeignFieldName = foreignKey
relationship.ForeignDBName = ToDBName(foreignKey)
relationship.AssociationForeignFieldName = associationForeignKey
relationship.AssociationForeignDBName = ToDBName(associationForeignKey)
field.Relationship = relationship
} else {
relationship.Kind = "has_many"
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else if relationship.ForeignFieldName != "" {
field.Relationship = relationship
}
}
} else {
field.IsNormal = true
}
case reflect.Struct:
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
for _, toField := range toScope.GetStructFields() {
toField = toField.clone()
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
modelStruct.StructFields = append(modelStruct.StructFields, toField)
if toField.IsPrimaryKey {
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
}
}
continue
} else {
belongsToForeignKey := foreignKey
if belongsToForeignKey == "" {
belongsToForeignKey = field.Name + "Id"
}
if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil {
relationship.Kind = "belongs_to"
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else {
if foreignKey == "" {
foreignKey = modelStruct.ModelType.Name() + "Id"
}
relationship.Kind = "has_one"
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else if relationship.ForeignFieldName != "" {
field.Relationship = relationship
}
}
}
default:
if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
field.IsNormal = true
}
}
if field.IsNormal {
if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
if !field.IsNormal {
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
getForeignField := func(column string, fields []*StructField) *StructField {
for _, field := range fields {
if field.Name == column || field.DBName == ToDBName(column) {
return field
}
}
return nil
}
var relationship = &Relationship{}
foreignKey := gormSettings["FOREIGNKEY"]
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
relationship.ForeignFieldName = polymorphicField.Name
relationship.ForeignDBName = polymorphicField.DBName
relationship.PolymorphicType = polymorphicType.Name
relationship.PolymorphicDBName = polymorphicType.DBName
polymorphicType.IsForeignKey = true
polymorphicField.IsForeignKey = true
}
}
}
switch indirectType.Kind() {
case reflect.Slice:
elemType := indirectType.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
if foreignKey == "" {
foreignKey = scopeType.Name() + "Id"
}
if many2many := gormSettings["MANY2MANY"]; many2many != "" {
relationship.Kind = "many_to_many"
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
if associationForeignKey == "" {
associationForeignKey = elemType.Name() + "Id"
}
relationship.ForeignFieldName = foreignKey
relationship.ForeignDBName = ToDBName(foreignKey)
relationship.AssociationForeignFieldName = associationForeignKey
relationship.AssociationForeignDBName = ToDBName(associationForeignKey)
joinTableHandler := JoinTableHandler{}
joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
relationship.JoinTableHandler = &joinTableHandler
field.Relationship = relationship
} else {
relationship.Kind = "has_many"
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else if relationship.ForeignFieldName != "" {
field.Relationship = relationship
}
}
} else {
field.IsNormal = true
}
case reflect.Struct:
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
for _, toField := range toScope.GetStructFields() {
toField = toField.clone()
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
modelStruct.StructFields = append(modelStruct.StructFields, toField)
if toField.IsPrimaryKey {
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
}
}
continue
} else {
belongsToForeignKey := foreignKey
if belongsToForeignKey == "" {
belongsToForeignKey = field.Name + "Id"
}
if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil {
relationship.Kind = "belongs_to"
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else {
if foreignKey == "" {
foreignKey = modelStruct.ModelType.Name() + "Id"
}
relationship.Kind = "has_one"
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
relationship.ForeignFieldName = foreignField.Name
relationship.ForeignDBName = foreignField.DBName
foreignField.IsForeignKey = true
field.Relationship = relationship
} else if relationship.ForeignFieldName != "" {
field.Relationship = relationship
}
}
}
default:
field.IsNormal = true
}
}
if field.IsNormal {
if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
field.IsPrimaryKey = true
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
}
}
}
modelStruct.StructFields = append(modelStruct.StructFields, field)
}
modelStruct.StructFields = append(modelStruct.StructFields, field)
}
}()
modelStructs[scopeType] = &modelStruct
@ -308,7 +312,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
if value, ok := sqlSettings["DEFAULT"]; ok {
additionalType = additionalType + "DEFAULT " + value
additionalType = additionalType + " DEFAULT " + value
}
if field.IsScanner {

View File

@ -7,21 +7,15 @@ import (
"time"
)
type mssql struct{}
func (s *mssql) BinVar(i int) string {
return "$$" // ?
type mssql struct {
commonDialect
}
func (s *mssql) SupportLastInsertId() bool {
func (mssql) HasTop() bool {
return true
}
func (s *mssql) HasTop() bool {
return true
}
func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bit"
@ -57,19 +51,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
}
func (s *mssql) ReturningStr(tableName, key string) string {
return ""
}
func (s *mssql) SelectFromDummyTable() string {
return ""
}
func (s *mssql) Quote(key string) string {
return fmt.Sprintf(" \"%s\"", key)
}
func (s *mssql) databaseName(scope *Scope) string {
func (mssql) databaseName(scope *Scope) string {
dbStr := strings.Split(scope.db.parent.source, ";")
for _, value := range dbStr {
s := strings.Split(value, "=")
@ -80,24 +62,20 @@ func (s *mssql) databaseName(scope *Scope) string {
return ""
}
func (s *mssql) HasTable(scope *Scope, tableName string) bool {
func (s mssql) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
return count > 0
}
func (s *mssql) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
}

View File

@ -3,25 +3,14 @@ package gorm
import (
"fmt"
"reflect"
"strings"
"time"
)
type mysql struct{}
func (s *mysql) BinVar(i int) string {
return "$$" // ?
type mysql struct {
commonDialect
}
func (s *mysql) SupportLastInsertId() bool {
return true
}
func (s *mysql) HasTop() bool {
return false
}
func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
@ -57,45 +46,10 @@ func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
}
func (s *mysql) ReturningStr(tableName, key string) string {
return ""
}
func (s *mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}
func (s *mysql) Quote(key string) string {
func (mysql) Quote(key string) string {
return fmt.Sprintf("`%s`", key)
}
func (s *mysql) databaseName(scope *Scope) string {
from := strings.Index(scope.db.parent.source, "/") + 1
to := strings.Index(scope.db.parent.source, "?")
if to == -1 {
to = len(scope.db.parent.source)
}
return scope.db.parent.source[from:to]
}
func (s *mysql) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
return count > 0
}
func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *mysql) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
return count > 0
}
func (s *mysql) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
func (mysql) SelectFromDummyTable() string {
return "FROM DUAL"
}

View File

@ -11,21 +11,18 @@ import (
)
type postgres struct {
commonDialect
}
func (s *postgres) BinVar(i int) string {
func (postgres) BinVar(i int) string {
return fmt.Sprintf("$%v", i)
}
func (s *postgres) SupportLastInsertId() bool {
func (postgres) SupportLastInsertId() bool {
return false
}
func (s *postgres) HasTop() bool {
return false
}
func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
@ -62,35 +59,27 @@ func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) stri
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
}
func (s *postgres) ReturningStr(tableName, key string) string {
func (s postgres) ReturningStr(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
}
func (s *postgres) SelectFromDummyTable() string {
return ""
}
func (s *postgres) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key)
}
func (s *postgres) HasTable(scope *Scope, tableName string) bool {
func (postgres) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count)
return count > 0
}
func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
return count > 0
}
func (s *postgres) RemoveIndex(scope *Scope, indexName string) {
func (postgres) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
}
func (s *postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count)
return count > 0

View File

@ -398,3 +398,11 @@ func (scope *Scope) changeableField(field *Field) bool {
return !field.IsIgnored
}
func (scope *Scope) shouldSaveAssociations() bool {
saveAssociations, ok := scope.Get("gorm:save_associations")
if ok && !saveAssociations.(bool) {
return false
}
return true
}

View File

@ -3,7 +3,6 @@ package gorm
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"reflect"
"regexp"
@ -360,7 +359,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
dest := reflect.Indirect(reflect.ValueOf(value))
scope.Search.Select(column)
if dest.Kind() != reflect.Slice {
scope.Err(errors.New("results should be a slice"))
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
return scope
}
@ -402,18 +401,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if fromField != nil {
if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" {
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db, relationship))
joinSql := fmt.Sprintf(
"INNER JOIN %v ON %v.%v = %v.%v",
quotedJoinTable,
quotedJoinTable,
scope.Quote(relationship.AssociationForeignDBName),
toScope.QuotedTableName(),
scope.Quote(toScope.PrimaryKey()))
whereSql := fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName))
scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error)
joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error)
} else if relationship.Kind == "belongs_to" {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()
@ -443,9 +432,9 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
}
func (scope *Scope) createJoinTable(field *StructField) {
if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" {
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
joinTable := joinTableHandler.Table(scope.db, relationship)
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(scope, joinTable) {
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
@ -455,7 +444,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
).Error)
}
scope.NewDB().Table(joinTable).AutoMigrate()
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
}
}
@ -570,7 +559,7 @@ func (scope *Scope) autoIndex() *Scope {
if name == "UNIQUE_INDEX" {
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
}
uniqueIndexes[name] = append(indexes[name], field.DBName)
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
}
}

View File

@ -6,21 +6,11 @@ import (
"time"
)
type sqlite3 struct{}
func (s *sqlite3) BinVar(i int) string {
return "$$" // ?
type sqlite3 struct {
commonDialect
}
func (s *sqlite3) SupportLastInsertId() bool {
return true
}
func (s *sqlite3) HasTop() bool {
return false
}
func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "bool"
@ -50,36 +40,24 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) strin
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
}
func (s *sqlite3) ReturningStr(tableName, key string) string {
return ""
}
func (s *sqlite3) SelectFromDummyTable() string {
return ""
}
func (s *sqlite3) Quote(key string) string {
return fmt.Sprintf("\"%s\"", key)
}
func (s *sqlite3) HasTable(scope *Scope, tableName string) bool {
func (sqlite3) HasTable(scope *Scope, tableName string) bool {
var count int
scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count)
return count > 0
}
func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
func (sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
var count int
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count)
return count > 0
}
func (s *sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
var count int
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count)
return count > 0
}
func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) {
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
}

View File

@ -1,4 +1,4 @@
dialects=("postgres" "mysql" "sqlite")
dialects=("postgres" "foundation" "mysql" "sqlite")
for dialect in "${dialects[@]}" ; do
GORM_DIALECT=${dialect} go test

View File

@ -382,3 +382,32 @@ func TestOmitWithUpdateColumn(t *testing.T) {
t.Errorf("Should omit name column when update user")
}
}
func TestUpdateColumnsSkipsAssociations(t *testing.T) {
user := getPreparedUser("update_columns_user", "special_role")
user.Age = 99
address1 := "first street"
user.BillingAddress = Address{Address1: address1}
DB.Save(user)
// Update a single field of the user and verify that the changed address is not stored.
newAge := int64(100)
user.BillingAddress.Address1 = "second street"
db := DB.Model(user).UpdateColumns(User{Age: newAge})
if db.RowsAffected != 1 {
t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected)
}
// Verify that Age now=`newAge`.
freshUser := &User{Id: user.Id}
DB.First(freshUser)
if freshUser.Age != newAge {
t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age)
}
// Verify that user's BillingAddress.Address1 is not changed and is still "first street".
DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID)
if freshUser.BillingAddress.Address1 != address1 {
t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
}
}