commit
4e39820d9a
@ -96,6 +96,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||||
|
// db, err := gorm.Open("foundation", "dbname=gorm") // FoundationDB.
|
||||||
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
|
// db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True&loc=Local")
|
||||||
// db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
|
// db, err := gorm.Open("sqlite3", "/tmp/gorm.db")
|
||||||
|
|
||||||
@ -360,7 +361,7 @@ db.Model(&user).Updates(User{Name: "hello", Age: 18})
|
|||||||
|
|
||||||
### Update Without Callbacks
|
### Update Without Callbacks
|
||||||
|
|
||||||
By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks:
|
By default, update will call BeforeUpdate, AfterUpdate callbacks, if you want to update w/o callbacks and w/o saving associations:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
db.Model(&user).UpdateColumn("name", "hello")
|
db.Model(&user).UpdateColumn("name", "hello")
|
||||||
@ -546,7 +547,7 @@ Supports polymorphic has-many and has-one associations.
|
|||||||
Id int
|
Id int
|
||||||
Name string
|
Name string
|
||||||
OwnerId int
|
OwnerId int
|
||||||
OwnerType int
|
OwnerType string
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors.
|
Note: polymorphic belongs-to and many-to-many are explicitly NOT supported, and will throw errors.
|
||||||
|
@ -77,7 +77,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
|||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||||
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
|
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
|
||||||
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
|
||||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||||
value := association.Field.Field.Index(i)
|
value := association.Field.Field.Index(i)
|
||||||
@ -132,7 +132,7 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
|||||||
|
|
||||||
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||||
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
|
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
|
||||||
association.setErr(scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship))
|
association.setErr(relationship.JoinTableHandler.Delete(query, relationship))
|
||||||
} else {
|
} else {
|
||||||
association.setErr(errors.New("replace only support many to many"))
|
association.setErr(errors.New("replace only support many to many"))
|
||||||
}
|
}
|
||||||
@ -145,7 +145,7 @@ func (association *Association) Clear() *Association {
|
|||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
|
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
|
||||||
query := scope.NewDB().Where(sql, association.PrimaryKey)
|
query := scope.NewDB().Where(sql, association.PrimaryKey)
|
||||||
if err := scope.db.GetJoinTableHandler(relationship.JoinTable).Delete(query, relationship); err == nil {
|
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
|
||||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||||
} else {
|
} else {
|
||||||
association.setErr(err)
|
association.setErr(err)
|
||||||
@ -163,9 +163,7 @@ func (association *Association) Count() int {
|
|||||||
newScope := scope.New(association.Field.Field.Interface())
|
newScope := scope.New(association.Field.Field.Interface())
|
||||||
|
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
query := scope.DB().Select("COUNT(DISTINCT ?)", relationship.AssociationForeignDBName).
|
relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
|
||||||
Where(relationship.ForeignDBName+" = ?", association.PrimaryKey)
|
|
||||||
scope.db.GetJoinTableHandler(relationship.JoinTable).Scope(query, relationship).Row().Scan(&count)
|
|
||||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||||
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
|
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
|
||||||
|
@ -21,7 +21,7 @@ func Query(scope *Scope) {
|
|||||||
dest = reflect.Indirect(reflect.ValueOf(value))
|
dest = reflect.Indirect(reflect.ValueOf(value))
|
||||||
}
|
}
|
||||||
|
|
||||||
if orderBy, ok := scope.InstanceGet("gorm:order_by_primary_key"); ok {
|
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||||
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
|
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
|
||||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy))
|
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy))
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,9 @@ func CommitOrRollbackTransaction(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SaveBeforeAssociations(scope *Scope) {
|
func SaveBeforeAssociations(scope *Scope) {
|
||||||
|
if !scope.shouldSaveAssociations() {
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
@ -25,6 +28,9 @@ func SaveBeforeAssociations(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SaveAfterAssociations(scope *Scope) {
|
func SaveAfterAssociations(scope *Scope) {
|
||||||
|
if !scope.shouldSaveAssociations() {
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, field := range scope.Fields() {
|
for _, field := range scope.Fields() {
|
||||||
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored {
|
||||||
if relationship := field.Relationship; relationship != nil &&
|
if relationship := field.Relationship; relationship != nil &&
|
||||||
@ -38,7 +44,7 @@ func SaveAfterAssociations(scope *Scope) {
|
|||||||
elem := value.Index(i).Addr().Interface()
|
elem := value.Index(i).Addr().Interface()
|
||||||
newScope := newDB.NewScope(elem)
|
newScope := newDB.NewScope(elem)
|
||||||
|
|
||||||
if relationship.JoinTable == "" && relationship.ForeignFieldName != "" {
|
if relationship.JoinTableHandler == nil && relationship.ForeignFieldName != "" {
|
||||||
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
|
scope.Err(newScope.SetColumn(relationship.ForeignFieldName, scope.PrimaryKeyValue()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,9 +54,8 @@ func SaveAfterAssociations(scope *Scope) {
|
|||||||
|
|
||||||
scope.Err(newDB.Save(elem).Error)
|
scope.Err(newDB.Save(elem).Error)
|
||||||
|
|
||||||
if joinTable := relationship.JoinTable; joinTable != "" {
|
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||||
scope.Err(scope.db.GetJoinTableHandler(joinTable).
|
scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value))
|
||||||
Add(scope.NewDB(), relationship, scope.PrimaryKeyValue(), newScope.PrimaryKeyValue()))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -9,19 +9,19 @@ import (
|
|||||||
|
|
||||||
type commonDialect struct{}
|
type commonDialect struct{}
|
||||||
|
|
||||||
func (s *commonDialect) BinVar(i int) string {
|
func (commonDialect) BinVar(i int) string {
|
||||||
return "?"
|
return "$$" // ?
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) SupportLastInsertId() bool {
|
func (commonDialect) SupportLastInsertId() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) HasTop() bool {
|
func (commonDialect) HasTop() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
func (commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "BOOLEAN"
|
return "BOOLEAN"
|
||||||
@ -57,19 +57,19 @@ func (s *commonDialect) SqlTag(value reflect.Value, size int, autoIncrease bool)
|
|||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) ReturningStr(tableName, key string) string {
|
func (commonDialect) ReturningStr(tableName, key string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) SelectFromDummyTable() string {
|
func (commonDialect) SelectFromDummyTable() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) Quote(key string) string {
|
func (commonDialect) Quote(key string) string {
|
||||||
return fmt.Sprintf("`%s`", key)
|
return fmt.Sprintf(`"%s"`, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) databaseName(scope *Scope) string {
|
func (commonDialect) databaseName(scope *Scope) string {
|
||||||
from := strings.Index(scope.db.parent.source, "/") + 1
|
from := strings.Index(scope.db.parent.source, "/") + 1
|
||||||
to := strings.Index(scope.db.parent.source, "?")
|
to := strings.Index(scope.db.parent.source, "?")
|
||||||
if to == -1 {
|
if to == -1 {
|
||||||
@ -78,24 +78,24 @@ func (s *commonDialect) databaseName(scope *Scope) string {
|
|||||||
return scope.db.parent.source[from:to]
|
return scope.db.parent.source[from:to]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) HasTable(scope *Scope, tableName string) bool {
|
func (c commonDialect) HasTable(scope *Scope, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_name = ? AND table_schema = ?", tableName, c.databaseName(scope)).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (c commonDialect) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", c.databaseName(scope), tableName, columnName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (commonDialect) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
func (commonDialect) RemoveIndex(scope *Scope, indexName string) {
|
||||||
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
|
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,8 @@ func NewDialect(driver string) Dialect {
|
|||||||
switch driver {
|
switch driver {
|
||||||
case "postgres":
|
case "postgres":
|
||||||
d = &postgres{}
|
d = &postgres{}
|
||||||
|
case "foundation":
|
||||||
|
d = &foundation{}
|
||||||
case "mysql":
|
case "mysql":
|
||||||
d = &mysql{}
|
d = &mysql{}
|
||||||
case "sqlite3":
|
case "sqlite3":
|
||||||
|
8
field.go
8
field.go
@ -27,9 +27,13 @@ func (field *Field) Set(value interface{}) error {
|
|||||||
|
|
||||||
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
if scanner, ok := field.Field.Addr().Interface().(sql.Scanner); ok {
|
||||||
if v, ok := value.(reflect.Value); ok {
|
if v, ok := value.(reflect.Value); ok {
|
||||||
scanner.Scan(v.Interface())
|
if err := scanner.Scan(v.Interface()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
scanner.Scan(value)
|
if err := scanner.Scan(value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
reflectValue, ok := value.(reflect.Value)
|
reflectValue, ok := value.(reflect.Value)
|
||||||
|
78
foundation.go
Normal file
78
foundation.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type foundation struct {
|
||||||
|
commonDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func (foundation) BinVar(i int) string {
|
||||||
|
return fmt.Sprintf("$%v", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (foundation) SupportLastInsertId() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (foundation) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
return "boolean"
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||||
|
if autoIncrease {
|
||||||
|
return "serial"
|
||||||
|
}
|
||||||
|
return "int"
|
||||||
|
case reflect.Int64, reflect.Uint64:
|
||||||
|
if autoIncrease {
|
||||||
|
return "bigserial"
|
||||||
|
}
|
||||||
|
return "bigint"
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return "double"
|
||||||
|
case reflect.String:
|
||||||
|
if size > 0 && size < 65532 {
|
||||||
|
return fmt.Sprintf("varchar(%d)", size)
|
||||||
|
}
|
||||||
|
return "clob"
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := value.Interface().(time.Time); ok {
|
||||||
|
return "datetime"
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if _, ok := value.Interface().([]byte); ok {
|
||||||
|
return "blob"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic(fmt.Sprintf("invalid sql type %s (%s) for foundation", value.Type().Name(), value.Kind().String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f foundation) ReturningStr(tableName, key string) string {
|
||||||
|
return fmt.Sprintf("RETURNING %v.%v", f.Quote(tableName), key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (foundation) HasTable(scope *Scope, tableName string) bool {
|
||||||
|
var count int
|
||||||
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_schema = current_schema AND table_type = 'TABLE' AND table_name = ?", tableName).Row().Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (foundation) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||||
|
var count int
|
||||||
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = current_schema AND table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f foundation) RemoveIndex(scope *Scope, indexName string) {
|
||||||
|
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", f.Quote(indexName)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (foundation) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
|
var count int
|
||||||
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.indexes WHERE table_schema = current_schema AND table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
|
||||||
|
return count > 0
|
||||||
|
}
|
@ -1,48 +0,0 @@
|
|||||||
package gorm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type JoinTableHandler interface {
|
|
||||||
Table(*DB, *Relationship) string
|
|
||||||
Add(*DB, *Relationship, interface{}, interface{}) error
|
|
||||||
Delete(*DB, *Relationship) error
|
|
||||||
Scope(*DB, *Relationship) *DB
|
|
||||||
}
|
|
||||||
|
|
||||||
type defaultJoinTableHandler struct{}
|
|
||||||
|
|
||||||
func (s *defaultJoinTableHandler) Table(db *DB, relationship *Relationship) string {
|
|
||||||
return relationship.JoinTable
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *defaultJoinTableHandler) Add(db *DB, relationship *Relationship, foreignValue interface{}, associationValue interface{}) error {
|
|
||||||
scope := db.NewScope("")
|
|
||||||
quotedForeignDBName := scope.Quote(relationship.ForeignDBName)
|
|
||||||
quotedAssociationDBName := scope.Quote(relationship.AssociationForeignDBName)
|
|
||||||
table := s.Table(db, relationship)
|
|
||||||
|
|
||||||
sql := fmt.Sprintf(
|
|
||||||
"INSERT INTO %v (%v) SELECT ?,? %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v = ? AND %v = ?);",
|
|
||||||
scope.Quote(table),
|
|
||||||
strings.Join([]string{quotedForeignDBName, quotedAssociationDBName}, ","),
|
|
||||||
scope.Dialect().SelectFromDummyTable(),
|
|
||||||
scope.Quote(table),
|
|
||||||
quotedForeignDBName,
|
|
||||||
quotedAssociationDBName,
|
|
||||||
)
|
|
||||||
|
|
||||||
return db.Exec(sql, foreignValue, associationValue, foreignValue, associationValue).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *defaultJoinTableHandler) Delete(db *DB, relationship *Relationship) error {
|
|
||||||
return db.Table(s.Table(db, relationship)).Delete("").Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *defaultJoinTableHandler) Scope(db *DB, relationship *Relationship) *DB {
|
|
||||||
return db.Table(s.Table(db, relationship))
|
|
||||||
}
|
|
||||||
|
|
||||||
var DefaultJoinTableHandler = &defaultJoinTableHandler{}
|
|
155
join_table_handler.go
Normal file
155
join_table_handler.go
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
package gorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JoinTableHandlerInterface interface {
|
||||||
|
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||||
|
Table(db *DB) string
|
||||||
|
Add(db *DB, source interface{}, destination interface{}) error
|
||||||
|
Delete(db *DB, sources ...interface{}) error
|
||||||
|
JoinWith(db *DB, source interface{}) *DB
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinTableForeignKey struct {
|
||||||
|
DBName string
|
||||||
|
AssociationDBName string
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinTableSource struct {
|
||||||
|
ModelType reflect.Type
|
||||||
|
ForeignKeys []JoinTableForeignKey
|
||||||
|
}
|
||||||
|
|
||||||
|
type JoinTableHandler struct {
|
||||||
|
TableName string `sql:"-"`
|
||||||
|
Source JoinTableSource `sql:"-"`
|
||||||
|
Destination JoinTableSource `sql:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||||
|
s.TableName = tableName
|
||||||
|
|
||||||
|
s.Source = JoinTableSource{ModelType: source}
|
||||||
|
sourceScope := &Scope{Value: reflect.New(source).Interface()}
|
||||||
|
for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields {
|
||||||
|
if relationship.ForeignDBName == "" {
|
||||||
|
relationship.ForeignFieldName = source.Name() + primaryField.Name
|
||||||
|
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
|
||||||
|
}
|
||||||
|
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
|
||||||
|
DBName: relationship.ForeignDBName,
|
||||||
|
AssociationDBName: primaryField.DBName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Destination = JoinTableSource{ModelType: destination}
|
||||||
|
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
|
||||||
|
for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields {
|
||||||
|
if relationship.AssociationForeignDBName == "" {
|
||||||
|
relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name
|
||||||
|
relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName)
|
||||||
|
}
|
||||||
|
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
|
||||||
|
DBName: relationship.AssociationForeignDBName,
|
||||||
|
AssociationDBName: primaryField.DBName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JoinTableHandler) Table(*DB) string {
|
||||||
|
return s.TableName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} {
|
||||||
|
values := map[string]interface{}{}
|
||||||
|
|
||||||
|
for _, source := range sources {
|
||||||
|
scope := db.NewScope(source)
|
||||||
|
modelType := scope.GetModelStruct().ModelType
|
||||||
|
|
||||||
|
if s.Source.ModelType == modelType {
|
||||||
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
|
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
||||||
|
}
|
||||||
|
} else if s.Destination.ModelType == modelType {
|
||||||
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
|
values[foreignKey.DBName] = scope.Fields()[foreignKey.AssociationDBName].Field.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
|
||||||
|
scope := db.NewScope("")
|
||||||
|
searchMap := s.GetSearchMap(db, source1, source2)
|
||||||
|
|
||||||
|
var assignColumns, binVars, conditions []string
|
||||||
|
var values []interface{}
|
||||||
|
for key, value := range searchMap {
|
||||||
|
assignColumns = append(assignColumns, key)
|
||||||
|
binVars = append(binVars, `?`)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key)))
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, value := range values {
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
quotedTable := s.Table(db)
|
||||||
|
sql := fmt.Sprintf(
|
||||||
|
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
|
||||||
|
quotedTable,
|
||||||
|
strings.Join(assignColumns, ","),
|
||||||
|
strings.Join(binVars, ","),
|
||||||
|
scope.Dialect().SelectFromDummyTable(),
|
||||||
|
quotedTable,
|
||||||
|
strings.Join(conditions, " AND "),
|
||||||
|
)
|
||||||
|
|
||||||
|
return db.Exec(sql, values...).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
|
||||||
|
var conditions []string
|
||||||
|
var values []interface{}
|
||||||
|
|
||||||
|
for key, value := range s.GetSearchMap(db, sources...) {
|
||||||
|
conditions = append(conditions, fmt.Sprintf("%v = ?", key))
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
|
||||||
|
quotedTable := s.Table(db)
|
||||||
|
|
||||||
|
scope := db.NewScope(source)
|
||||||
|
modelType := scope.GetModelStruct().ModelType
|
||||||
|
var joinConditions []string
|
||||||
|
var queryConditions []string
|
||||||
|
var values []interface{}
|
||||||
|
if s.Source.ModelType == modelType {
|
||||||
|
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||||
|
destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||||
|
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, foreignKey := range s.Source.ForeignKeys {
|
||||||
|
queryConditions = append(queryConditions, fmt.Sprintf("%v.%v = ?", quotedTable, scope.Quote(foreignKey.DBName)))
|
||||||
|
values = append(values, scope.Fields()[foreignKey.AssociationDBName].Field.Interface())
|
||||||
|
}
|
||||||
|
return db.Joins(fmt.Sprintf("INNER JOIN %v ON %v", quotedTable, strings.Join(joinConditions, " AND "))).
|
||||||
|
Where(strings.Join(queryConditions, " AND "), values...)
|
||||||
|
} else {
|
||||||
|
db.Error = errors.New("wrong source type for join table handler")
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
}
|
@ -15,40 +15,37 @@ type Person struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PersonAddress struct {
|
type PersonAddress struct {
|
||||||
|
gorm.JoinTableHandler
|
||||||
PersonID int
|
PersonID int
|
||||||
AddressID int
|
AddressID int
|
||||||
DeletedAt time.Time
|
DeletedAt time.Time
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*PersonAddress) Table(db *gorm.DB, relationship *gorm.Relationship) string {
|
func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
|
||||||
return relationship.JoinTable
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*PersonAddress) Add(db *gorm.DB, relationship *gorm.Relationship, foreignValue interface{}, associationValue interface{}) error {
|
|
||||||
return db.Where(map[string]interface{}{
|
return db.Where(map[string]interface{}{
|
||||||
relationship.ForeignDBName: foreignValue,
|
"person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
|
||||||
relationship.AssociationForeignDBName: associationValue,
|
"address_id": db.NewScope(associationValue).PrimaryKeyValue(),
|
||||||
}).Assign(map[string]interface{}{
|
}).Assign(map[string]interface{}{
|
||||||
relationship.ForeignFieldName: foreignValue,
|
"person_id": foreignValue,
|
||||||
relationship.AssociationForeignFieldName: associationValue,
|
"address_id": associationValue,
|
||||||
"DeletedAt": gorm.Expr("NULL"),
|
"deleted_at": gorm.Expr("NULL"),
|
||||||
}).FirstOrCreate(&PersonAddress{}).Error
|
}).FirstOrCreate(&PersonAddress{}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*PersonAddress) Delete(db *gorm.DB, relationship *gorm.Relationship) error {
|
func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error {
|
||||||
return db.Delete(&PersonAddress{}).Error
|
return db.Delete(&PersonAddress{}).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pa *PersonAddress) Scope(db *gorm.DB, relationship *gorm.Relationship) *gorm.DB {
|
func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB {
|
||||||
table := pa.Table(db, relationship)
|
table := pa.Table(db)
|
||||||
return db.Table(table).Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
|
return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJoinTable(t *testing.T) {
|
func TestJoinTable(t *testing.T) {
|
||||||
DB.Exec("drop table person_addresses;")
|
DB.Exec("drop table person_addresses;")
|
||||||
DB.AutoMigrate(&Person{})
|
DB.AutoMigrate(&Person{})
|
||||||
DB.SetJoinTableHandler(&PersonAddress{}, "person_addresses")
|
DB.SetJoinTableHandler(&Person{}, "Addresses", &PersonAddress{})
|
||||||
|
|
||||||
address1 := &Address{Address1: "address 1"}
|
address1 := &Address{Address1: "address 1"}
|
||||||
address2 := &Address{Address1: "address 2"}
|
address2 := &Address{Address1: "address 2"}
|
||||||
|
41
main.go
41
main.go
@ -55,6 +55,9 @@ func Open(dialect string, args ...interface{}) (DB, error) {
|
|||||||
driver = value
|
driver = value
|
||||||
source = args[1].(string)
|
source = args[1].(string)
|
||||||
}
|
}
|
||||||
|
if driver == "foundation" {
|
||||||
|
driver = "postgres" // FoundationDB speaks a postgres-compatible protocol.
|
||||||
|
}
|
||||||
dbSql, err = sql.Open(driver, source)
|
dbSql, err = sql.Open(driver, source)
|
||||||
case sqlCommon:
|
case sqlCommon:
|
||||||
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String()
|
||||||
@ -193,14 +196,14 @@ func (s *DB) Assign(attrs ...interface{}) *DB {
|
|||||||
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
func (s *DB) First(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.clone().NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.InstanceSet("gorm:order_by_primary_key", "ASC").
|
return newScope.Set("gorm:order_by_primary_key", "ASC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
func (s *DB) Last(out interface{}, where ...interface{}) *DB {
|
||||||
newScope := s.clone().NewScope(out)
|
newScope := s.clone().NewScope(out)
|
||||||
newScope.Search.Limit(1)
|
newScope.Search.Limit(1)
|
||||||
return newScope.InstanceSet("gorm:order_by_primary_key", "DESC").
|
return newScope.Set("gorm:order_by_primary_key", "DESC").
|
||||||
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,6 +269,7 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
|
|||||||
func (s *DB) UpdateColumns(values interface{}) *DB {
|
func (s *DB) UpdateColumns(values interface{}) *DB {
|
||||||
return s.clone().NewScope(s.Value).
|
return s.clone().NewScope(s.Value).
|
||||||
Set("gorm:update_column", true).
|
Set("gorm:update_column", true).
|
||||||
|
Set("gorm:save_associations", false).
|
||||||
InstanceSet("gorm:update_interface", values).
|
InstanceSet("gorm:update_interface", values).
|
||||||
callCallbacks(s.parent.callback.updates).db
|
callCallbacks(s.parent.callback.updates).db
|
||||||
}
|
}
|
||||||
@ -470,29 +474,16 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DB) GetJoinTableHandler(table string) JoinTableHandler {
|
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
|
||||||
if s.parent.joinTableHandlers != nil {
|
for _, field := range s.NewScope(source).GetModelStruct().StructFields {
|
||||||
if joinTableHandler, ok := s.parent.joinTableHandlers[table]; ok {
|
if field.Name == column || field.DBName == column {
|
||||||
return joinTableHandler
|
if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" {
|
||||||
|
source := (&Scope{Value: source}).GetModelStruct().ModelType
|
||||||
|
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||||
|
handler.Setup(field.Relationship, many2many, source, destination)
|
||||||
|
field.Relationship.JoinTableHandler = handler
|
||||||
|
s.Table(handler.Table(s)).AutoMigrate(handler)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if joinTableHandler, ok := s.parent.joinTableHandlers["*"]; ok {
|
|
||||||
return joinTableHandler
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return DefaultJoinTableHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DB) SetJoinTableHandler(joinTableHandler JoinTableHandler, tables ...string) {
|
|
||||||
if s.parent.joinTableHandlers == nil {
|
|
||||||
s.parent.joinTableHandlers = map[string]JoinTableHandler{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tables) > 0 {
|
|
||||||
for _, table := range tables {
|
|
||||||
s.parent.joinTableHandlers[table] = joinTableHandler
|
|
||||||
s.Table(table).AutoMigrate(joinTableHandler)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
s.parent.joinTableHandlers["*"] = joinTableHandler
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
31
main_test.go
31
main_test.go
@ -36,6 +36,9 @@ func init() {
|
|||||||
case "postgres":
|
case "postgres":
|
||||||
fmt.Println("testing postgres...")
|
fmt.Println("testing postgres...")
|
||||||
DB, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
|
DB, err = gorm.Open("postgres", "user=gorm DB.name=gorm sslmode=disable")
|
||||||
|
case "foundation":
|
||||||
|
fmt.Println("testing foundation...")
|
||||||
|
DB, err = gorm.Open("foundation", "dbname=gorm port=15432 sslmode=disable")
|
||||||
case "mssql":
|
case "mssql":
|
||||||
fmt.Println("testing mssql...")
|
fmt.Println("testing mssql...")
|
||||||
DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433")
|
DB, err = gorm.Open("mssql", "server=SERVER_HERE;database=rogue;user id=USER_HERE;password=PW_HERE;port=1433")
|
||||||
@ -445,6 +448,14 @@ func TestHaving(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DialectHasTzSupport() bool {
|
||||||
|
// NB: mssql and FoundationDB do not support time zones.
|
||||||
|
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" || dialect == "foundation" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func TestTimeWithZone(t *testing.T) {
|
func TestTimeWithZone(t *testing.T) {
|
||||||
var format = "2006-01-02 15:04:05 -0700"
|
var format = "2006-01-02 15:04:05 -0700"
|
||||||
var times []time.Time
|
var times []time.Time
|
||||||
@ -456,26 +467,30 @@ func TestTimeWithZone(t *testing.T) {
|
|||||||
name := "time_with_zone_" + strconv.Itoa(index)
|
name := "time_with_zone_" + strconv.Itoa(index)
|
||||||
user := User{Name: name, Birthday: vtime}
|
user := User{Name: name, Birthday: vtime}
|
||||||
|
|
||||||
// TODO mssql does not support time zones
|
if !DialectHasTzSupport() {
|
||||||
if dialect := os.Getenv("GORM_DIALECT"); dialect == "mssql" {
|
// If our driver dialect doesn't support TZ's, just use UTC for everything here.
|
||||||
user.Birthday = vtime.UTC()
|
user.Birthday = vtime.UTC()
|
||||||
}
|
}
|
||||||
|
|
||||||
DB.Save(&user)
|
DB.Save(&user)
|
||||||
if user.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
|
expectedBirthday := "2013-02-18 17:51:49 +0000"
|
||||||
t.Errorf("User's birthday should not be changed after save")
|
foundBirthday := user.Birthday.UTC().Format(format)
|
||||||
|
if foundBirthday != expectedBirthday {
|
||||||
|
t.Errorf("User's birthday should not be changed after save for name=%s, expected bday=%+v but actual value=%+v", name, expectedBirthday, foundBirthday)
|
||||||
}
|
}
|
||||||
|
|
||||||
var findUser, findUser2, findUser3 User
|
var findUser, findUser2, findUser3 User
|
||||||
DB.First(&findUser, "name = ?", name)
|
DB.First(&findUser, "name = ?", name)
|
||||||
if findUser.Birthday.UTC().Format(format) != "2013-02-18 17:51:49 +0000" {
|
foundBirthday = findUser.Birthday.UTC().Format(format)
|
||||||
t.Errorf("User's birthday should not be changed after find")
|
if foundBirthday != expectedBirthday {
|
||||||
|
t.Errorf("User's birthday should not be changed after find for name=%s, expected bday=%+v but actual value=%+v or %+v", name, expectedBirthday, foundBirthday)
|
||||||
}
|
}
|
||||||
|
|
||||||
if DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
|
if DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(-time.Minute)).First(&findUser2).RecordNotFound() {
|
||||||
t.Errorf("User should be found")
|
t.Errorf("User should be found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !DB.Where("id = ? AND birthday >= ?", findUser.Id, vtime.Add(time.Minute)).First(&findUser3).RecordNotFound() {
|
if !DB.Where("id = ? AND birthday >= ?", findUser.Id, user.Birthday.Add(time.Minute)).First(&findUser3).RecordNotFound() {
|
||||||
t.Errorf("User should not be found")
|
t.Errorf("User should not be found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
262
model_struct.go
262
model_struct.go
@ -60,7 +60,7 @@ type Relationship struct {
|
|||||||
ForeignDBName string
|
ForeignDBName string
|
||||||
AssociationForeignFieldName string
|
AssociationForeignFieldName string
|
||||||
AssociationForeignDBName string
|
AssociationForeignDBName string
|
||||||
JoinTable string
|
JoinTableHandler JoinTableHandlerInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
var pluralMapKeys = []*regexp.Regexp{regexp.MustCompile("ch$"), regexp.MustCompile("ss$"), regexp.MustCompile("sh$"), regexp.MustCompile("day$"), regexp.MustCompile("y$"), regexp.MustCompile("x$"), regexp.MustCompile("([^s])s?$")}
|
||||||
@ -146,143 +146,147 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, field := range fields {
|
defer func() {
|
||||||
if !field.IsIgnored {
|
for _, field := range fields {
|
||||||
fieldStruct := field.Struct
|
if !field.IsIgnored {
|
||||||
fieldType, indirectType := fieldStruct.Type, fieldStruct.Type
|
fieldStruct := field.Struct
|
||||||
if indirectType.Kind() == reflect.Ptr {
|
fieldType, indirectType := fieldStruct.Type, fieldStruct.Type
|
||||||
indirectType = indirectType.Elem()
|
if indirectType.Kind() == reflect.Ptr {
|
||||||
}
|
indirectType = indirectType.Elem()
|
||||||
|
|
||||||
if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner {
|
|
||||||
field.IsScanner, field.IsNormal = true, true
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
|
|
||||||
field.IsNormal = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if !field.IsNormal {
|
|
||||||
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
|
||||||
toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
|
|
||||||
|
|
||||||
getForeignField := func(column string, fields []*StructField) *StructField {
|
|
||||||
for _, field := range fields {
|
|
||||||
if field.Name == column || field.DBName == ToDBName(column) {
|
|
||||||
return field
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var relationship = &Relationship{}
|
if _, isScanner := reflect.New(fieldType).Interface().(sql.Scanner); isScanner {
|
||||||
|
field.IsScanner, field.IsNormal = true, true
|
||||||
foreignKey := gormSettings["FOREIGNKEY"]
|
|
||||||
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
|
|
||||||
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
|
|
||||||
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
|
|
||||||
relationship.ForeignFieldName = polymorphicField.Name
|
|
||||||
relationship.ForeignDBName = polymorphicField.DBName
|
|
||||||
relationship.PolymorphicType = polymorphicType.Name
|
|
||||||
relationship.PolymorphicDBName = polymorphicType.DBName
|
|
||||||
polymorphicType.IsForeignKey = true
|
|
||||||
polymorphicField.IsForeignKey = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch indirectType.Kind() {
|
if _, isTime := reflect.New(indirectType).Interface().(*time.Time); isTime {
|
||||||
case reflect.Slice:
|
|
||||||
elemType := indirectType.Elem()
|
|
||||||
if elemType.Kind() == reflect.Ptr {
|
|
||||||
elemType = elemType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if elemType.Kind() == reflect.Struct {
|
|
||||||
if foreignKey == "" {
|
|
||||||
foreignKey = scopeType.Name() + "Id"
|
|
||||||
}
|
|
||||||
|
|
||||||
if many2many := gormSettings["MANY2MANY"]; many2many != "" {
|
|
||||||
relationship.Kind = "many_to_many"
|
|
||||||
relationship.JoinTable = many2many
|
|
||||||
|
|
||||||
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
|
|
||||||
if associationForeignKey == "" {
|
|
||||||
associationForeignKey = elemType.Name() + "Id"
|
|
||||||
}
|
|
||||||
|
|
||||||
relationship.ForeignFieldName = foreignKey
|
|
||||||
relationship.ForeignDBName = ToDBName(foreignKey)
|
|
||||||
relationship.AssociationForeignFieldName = associationForeignKey
|
|
||||||
relationship.AssociationForeignDBName = ToDBName(associationForeignKey)
|
|
||||||
field.Relationship = relationship
|
|
||||||
} else {
|
|
||||||
relationship.Kind = "has_many"
|
|
||||||
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
|
||||||
relationship.ForeignFieldName = foreignField.Name
|
|
||||||
relationship.ForeignDBName = foreignField.DBName
|
|
||||||
foreignField.IsForeignKey = true
|
|
||||||
field.Relationship = relationship
|
|
||||||
} else if relationship.ForeignFieldName != "" {
|
|
||||||
field.Relationship = relationship
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
field.IsNormal = true
|
|
||||||
}
|
|
||||||
case reflect.Struct:
|
|
||||||
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
|
||||||
for _, toField := range toScope.GetStructFields() {
|
|
||||||
toField = toField.clone()
|
|
||||||
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
|
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, toField)
|
|
||||||
if toField.IsPrimaryKey {
|
|
||||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
belongsToForeignKey := foreignKey
|
|
||||||
if belongsToForeignKey == "" {
|
|
||||||
belongsToForeignKey = field.Name + "Id"
|
|
||||||
}
|
|
||||||
|
|
||||||
if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil {
|
|
||||||
relationship.Kind = "belongs_to"
|
|
||||||
relationship.ForeignFieldName = foreignField.Name
|
|
||||||
relationship.ForeignDBName = foreignField.DBName
|
|
||||||
foreignField.IsForeignKey = true
|
|
||||||
field.Relationship = relationship
|
|
||||||
} else {
|
|
||||||
if foreignKey == "" {
|
|
||||||
foreignKey = modelStruct.ModelType.Name() + "Id"
|
|
||||||
}
|
|
||||||
relationship.Kind = "has_one"
|
|
||||||
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
|
||||||
relationship.ForeignFieldName = foreignField.Name
|
|
||||||
relationship.ForeignDBName = foreignField.DBName
|
|
||||||
foreignField.IsForeignKey = true
|
|
||||||
field.Relationship = relationship
|
|
||||||
} else if relationship.ForeignFieldName != "" {
|
|
||||||
field.Relationship = relationship
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
field.IsNormal = true
|
field.IsNormal = true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if field.IsNormal {
|
if !field.IsNormal {
|
||||||
if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
|
gormSettings := parseTagSetting(field.Tag.Get("gorm"))
|
||||||
field.IsPrimaryKey = true
|
toScope := scope.New(reflect.New(fieldStruct.Type).Interface())
|
||||||
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
|
||||||
|
getForeignField := func(column string, fields []*StructField) *StructField {
|
||||||
|
for _, field := range fields {
|
||||||
|
if field.Name == column || field.DBName == ToDBName(column) {
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var relationship = &Relationship{}
|
||||||
|
|
||||||
|
foreignKey := gormSettings["FOREIGNKEY"]
|
||||||
|
if polymorphic := gormSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||||
|
if polymorphicField := getForeignField(polymorphic+"Id", toScope.GetStructFields()); polymorphicField != nil {
|
||||||
|
if polymorphicType := getForeignField(polymorphic+"Type", toScope.GetStructFields()); polymorphicType != nil {
|
||||||
|
relationship.ForeignFieldName = polymorphicField.Name
|
||||||
|
relationship.ForeignDBName = polymorphicField.DBName
|
||||||
|
relationship.PolymorphicType = polymorphicType.Name
|
||||||
|
relationship.PolymorphicDBName = polymorphicType.DBName
|
||||||
|
polymorphicType.IsForeignKey = true
|
||||||
|
polymorphicField.IsForeignKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch indirectType.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
elemType := indirectType.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
if foreignKey == "" {
|
||||||
|
foreignKey = scopeType.Name() + "Id"
|
||||||
|
}
|
||||||
|
|
||||||
|
if many2many := gormSettings["MANY2MANY"]; many2many != "" {
|
||||||
|
relationship.Kind = "many_to_many"
|
||||||
|
associationForeignKey := gormSettings["ASSOCIATIONFOREIGNKEY"]
|
||||||
|
if associationForeignKey == "" {
|
||||||
|
associationForeignKey = elemType.Name() + "Id"
|
||||||
|
}
|
||||||
|
|
||||||
|
relationship.ForeignFieldName = foreignKey
|
||||||
|
relationship.ForeignDBName = ToDBName(foreignKey)
|
||||||
|
relationship.AssociationForeignFieldName = associationForeignKey
|
||||||
|
relationship.AssociationForeignDBName = ToDBName(associationForeignKey)
|
||||||
|
|
||||||
|
joinTableHandler := JoinTableHandler{}
|
||||||
|
joinTableHandler.Setup(relationship, many2many, scopeType, elemType)
|
||||||
|
relationship.JoinTableHandler = &joinTableHandler
|
||||||
|
field.Relationship = relationship
|
||||||
|
} else {
|
||||||
|
relationship.Kind = "has_many"
|
||||||
|
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
||||||
|
relationship.ForeignFieldName = foreignField.Name
|
||||||
|
relationship.ForeignDBName = foreignField.DBName
|
||||||
|
foreignField.IsForeignKey = true
|
||||||
|
field.Relationship = relationship
|
||||||
|
} else if relationship.ForeignFieldName != "" {
|
||||||
|
field.Relationship = relationship
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
field.IsNormal = true
|
||||||
|
}
|
||||||
|
case reflect.Struct:
|
||||||
|
if _, ok := gormSettings["EMBEDDED"]; ok || fieldStruct.Anonymous {
|
||||||
|
for _, toField := range toScope.GetStructFields() {
|
||||||
|
toField = toField.clone()
|
||||||
|
toField.Names = append([]string{fieldStruct.Name}, toField.Names...)
|
||||||
|
modelStruct.StructFields = append(modelStruct.StructFields, toField)
|
||||||
|
if toField.IsPrimaryKey {
|
||||||
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, toField)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
belongsToForeignKey := foreignKey
|
||||||
|
if belongsToForeignKey == "" {
|
||||||
|
belongsToForeignKey = field.Name + "Id"
|
||||||
|
}
|
||||||
|
|
||||||
|
if foreignField := getForeignField(belongsToForeignKey, fields); foreignField != nil {
|
||||||
|
relationship.Kind = "belongs_to"
|
||||||
|
relationship.ForeignFieldName = foreignField.Name
|
||||||
|
relationship.ForeignDBName = foreignField.DBName
|
||||||
|
foreignField.IsForeignKey = true
|
||||||
|
field.Relationship = relationship
|
||||||
|
} else {
|
||||||
|
if foreignKey == "" {
|
||||||
|
foreignKey = modelStruct.ModelType.Name() + "Id"
|
||||||
|
}
|
||||||
|
relationship.Kind = "has_one"
|
||||||
|
if foreignField := getForeignField(foreignKey, toScope.GetStructFields()); foreignField != nil {
|
||||||
|
relationship.ForeignFieldName = foreignField.Name
|
||||||
|
relationship.ForeignDBName = foreignField.DBName
|
||||||
|
foreignField.IsForeignKey = true
|
||||||
|
field.Relationship = relationship
|
||||||
|
} else if relationship.ForeignFieldName != "" {
|
||||||
|
field.Relationship = relationship
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
field.IsNormal = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if field.IsNormal {
|
||||||
|
if len(modelStruct.PrimaryFields) == 0 && field.DBName == "id" {
|
||||||
|
field.IsPrimaryKey = true
|
||||||
|
modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
||||||
}
|
}
|
||||||
modelStruct.StructFields = append(modelStruct.StructFields, field)
|
}()
|
||||||
}
|
|
||||||
|
|
||||||
modelStructs[scopeType] = &modelStruct
|
modelStructs[scopeType] = &modelStruct
|
||||||
|
|
||||||
@ -308,7 +312,7 @@ func (scope *Scope) generateSqlTag(field *StructField) string {
|
|||||||
|
|
||||||
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
|
additionalType := sqlSettings["NOT NULL"] + " " + sqlSettings["UNIQUE"]
|
||||||
if value, ok := sqlSettings["DEFAULT"]; ok {
|
if value, ok := sqlSettings["DEFAULT"]; ok {
|
||||||
additionalType = additionalType + "DEFAULT " + value
|
additionalType = additionalType + " DEFAULT " + value
|
||||||
}
|
}
|
||||||
|
|
||||||
if field.IsScanner {
|
if field.IsScanner {
|
||||||
|
38
mssql.go
38
mssql.go
@ -7,21 +7,15 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mssql struct{}
|
type mssql struct {
|
||||||
|
commonDialect
|
||||||
func (s *mssql) BinVar(i int) string {
|
|
||||||
return "$$" // ?
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) SupportLastInsertId() bool {
|
func (mssql) HasTop() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) HasTop() bool {
|
func (mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "bit"
|
return "bit"
|
||||||
@ -57,19 +51,7 @@ func (s *mssql) SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
|||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mssql", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) ReturningStr(tableName, key string) string {
|
func (mssql) databaseName(scope *Scope) string {
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mssql) SelectFromDummyTable() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mssql) Quote(key string) string {
|
|
||||||
return fmt.Sprintf(" \"%s\"", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mssql) databaseName(scope *Scope) string {
|
|
||||||
dbStr := strings.Split(scope.db.parent.source, ";")
|
dbStr := strings.Split(scope.db.parent.source, ";")
|
||||||
for _, value := range dbStr {
|
for _, value := range dbStr {
|
||||||
s := strings.Split(value, "=")
|
s := strings.Split(value, "=")
|
||||||
@ -80,24 +62,20 @@ func (s *mssql) databaseName(scope *Scope) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) HasTable(scope *Scope, tableName string) bool {
|
func (s mssql) HasTable(scope *Scope, tableName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (s mssql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_catalog = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (mssql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)", indexName, tableName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mssql) RemoveIndex(scope *Scope, indexName string) {
|
|
||||||
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
|
|
||||||
}
|
|
||||||
|
58
mysql.go
58
mysql.go
@ -3,25 +3,14 @@ package gorm
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mysql struct{}
|
type mysql struct {
|
||||||
|
commonDialect
|
||||||
func (s *mysql) BinVar(i int) string {
|
|
||||||
return "$$" // ?
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mysql) SupportLastInsertId() bool {
|
func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mysql) HasTop() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
@ -57,45 +46,10 @@ func (s *mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string
|
|||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mysql) ReturningStr(tableName, key string) string {
|
func (mysql) Quote(key string) string {
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mysql) SelectFromDummyTable() string {
|
|
||||||
return "FROM DUAL"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mysql) Quote(key string) string {
|
|
||||||
return fmt.Sprintf("`%s`", key)
|
return fmt.Sprintf("`%s`", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mysql) databaseName(scope *Scope) string {
|
func (mysql) SelectFromDummyTable() string {
|
||||||
from := strings.Index(scope.db.parent.source, "/") + 1
|
return "FROM DUAL"
|
||||||
to := strings.Index(scope.db.parent.source, "?")
|
|
||||||
if to == -1 {
|
|
||||||
to = len(scope.db.parent.source)
|
|
||||||
}
|
|
||||||
return scope.db.parent.source[from:to]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mysql) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var count int
|
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables where table_name = ? AND table_schema = ?", tableName, s.databaseName(scope)).Row().Scan(&count)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mysql) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
|
||||||
var count int
|
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.databaseName(scope), tableName, columnName).Row().Scan(&count)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mysql) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
|
||||||
var count int
|
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS where table_name = ? AND index_name = ?", tableName, indexName).Row().Scan(&count)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *mysql) RemoveIndex(scope *Scope, indexName string) {
|
|
||||||
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v ON %v", indexName, scope.QuotedTableName()))
|
|
||||||
}
|
}
|
||||||
|
29
postgres.go
29
postgres.go
@ -11,21 +11,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type postgres struct {
|
type postgres struct {
|
||||||
|
commonDialect
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) BinVar(i int) string {
|
func (postgres) BinVar(i int) string {
|
||||||
return fmt.Sprintf("$%v", i)
|
return fmt.Sprintf("$%v", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) SupportLastInsertId() bool {
|
func (postgres) SupportLastInsertId() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) HasTop() bool {
|
func (postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
@ -62,35 +59,27 @@ func (s *postgres) SqlTag(value reflect.Value, size int, autoIncrease bool) stri
|
|||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for postgres", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) ReturningStr(tableName, key string) string {
|
func (s postgres) ReturningStr(tableName, key string) string {
|
||||||
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
|
return fmt.Sprintf("RETURNING %v.%v", s.Quote(tableName), key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) SelectFromDummyTable() string {
|
func (postgres) HasTable(scope *Scope, tableName string) bool {
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *postgres) Quote(key string) string {
|
|
||||||
return fmt.Sprintf("\"%s\"", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *postgres) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_type = 'BASE TABLE'", tableName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (postgres) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = ? AND column_name = ?", tableName, columnName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) RemoveIndex(scope *Scope, indexName string) {
|
func (postgres) RemoveIndex(scope *Scope, indexName string) {
|
||||||
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (postgres) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ?", tableName, indexName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
|
8
scope.go
8
scope.go
@ -398,3 +398,11 @@ func (scope *Scope) changeableField(field *Field) bool {
|
|||||||
|
|
||||||
return !field.IsIgnored
|
return !field.IsIgnored
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (scope *Scope) shouldSaveAssociations() bool {
|
||||||
|
saveAssociations, ok := scope.Get("gorm:save_associations")
|
||||||
|
if ok && !saveAssociations.(bool) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
@ -3,7 +3,6 @@ package gorm
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -360,7 +359,7 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope {
|
|||||||
dest := reflect.Indirect(reflect.ValueOf(value))
|
dest := reflect.Indirect(reflect.ValueOf(value))
|
||||||
scope.Search.Select(column)
|
scope.Search.Select(column)
|
||||||
if dest.Kind() != reflect.Slice {
|
if dest.Kind() != reflect.Slice {
|
||||||
scope.Err(errors.New("results should be a slice"))
|
scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind()))
|
||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -402,18 +401,8 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
|||||||
if fromField != nil {
|
if fromField != nil {
|
||||||
if relationship := fromField.Relationship; relationship != nil {
|
if relationship := fromField.Relationship; relationship != nil {
|
||||||
if relationship.Kind == "many_to_many" {
|
if relationship.Kind == "many_to_many" {
|
||||||
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
quotedJoinTable := scope.Quote(joinTableHandler.Table(scope.db, relationship))
|
scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error)
|
||||||
|
|
||||||
joinSql := fmt.Sprintf(
|
|
||||||
"INNER JOIN %v ON %v.%v = %v.%v",
|
|
||||||
quotedJoinTable,
|
|
||||||
quotedJoinTable,
|
|
||||||
scope.Quote(relationship.AssociationForeignDBName),
|
|
||||||
toScope.QuotedTableName(),
|
|
||||||
scope.Quote(toScope.PrimaryKey()))
|
|
||||||
whereSql := fmt.Sprintf("%v.%v = ?", quotedJoinTable, scope.Quote(relationship.ForeignDBName))
|
|
||||||
scope.Err(toScope.db.Joins(joinSql).Where(whereSql, scope.PrimaryKeyValue()).Find(value).Error)
|
|
||||||
} else if relationship.Kind == "belongs_to" {
|
} else if relationship.Kind == "belongs_to" {
|
||||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||||
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()
|
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()
|
||||||
@ -443,9 +432,9 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (scope *Scope) createJoinTable(field *StructField) {
|
func (scope *Scope) createJoinTable(field *StructField) {
|
||||||
if relationship := field.Relationship; relationship != nil && relationship.JoinTable != "" {
|
if relationship := field.Relationship; relationship != nil && relationship.JoinTableHandler != nil {
|
||||||
joinTableHandler := scope.db.GetJoinTableHandler(relationship.JoinTable)
|
joinTableHandler := relationship.JoinTableHandler
|
||||||
joinTable := joinTableHandler.Table(scope.db, relationship)
|
joinTable := joinTableHandler.Table(scope.db)
|
||||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
|
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
|
||||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||||
@ -455,7 +444,7 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
|||||||
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
|
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
|
||||||
).Error)
|
).Error)
|
||||||
}
|
}
|
||||||
scope.NewDB().Table(joinTable).AutoMigrate()
|
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -570,7 +559,7 @@ func (scope *Scope) autoIndex() *Scope {
|
|||||||
if name == "UNIQUE_INDEX" {
|
if name == "UNIQUE_INDEX" {
|
||||||
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
|
name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName)
|
||||||
}
|
}
|
||||||
uniqueIndexes[name] = append(indexes[name], field.DBName)
|
uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
36
sqlite3.go
36
sqlite3.go
@ -6,21 +6,11 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sqlite3 struct{}
|
type sqlite3 struct {
|
||||||
|
commonDialect
|
||||||
func (s *sqlite3) BinVar(i int) string {
|
|
||||||
return "$$" // ?
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) SupportLastInsertId() bool {
|
func (sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sqlite3) HasTop() bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
|
||||||
switch value.Kind() {
|
switch value.Kind() {
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
return "bool"
|
return "bool"
|
||||||
@ -50,36 +40,24 @@ func (s *sqlite3) SqlTag(value reflect.Value, size int, autoIncrease bool) strin
|
|||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite3", value.Type().Name(), value.Kind().String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) ReturningStr(tableName, key string) string {
|
func (sqlite3) HasTable(scope *Scope, tableName string) bool {
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sqlite3) SelectFromDummyTable() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sqlite3) Quote(key string) string {
|
|
||||||
return fmt.Sprintf("\"%s\"", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *sqlite3) HasTable(scope *Scope, tableName string) bool {
|
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count)
|
scope.NewDB().Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
func (sqlite3) HasColumn(scope *Scope, tableName string, columnName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count)
|
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%(\"%v\" %%' OR sql LIKE '%%,\"%v\" %%' OR sql LIKE '%%( %v %%' OR sql LIKE '%%, %v %%');\n", columnName, columnName, columnName, columnName), tableName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
func (sqlite3) HasIndex(scope *Scope, tableName string, indexName string) bool {
|
||||||
var count int
|
var count int
|
||||||
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count)
|
scope.NewDB().Raw(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Row().Scan(&count)
|
||||||
return count > 0
|
return count > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *sqlite3) RemoveIndex(scope *Scope, indexName string) {
|
func (sqlite3) RemoveIndex(scope *Scope, indexName string) {
|
||||||
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
scope.NewDB().Exec(fmt.Sprintf("DROP INDEX %v", indexName))
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
dialects=("postgres" "mysql" "sqlite")
|
dialects=("postgres" "foundation" "mysql" "sqlite")
|
||||||
|
|
||||||
for dialect in "${dialects[@]}" ; do
|
for dialect in "${dialects[@]}" ; do
|
||||||
GORM_DIALECT=${dialect} go test
|
GORM_DIALECT=${dialect} go test
|
||||||
|
@ -382,3 +382,32 @@ func TestOmitWithUpdateColumn(t *testing.T) {
|
|||||||
t.Errorf("Should omit name column when update user")
|
t.Errorf("Should omit name column when update user")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateColumnsSkipsAssociations(t *testing.T) {
|
||||||
|
user := getPreparedUser("update_columns_user", "special_role")
|
||||||
|
user.Age = 99
|
||||||
|
address1 := "first street"
|
||||||
|
user.BillingAddress = Address{Address1: address1}
|
||||||
|
DB.Save(user)
|
||||||
|
|
||||||
|
// Update a single field of the user and verify that the changed address is not stored.
|
||||||
|
newAge := int64(100)
|
||||||
|
user.BillingAddress.Address1 = "second street"
|
||||||
|
db := DB.Model(user).UpdateColumns(User{Age: newAge})
|
||||||
|
if db.RowsAffected != 1 {
|
||||||
|
t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", DB.RowsAffected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that Age now=`newAge`.
|
||||||
|
freshUser := &User{Id: user.Id}
|
||||||
|
DB.First(freshUser)
|
||||||
|
if freshUser.Age != newAge {
|
||||||
|
t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, freshUser.Age)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that user's BillingAddress.Address1 is not changed and is still "first street".
|
||||||
|
DB.First(&freshUser.BillingAddress, freshUser.BillingAddressID)
|
||||||
|
if freshUser.BillingAddress.Address1 != address1 {
|
||||||
|
t.Errorf("Expected user's BillingAddress.Address1=%s to remain unchanged after UpdateColumns invocation, but BillingAddress.Address1=%s", address1, freshUser.BillingAddress.Address1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user