commit
36ea322ee8
67
README.md
67
README.md
@ -39,7 +39,7 @@ type User struct {
|
||||
Num int `sql:"AUTO_INCREMENT"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
|
||||
Emails []Email // One-To-Many relationship (has many)
|
||||
BillingAddress Address // One-To-One relationship (has one)
|
||||
@ -84,6 +84,21 @@ type User struct{} // struct User's database table name is "users" by default, w
|
||||
* Use `CreatedAt` to store record's created time if field exists
|
||||
* Use `UpdatedAt` to store record's updated time if field exists
|
||||
* Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete)
|
||||
* Gorm provide a default model struct, you could embed it in your struct
|
||||
|
||||
```go
|
||||
type Model struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
}
|
||||
```
|
||||
|
||||
## Initialize Database
|
||||
|
||||
@ -102,7 +117,7 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
|
||||
// You can also use an existing database connection handle
|
||||
// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
|
||||
// db := gorm.Open("postgres", dbSql)
|
||||
// db, _ := gorm.Open("postgres", dbSql)
|
||||
|
||||
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
|
||||
db.DB()
|
||||
@ -332,6 +347,13 @@ db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users)
|
||||
//// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to
|
||||
```
|
||||
|
||||
#### Nested Preloading
|
||||
|
||||
```go
|
||||
db.Preload("Orders.OrderItems").Find(&users)
|
||||
db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users)
|
||||
```
|
||||
|
||||
## Update
|
||||
|
||||
```go
|
||||
@ -823,23 +845,56 @@ for rows.Next() {
|
||||
}
|
||||
|
||||
db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results)
|
||||
|
||||
// find a user by email address
|
||||
db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user)
|
||||
|
||||
// find all email addresses for a user
|
||||
db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails)
|
||||
```
|
||||
|
||||
## Transactions
|
||||
|
||||
All individual save and delete operations are run in a transaction by default.
|
||||
To perform a set of operations within a transaction, the general flow is as below.
|
||||
The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction.
|
||||
(Note that all individual save and delete operations are run in a transaction by default.)
|
||||
|
||||
```go
|
||||
// begin
|
||||
tx := db.Begin()
|
||||
|
||||
// rollback
|
||||
// do some database operations (use 'tx' from this point, not 'db')
|
||||
tx.Create(...)
|
||||
...
|
||||
|
||||
// rollback in case of error
|
||||
tx.Rollback()
|
||||
|
||||
// commit
|
||||
// Or commit if all is ok
|
||||
tx.Commit()
|
||||
```
|
||||
|
||||
### A Specific Example
|
||||
```
|
||||
func CreateAnimals(db *gorm.DB) err {
|
||||
tx := db.Begin()
|
||||
// Note the use of tx as the database handle once you are within a transaction
|
||||
|
||||
if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
## Scopes
|
||||
|
||||
```go
|
||||
@ -1069,7 +1124,7 @@ type Product struct {
|
||||
// 2nd param : destination table(id)
|
||||
// 3rd param : ONDELETE
|
||||
// 4th param : ONUPDATE
|
||||
db.Model(&User{}).AddForeignKey("user_id", "destination_table(id)", "CASCADE", "RESTRICT")
|
||||
db.Model(&User{}).AddForeignKey("role_id", "roles", "CASCADE", "RESTRICT")
|
||||
|
||||
// Add index
|
||||
db.Model(&User{}).AddIndex("idx_user_name", "name")
|
||||
|
@ -40,6 +40,7 @@ func (association *Association) Append(values ...interface{}) *Association {
|
||||
association.setErr(errors.New("invalid association type"))
|
||||
}
|
||||
}
|
||||
scope.Search.Select(association.Column)
|
||||
scope.callCallbacks(scope.db.parent.callback.updates)
|
||||
return association.setErr(scope.db.Error)
|
||||
}
|
||||
@ -77,7 +78,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
|
||||
if relationship.Kind == "many_to_many" {
|
||||
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
|
||||
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
|
||||
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
|
||||
leftValues := reflect.Zero(association.Field.Field.Type())
|
||||
for i := 0; i < association.Field.Field.Len(); i++ {
|
||||
value := association.Field.Field.Index(i)
|
||||
@ -130,9 +131,11 @@ func (association *Association) Replace(values ...interface{}) *Association {
|
||||
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
|
||||
}
|
||||
|
||||
if len(addedPrimaryKeys) > 0 {
|
||||
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
|
||||
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
|
||||
association.setErr(relationship.JoinTableHandler.Delete(query, relationship))
|
||||
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship))
|
||||
}
|
||||
} else {
|
||||
association.setErr(errors.New("replace only support many to many"))
|
||||
}
|
||||
@ -145,7 +148,7 @@ func (association *Association) Clear() *Association {
|
||||
if relationship.Kind == "many_to_many" {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
|
||||
query := scope.NewDB().Where(sql, association.PrimaryKey)
|
||||
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
|
||||
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
|
||||
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
|
||||
} else {
|
||||
association.setErr(err)
|
||||
@ -163,7 +166,7 @@ func (association *Association) Count() int {
|
||||
newScope := scope.New(association.Field.Field.Interface())
|
||||
|
||||
if relationship.Kind == "many_to_many" {
|
||||
relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
|
||||
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
|
||||
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
|
||||
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
|
||||
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)
|
||||
|
@ -148,7 +148,6 @@ func TestManyToMany(t *testing.T) {
|
||||
t.Errorf("Query many to many relations")
|
||||
}
|
||||
|
||||
newLanguages = []Language{}
|
||||
DB.Model(&user).Association("Languages").Find(&newLanguages)
|
||||
if len(newLanguages) != len([]string{"ZH", "EN"}) {
|
||||
t.Errorf("Should be able to find many to many relations")
|
||||
@ -194,7 +193,6 @@ func TestManyToMany(t *testing.T) {
|
||||
t.Errorf("Language EE should not be deleted")
|
||||
}
|
||||
|
||||
languages = []Language{}
|
||||
DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
|
||||
|
||||
user2 := User{Name: "Many2Many_User2", Languages: languages}
|
||||
|
10
callback.go
10
callback.go
@ -9,6 +9,7 @@ type callback struct {
|
||||
updates []*func(scope *Scope)
|
||||
deletes []*func(scope *Scope)
|
||||
queries []*func(scope *Scope)
|
||||
rowQueries []*func(scope *Scope)
|
||||
processors []*callbackProcessor
|
||||
}
|
||||
|
||||
@ -55,6 +56,10 @@ func (c *callback) Query() *callbackProcessor {
|
||||
return c.addProcessor("query")
|
||||
}
|
||||
|
||||
func (c *callback) RowQuery() *callbackProcessor {
|
||||
return c.addProcessor("row_query")
|
||||
}
|
||||
|
||||
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
|
||||
cp.before = name
|
||||
return cp
|
||||
@ -168,7 +173,7 @@ func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
|
||||
}
|
||||
|
||||
func (c *callback) sort() {
|
||||
creates, updates, deletes, queries := []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}
|
||||
var creates, updates, deletes, queries, rowQueries []*callbackProcessor
|
||||
|
||||
for _, processor := range c.processors {
|
||||
switch processor.typ {
|
||||
@ -180,6 +185,8 @@ func (c *callback) sort() {
|
||||
deletes = append(deletes, processor)
|
||||
case "query":
|
||||
queries = append(queries, processor)
|
||||
case "row_query":
|
||||
rowQueries = append(rowQueries, processor)
|
||||
}
|
||||
}
|
||||
|
||||
@ -187,6 +194,7 @@ func (c *callback) sort() {
|
||||
c.updates = sortProcessors(updates)
|
||||
c.deletes = sortProcessors(deletes)
|
||||
c.queries = sortProcessors(queries)
|
||||
c.rowQueries = sortProcessors(rowQueries)
|
||||
}
|
||||
|
||||
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
|
||||
|
@ -70,18 +70,24 @@ func Create(scope *Scope) {
|
||||
id, err := result.LastInsertId()
|
||||
if scope.Err(err) == nil {
|
||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||
if primaryField != nil {
|
||||
if primaryField != nil && primaryField.IsBlank {
|
||||
scope.Err(scope.SetColumn(primaryField, id))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if primaryField == nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
|
||||
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
|
||||
scope.db.RowsAffected, _ = results.RowsAffected()
|
||||
} else {
|
||||
scope.Err(err)
|
||||
}
|
||||
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
|
||||
} else {
|
||||
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
|
||||
scope.db.RowsAffected = 1
|
||||
} else {
|
||||
scope.Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -16,20 +16,22 @@ func Query(scope *Scope) {
|
||||
destType reflect.Type
|
||||
)
|
||||
|
||||
var dest = scope.IndirectValue()
|
||||
if value, ok := scope.InstanceGet("gorm:query_destination"); ok {
|
||||
dest = reflect.Indirect(reflect.ValueOf(value))
|
||||
}
|
||||
|
||||
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
|
||||
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
|
||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy))
|
||||
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy))
|
||||
}
|
||||
}
|
||||
|
||||
var dest = scope.IndirectValue()
|
||||
if value, ok := scope.Get("gorm:query_destination"); ok {
|
||||
dest = reflect.Indirect(reflect.ValueOf(value))
|
||||
}
|
||||
|
||||
if kind := dest.Kind(); kind == reflect.Slice {
|
||||
isSlice = true
|
||||
destType = dest.Type().Elem()
|
||||
dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType))))
|
||||
|
||||
if destType.Kind() == reflect.Ptr {
|
||||
isPtr = true
|
||||
destType = destType.Elem()
|
||||
|
@ -55,7 +55,7 @@ func SaveAfterAssociations(scope *Scope) {
|
||||
scope.Err(newDB.Save(elem).Error)
|
||||
|
||||
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
|
||||
scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value))
|
||||
scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
@ -64,6 +64,7 @@ func Update(scope *Scope) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(sqls) > 0 {
|
||||
scope.Raw(fmt.Sprintf(
|
||||
"UPDATE %v SET %v %v",
|
||||
scope.QuotedTableName(),
|
||||
@ -73,6 +74,7 @@ func Update(scope *Scope) {
|
||||
scope.Exec()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterUpdate(scope *Scope) {
|
||||
if _, ok := scope.Get("gorm:update_column"); !ok {
|
||||
|
@ -61,7 +61,7 @@ Gorm is powered by callbacks, so you could refer below links to learn how to wri
|
||||
|
||||
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
|
||||
|
||||
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
|
||||
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
|
||||
|
||||
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)
|
||||
|
||||
|
@ -10,9 +10,11 @@ import (
|
||||
type JoinTableHandlerInterface interface {
|
||||
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
|
||||
Table(db *DB) string
|
||||
Add(db *DB, source interface{}, destination interface{}) error
|
||||
Delete(db *DB, sources ...interface{}) error
|
||||
JoinWith(db *DB, source interface{}) *DB
|
||||
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
|
||||
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
|
||||
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
|
||||
SourceForeignKeys() []JoinTableForeignKey
|
||||
DestinationForeignKeys() []JoinTableForeignKey
|
||||
}
|
||||
|
||||
type JoinTableForeignKey struct {
|
||||
@ -31,37 +33,58 @@ type JoinTableHandler struct {
|
||||
Destination JoinTableSource `sql:"-"`
|
||||
}
|
||||
|
||||
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
|
||||
return s.Source.ForeignKeys
|
||||
}
|
||||
|
||||
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
|
||||
return s.Destination.ForeignKeys
|
||||
}
|
||||
|
||||
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
|
||||
s.TableName = tableName
|
||||
|
||||
s.Source = JoinTableSource{ModelType: source}
|
||||
sourceScope := &Scope{Value: reflect.New(source).Interface()}
|
||||
for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields {
|
||||
sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields
|
||||
for _, primaryField := range sourcePrimaryFields {
|
||||
if relationship.ForeignDBName == "" {
|
||||
relationship.ForeignFieldName = source.Name() + primaryField.Name
|
||||
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
|
||||
}
|
||||
|
||||
var dbName string
|
||||
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
|
||||
dbName = relationship.ForeignDBName
|
||||
} else {
|
||||
dbName = ToDBName(source.Name() + primaryField.Name)
|
||||
}
|
||||
|
||||
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: relationship.ForeignDBName,
|
||||
DBName: dbName,
|
||||
AssociationDBName: primaryField.DBName,
|
||||
})
|
||||
}
|
||||
|
||||
s.Destination = JoinTableSource{ModelType: destination}
|
||||
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
|
||||
for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields {
|
||||
if relationship.AssociationForeignDBName == "" {
|
||||
relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name
|
||||
relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName)
|
||||
destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields
|
||||
for _, primaryField := range destinationPrimaryFields {
|
||||
var dbName string
|
||||
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
|
||||
dbName = relationship.AssociationForeignDBName
|
||||
} else {
|
||||
dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name)
|
||||
}
|
||||
|
||||
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
|
||||
DBName: relationship.AssociationForeignDBName,
|
||||
DBName: dbName,
|
||||
AssociationDBName: primaryField.DBName,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) Table(*DB) string {
|
||||
func (s JoinTableHandler) Table(db *DB) string {
|
||||
return s.TableName
|
||||
}
|
||||
|
||||
@ -85,7 +108,7 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
|
||||
return values
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
|
||||
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
|
||||
scope := db.NewScope("")
|
||||
searchMap := s.GetSearchMap(db, source1, source2)
|
||||
|
||||
@ -102,7 +125,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{})
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
quotedTable := s.Table(db)
|
||||
quotedTable := handler.Table(db)
|
||||
sql := fmt.Sprintf(
|
||||
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
|
||||
quotedTable,
|
||||
@ -116,7 +139,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{})
|
||||
return db.Exec(sql, values...).Error
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
|
||||
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
|
||||
var conditions []string
|
||||
var values []interface{}
|
||||
|
||||
@ -125,11 +148,11 @@ func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
|
||||
}
|
||||
|
||||
func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
|
||||
quotedTable := s.Table(db)
|
||||
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
|
||||
quotedTable := handler.Table(db)
|
||||
|
||||
scope := db.NewScope(source)
|
||||
modelType := scope.GetModelStruct().ModelType
|
||||
@ -138,7 +161,7 @@ func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
|
||||
var values []interface{}
|
||||
if s.Source.ModelType == modelType {
|
||||
for _, foreignKey := range s.Destination.ForeignKeys {
|
||||
destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
|
||||
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
|
||||
}
|
||||
|
||||
|
@ -22,7 +22,7 @@ type PersonAddress struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
|
||||
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
|
||||
return db.Where(map[string]interface{}{
|
||||
"person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
|
||||
"address_id": db.NewScope(associationValue).PrimaryKeyValue(),
|
||||
@ -33,11 +33,11 @@ func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValu
|
||||
}).FirstOrCreate(&PersonAddress{}).Error
|
||||
}
|
||||
|
||||
func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error {
|
||||
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
|
||||
return db.Delete(&PersonAddress{}).Error
|
||||
}
|
||||
|
||||
func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB {
|
||||
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
|
||||
table := pa.Table(db)
|
||||
return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
|
||||
}
|
||||
|
67
main.go
67
main.go
@ -211,6 +211,10 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB {
|
||||
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
|
||||
}
|
||||
|
||||
func (s *DB) Scan(dest interface{}) *DB {
|
||||
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
|
||||
}
|
||||
|
||||
func (s *DB) Row() *sql.Row {
|
||||
return s.NewScope(s.Value).row()
|
||||
}
|
||||
@ -219,10 +223,16 @@ func (s *DB) Rows() (*sql.Rows, error) {
|
||||
return s.NewScope(s.Value).rows()
|
||||
}
|
||||
|
||||
func (s *DB) Scan(dest interface{}) *DB {
|
||||
scope := s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest)
|
||||
Query(scope)
|
||||
return scope.db
|
||||
func (s *DB) Pluck(column string, value interface{}) *DB {
|
||||
return s.NewScope(s.Value).pluck(column, value).db
|
||||
}
|
||||
|
||||
func (s *DB) Count(value interface{}) *DB {
|
||||
return s.NewScope(s.Value).count(value).db
|
||||
}
|
||||
|
||||
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
||||
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
||||
}
|
||||
|
||||
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
|
||||
@ -309,18 +319,6 @@ func (s *DB) Model(value interface{}) *DB {
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
|
||||
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
|
||||
}
|
||||
|
||||
func (s *DB) Pluck(column string, value interface{}) *DB {
|
||||
return s.NewScope(s.Value).pluck(column, value).db
|
||||
}
|
||||
|
||||
func (s *DB) Count(value interface{}) *DB {
|
||||
return s.NewScope(s.Value).count(value).db
|
||||
}
|
||||
|
||||
func (s *DB) Table(name string) *DB {
|
||||
clone := s.clone()
|
||||
clone.search.Table(name)
|
||||
@ -398,28 +396,33 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB {
|
||||
}
|
||||
|
||||
func (s *DB) ModifyColumn(column string, typ string) *DB {
|
||||
s.clone().NewScope(s.Value).modifyColumn(column, typ)
|
||||
return s
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.modifyColumn(column, typ)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
func (s *DB) DropColumn(column string) *DB {
|
||||
s.clone().NewScope(s.Value).dropColumn(column)
|
||||
return s
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.dropColumn(column)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
func (s *DB) AddIndex(indexName string, column ...string) *DB {
|
||||
s.clone().NewScope(s.Value).addIndex(false, indexName, column...)
|
||||
return s
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.addIndex(false, indexName, column...)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
|
||||
s.clone().NewScope(s.Value).addIndex(true, indexName, column...)
|
||||
return s
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.addIndex(true, indexName, column...)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
func (s *DB) RemoveIndex(indexName string) *DB {
|
||||
s.clone().NewScope(s.Value).removeIndex(indexName)
|
||||
return s
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.removeIndex(indexName)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
/*
|
||||
@ -429,8 +432,9 @@ Example:
|
||||
db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
|
||||
*/
|
||||
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
|
||||
s.clone().NewScope(s.Value).addForeignKey(field, dest, onDelete, onUpdate)
|
||||
return s
|
||||
scope := s.clone().NewScope(s.Value)
|
||||
scope.addForeignKey(field, dest, onDelete, onUpdate)
|
||||
return scope.db
|
||||
}
|
||||
|
||||
func (s *DB) Association(column string) *Association {
|
||||
@ -475,14 +479,17 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
|
||||
}
|
||||
|
||||
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
|
||||
for _, field := range s.NewScope(source).GetModelStruct().StructFields {
|
||||
scope := s.NewScope(source)
|
||||
for _, field := range scope.GetModelStruct().StructFields {
|
||||
if field.Name == column || field.DBName == column {
|
||||
if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" {
|
||||
source := (&Scope{Value: source}).GetModelStruct().ModelType
|
||||
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
|
||||
handler.Setup(field.Relationship, many2many, source, destination)
|
||||
field.Relationship.JoinTableHandler = handler
|
||||
s.Table(handler.Table(s)).AutoMigrate(handler)
|
||||
if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
|
||||
s.Table(table).AutoMigrate(handler)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
13
main_test.go
13
main_test.go
@ -61,6 +61,19 @@ func init() {
|
||||
runMigration()
|
||||
}
|
||||
|
||||
func TestStringPrimaryKey(t *testing.T) {
|
||||
type UUIDStruct struct {
|
||||
ID string `gorm:"primary_key"`
|
||||
Name string
|
||||
}
|
||||
DB.AutoMigrate(&UUIDStruct{})
|
||||
|
||||
data := UUIDStruct{ID: "uuid", Name: "hello"}
|
||||
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
|
||||
t.Errorf("string primary key should not be populated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExceptionsWithInvalidSql(t *testing.T) {
|
||||
var columns []string
|
||||
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {
|
||||
|
10
model.go
Normal file
10
model.go
Normal file
@ -0,0 +1,10 @@
|
||||
package gorm
|
||||
|
||||
import "time"
|
||||
|
||||
type Model struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
@ -13,11 +13,19 @@ import (
|
||||
|
||||
var modelStructs = map[reflect.Type]*ModelStruct{}
|
||||
|
||||
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
|
||||
return defaultTableName
|
||||
}
|
||||
|
||||
type ModelStruct struct {
|
||||
PrimaryFields []*StructField
|
||||
StructFields []*StructField
|
||||
ModelType reflect.Type
|
||||
TableName func(*DB) string
|
||||
defaultTableName string
|
||||
}
|
||||
|
||||
func (s ModelStruct) TableName(db *DB) string {
|
||||
return DefaultTableNameHandler(db, s.defaultTableName)
|
||||
}
|
||||
|
||||
type StructField struct {
|
||||
@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
}
|
||||
|
||||
// Set tablename
|
||||
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
|
||||
if results := fm.Call([]reflect.Value{}); len(results) > 0 {
|
||||
if name, ok := results[0].Interface().(string); ok {
|
||||
modelStruct.TableName = func(*DB) string {
|
||||
return name
|
||||
}
|
||||
}
|
||||
type tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
if tabler, ok := reflect.New(scopeType).Interface().(interface {
|
||||
TableName() string
|
||||
}); ok {
|
||||
modelStruct.defaultTableName = tabler.TableName()
|
||||
} else {
|
||||
name := ToDBName(scopeType.Name())
|
||||
if scope.db == nil || !scope.db.parent.singularTable {
|
||||
@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
|
||||
}
|
||||
}
|
||||
|
||||
modelStruct.TableName = func(*DB) string {
|
||||
return name
|
||||
}
|
||||
modelStruct.defaultTableName = name
|
||||
}
|
||||
|
||||
// Get all fields
|
||||
|
46
multi_primary_keys_test.go
Normal file
46
multi_primary_keys_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
package gorm_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type Blog struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
Locale string `gorm:"primary_key"`
|
||||
Subject string
|
||||
Body string
|
||||
Tags []Tag `gorm:"many2many:blog_tags;"`
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
Locale string `gorm:"primary_key"`
|
||||
Value string
|
||||
}
|
||||
|
||||
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
|
||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" {
|
||||
DB.Exec(fmt.Sprintf("drop table blog_tags;"))
|
||||
DB.AutoMigrate(&Blog{}, &Tag{})
|
||||
blog := Blog{
|
||||
Locale: "ZH",
|
||||
Subject: "subject",
|
||||
Body: "body",
|
||||
Tags: []Tag{
|
||||
{Locale: "ZH", Value: "tag1"},
|
||||
{Locale: "ZH", Value: "tag2"},
|
||||
},
|
||||
}
|
||||
|
||||
DB.Save(&blog)
|
||||
DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}})
|
||||
|
||||
var tags []Tag
|
||||
DB.Model(&blog).Related(&tags, "Tags")
|
||||
if len(tags) != 3 {
|
||||
t.Errorf("should found 3 tags with blog")
|
||||
}
|
||||
}
|
||||
}
|
14
mysql.go
14
mysql.go
@ -14,16 +14,26 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
|
||||
switch value.Kind() {
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
|
||||
if autoIncrease {
|
||||
return "int AUTO_INCREMENT"
|
||||
}
|
||||
return "int"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
|
||||
if autoIncrease {
|
||||
return "int unsigned AUTO_INCREMENT"
|
||||
}
|
||||
return "int unsigned"
|
||||
case reflect.Int64:
|
||||
if autoIncrease {
|
||||
return "bigint AUTO_INCREMENT"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Uint64:
|
||||
if autoIncrease {
|
||||
return "bigint unsigned AUTO_INCREMENT"
|
||||
}
|
||||
return "bigint unsigned"
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return "double"
|
||||
case reflect.String:
|
||||
|
204
preload.go
204
preload.go
@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getRealValue(value reflect.Value, field string) interface{} {
|
||||
@ -20,28 +21,98 @@ func equalAsString(a interface{}, b interface{}) bool {
|
||||
}
|
||||
|
||||
func Preload(scope *Scope) {
|
||||
fields := scope.Fields()
|
||||
isSlice := scope.IndirectValue().Kind() == reflect.Slice
|
||||
if scope.Search.preload == nil {
|
||||
return
|
||||
}
|
||||
|
||||
preloadMap := map[string]bool{}
|
||||
fields := scope.Fields()
|
||||
for _, preload := range scope.Search.preload {
|
||||
schema, conditions := preload.schema, preload.conditions
|
||||
keys := strings.Split(schema, ".")
|
||||
currentScope := scope
|
||||
currentFields := fields
|
||||
originalConditions := conditions
|
||||
conditions = []interface{}{}
|
||||
for i, key := range keys {
|
||||
var found bool
|
||||
if preloadMap[strings.Join(keys[:i+1], ".")] {
|
||||
goto nextLoop
|
||||
}
|
||||
|
||||
if i == len(keys)-1 {
|
||||
conditions = originalConditions
|
||||
}
|
||||
|
||||
for _, field := range currentFields {
|
||||
if field.Name != key || field.Relationship == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
found = true
|
||||
switch field.Relationship.Kind {
|
||||
case "has_one":
|
||||
currentScope.handleHasOnePreload(field, conditions)
|
||||
case "has_many":
|
||||
currentScope.handleHasManyPreload(field, conditions)
|
||||
case "belongs_to":
|
||||
currentScope.handleBelongsToPreload(field, conditions)
|
||||
case "many_to_many":
|
||||
fallthrough
|
||||
default:
|
||||
currentScope.Err(errors.New("not supported relation"))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if !found {
|
||||
value := reflect.ValueOf(currentScope.Value)
|
||||
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
|
||||
value = value.Index(0).Elem()
|
||||
}
|
||||
scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
|
||||
return
|
||||
}
|
||||
|
||||
preloadMap[strings.Join(keys[:i+1], ".")] = true
|
||||
|
||||
nextLoop:
|
||||
if i < len(keys)-1 {
|
||||
currentScope = currentScope.getColumnsAsScope(key)
|
||||
currentFields = currentScope.Fields()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func makeSlice(typ reflect.Type) interface{} {
|
||||
if typ.Kind() == reflect.Slice {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
sliceType := reflect.SliceOf(typ)
|
||||
slice := reflect.New(sliceType)
|
||||
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||
return slice.Interface()
|
||||
}
|
||||
|
||||
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
|
||||
primaryName := scope.PrimaryField().Name
|
||||
primaryKeys := scope.getColumnAsArray(primaryName)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if scope.Search.preload != nil {
|
||||
for key, conditions := range scope.Search.preload {
|
||||
for _, field := range fields {
|
||||
if field.Name == key && field.Relationship != nil {
|
||||
results := makeSlice(field.Struct.Type)
|
||||
relation := field.Relationship
|
||||
primaryName := scope.PrimaryField().Name
|
||||
associationPrimaryKey := scope.New(results).PrimaryField().Name
|
||||
|
||||
switch relation.Kind {
|
||||
case "has_one":
|
||||
if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
|
||||
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
|
||||
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
|
||||
|
||||
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if isSlice {
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
@ -51,16 +122,29 @@ func Preload(scope *Scope) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
scope.SetColumn(field, result)
|
||||
if err := scope.SetColumn(field, result); err != nil {
|
||||
scope.Err(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
case "has_many":
|
||||
if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
|
||||
}
|
||||
|
||||
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
|
||||
primaryName := scope.PrimaryField().Name
|
||||
primaryKeys := scope.getColumnAsArray(primaryName)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
relation := field.Relationship
|
||||
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
|
||||
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
|
||||
|
||||
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
if isSlice {
|
||||
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
value := getRealValue(result, relation.ForeignFieldName)
|
||||
@ -78,13 +162,23 @@ func Preload(scope *Scope) {
|
||||
scope.SetColumn(field, resultValues)
|
||||
}
|
||||
}
|
||||
case "belongs_to":
|
||||
if primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 {
|
||||
scope.NewDB().Where(primaryKeys).Find(results, conditions...)
|
||||
|
||||
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
|
||||
relation := field.Relationship
|
||||
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName)
|
||||
if len(primaryKeys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
results := makeSlice(field.Struct.Type)
|
||||
associationPrimaryKey := scope.New(results).PrimaryField().Name
|
||||
|
||||
scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error)
|
||||
resultValues := reflect.Indirect(reflect.ValueOf(results))
|
||||
|
||||
for i := 0; i < resultValues.Len(); i++ {
|
||||
result := resultValues.Index(i)
|
||||
if isSlice {
|
||||
if scope.IndirectValue().Kind() == reflect.Slice {
|
||||
value := getRealValue(result, associationPrimaryKey)
|
||||
objects := scope.IndirectValue()
|
||||
for j := 0; j < objects.Len(); j++ {
|
||||
@ -98,41 +192,51 @@ func Preload(scope *Scope) {
|
||||
}
|
||||
}
|
||||
}
|
||||
case "many_to_many":
|
||||
scope.Err(errors.New("not supported relation"))
|
||||
default:
|
||||
scope.Err(errors.New("not supported relation"))
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeSlice(typ reflect.Type) interface{} {
|
||||
if typ.Kind() == reflect.Slice {
|
||||
typ = typ.Elem()
|
||||
}
|
||||
sliceType := reflect.SliceOf(typ)
|
||||
slice := reflect.New(sliceType)
|
||||
slice.Elem().Set(reflect.MakeSlice(sliceType, 0, 0))
|
||||
return slice.Interface()
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) {
|
||||
func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) {
|
||||
values := scope.IndirectValue()
|
||||
switch values.Kind() {
|
||||
case reflect.Slice:
|
||||
primaryKeyMap := map[interface{}]bool{}
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
primaryKeyMap[reflect.Indirect(values.Index(i)).FieldByName(column).Interface()] = true
|
||||
}
|
||||
for key := range primaryKeyMap {
|
||||
primaryKeys = append(primaryKeys, key)
|
||||
columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
return []interface{}{values.FieldByName(column).Interface()}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (scope *Scope) getColumnsAsScope(column string) *Scope {
|
||||
values := scope.IndirectValue()
|
||||
switch values.Kind() {
|
||||
case reflect.Slice:
|
||||
modelType := values.Type().Elem()
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
fieldStruct, _ := modelType.FieldByName(column)
|
||||
var columns reflect.Value
|
||||
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
|
||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
|
||||
} else {
|
||||
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
|
||||
}
|
||||
for i := 0; i < values.Len(); i++ {
|
||||
column := reflect.Indirect(values.Index(i)).FieldByName(column)
|
||||
if column.Kind() == reflect.Ptr {
|
||||
column = column.Elem()
|
||||
}
|
||||
if column.Kind() == reflect.Slice {
|
||||
for i := 0; i < column.Len(); i++ {
|
||||
columns = reflect.Append(columns, column.Index(i).Addr())
|
||||
}
|
||||
} else {
|
||||
columns = reflect.Append(columns, column.Addr())
|
||||
}
|
||||
}
|
||||
return scope.New(columns.Interface())
|
||||
case reflect.Struct:
|
||||
return scope.New(values.FieldByName(column).Addr().Interface())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
524
preload_test.go
524
preload_test.go
@ -1,6 +1,10 @@
|
||||
package gorm_test
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func getPreloadUser(name string) *User {
|
||||
return getPreparedUser(name, "Preload")
|
||||
@ -85,3 +89,521 @@ func TestPreload(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreload1(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1 Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Name string
|
||||
Level2 Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
|
||||
if err := DB.Create(&want).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got Level3
|
||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreload2(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1s []*Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Name string
|
||||
Level2s []Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := Level3{
|
||||
Level2s: []Level2{
|
||||
{
|
||||
Level1s: []*Level1{
|
||||
&Level1{Value: "value1"},
|
||||
&Level1{Value: "value2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Level1s: []*Level1{
|
||||
&Level1{Value: "value3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got Level3
|
||||
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreload3(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1 Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
Name string
|
||||
ID uint
|
||||
Level2s []Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := Level3{
|
||||
Level2s: []Level2{
|
||||
{Level1: Level1{Value: "value1"}},
|
||||
{Level1: Level1{Value: "value2"}},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got Level3
|
||||
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreload4(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1s []Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Name string
|
||||
Level2 Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value1"},
|
||||
Level1{Value: "value2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got Level3
|
||||
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
// Slice: []Level3
|
||||
func TestNestedPreload5(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1 Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Name string
|
||||
Level2 Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := make([]Level3, 2)
|
||||
want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
|
||||
if err := DB.Create(&want[0]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}}
|
||||
if err := DB.Create(&want[1]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got []Level3
|
||||
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreload6(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1s []Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Name string
|
||||
Level2s []Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := make([]Level3, 2)
|
||||
want[0] = Level3{
|
||||
Level2s: []Level2{
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value1"},
|
||||
{Value: "value2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[0]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want[1] = Level3{
|
||||
Level2s: []Level2{
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value3"},
|
||||
{Value: "value4"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Level1s: []Level1{
|
||||
{Value: "value5"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[1]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got []Level3
|
||||
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreload7(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1 Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Name string
|
||||
Level2s []Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := make([]Level3, 2)
|
||||
want[0] = Level3{
|
||||
Level2s: []Level2{
|
||||
{Level1: Level1{Value: "value1"}},
|
||||
{Level1: Level1{Value: "value2"}},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[0]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want[1] = Level3{
|
||||
Level2s: []Level2{
|
||||
{Level1: Level1{Value: "value3"}},
|
||||
{Level1: Level1{Value: "value4"}},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[1]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got []Level3
|
||||
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreload8(t *testing.T) {
|
||||
type (
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1s []Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Name string
|
||||
Level2 Level2
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := make([]Level3, 2)
|
||||
want[0] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value1"},
|
||||
Level1{Value: "value2"},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[0]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
want[1] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value3"},
|
||||
Level1{Value: "value4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[1]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got []Level3
|
||||
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedPreload9(t *testing.T) {
|
||||
type (
|
||||
Level0 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level1ID uint
|
||||
}
|
||||
Level1 struct {
|
||||
ID uint
|
||||
Value string
|
||||
Level2ID uint
|
||||
Level2_1ID uint
|
||||
Level0s []Level0
|
||||
}
|
||||
Level2 struct {
|
||||
ID uint
|
||||
Level1s []Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level2_1 struct {
|
||||
ID uint
|
||||
Level1s []Level1
|
||||
Level3ID uint
|
||||
}
|
||||
Level3 struct {
|
||||
ID uint
|
||||
Name string
|
||||
Level2 Level2
|
||||
Level2_1 Level2_1
|
||||
}
|
||||
)
|
||||
DB.DropTableIfExists(&Level3{})
|
||||
DB.DropTableIfExists(&Level2{})
|
||||
DB.DropTableIfExists(&Level2_1{})
|
||||
DB.DropTableIfExists(&Level1{})
|
||||
DB.DropTableIfExists(&Level0{})
|
||||
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := make([]Level3, 2)
|
||||
want[0] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value1"},
|
||||
Level1{Value: "value2"},
|
||||
},
|
||||
},
|
||||
Level2_1: Level2_1{
|
||||
Level1s: []Level1{
|
||||
Level1{
|
||||
Value: "value1-1",
|
||||
Level0s: []Level0{{Value: "Level0-1"}},
|
||||
},
|
||||
Level1{
|
||||
Value: "value2-2",
|
||||
Level0s: []Level0{{Value: "Level0-2"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[0]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
want[1] = Level3{
|
||||
Level2: Level2{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value3"},
|
||||
Level1{Value: "value4"},
|
||||
},
|
||||
},
|
||||
Level2_1: Level2_1{
|
||||
Level1s: []Level1{
|
||||
Level1{Value: "value3-3"},
|
||||
Level1{Value: "value4-4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := DB.Create(&want[1]).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var got []Level3
|
||||
if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
|
||||
}
|
||||
}
|
||||
|
||||
func toJSONString(v interface{}) []byte {
|
||||
r, _ := json.MarshalIndent(v, "", " ")
|
||||
return r
|
||||
}
|
||||
|
@ -98,49 +98,41 @@ func TestSearchWithPlainSQL(t *testing.T) {
|
||||
t.Errorf("Should found 2 users that age > 1, but got %v", len(users))
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users)
|
||||
if len(users) != 3 {
|
||||
t.Errorf("Should found 3 users that age >= 1, but got %v", len(users))
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
scopedb.Where("age <> ?", 20).Find(&users)
|
||||
if len(users) != 2 {
|
||||
t.Errorf("Should found 2 users age != 20, but got %v", len(users))
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users)
|
||||
if len(users) != 2 {
|
||||
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
|
||||
if len(users) != 2 {
|
||||
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
|
||||
if len(users) != 1 {
|
||||
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
|
||||
if len(users) != 2 {
|
||||
t.Errorf("Should found 2 users, but got %v", len(users))
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users)
|
||||
if len(users) != 3 {
|
||||
t.Errorf("Should found 3 users, but got %v", len(users))
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
DB.Where("id in (?)", user1.Id).Find(&users)
|
||||
if len(users) != 1 {
|
||||
t.Errorf("Should found 1 users, but got %v", len(users))
|
||||
@ -191,7 +183,6 @@ func TestSearchWithStruct(t *testing.T) {
|
||||
t.Errorf("Search first record with where struct")
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
DB.Find(&users, &User{Name: user2.Name})
|
||||
if len(users) != 1 {
|
||||
t.Errorf("Search all records with inline struct")
|
||||
@ -222,7 +213,6 @@ func TestSearchWithMap(t *testing.T) {
|
||||
t.Errorf("Search all records with inline map")
|
||||
}
|
||||
|
||||
users = []User{}
|
||||
DB.Find(&users, map[string]interface{}{"name": user3.Name})
|
||||
if len(users) != 1 {
|
||||
t.Errorf("Search all records with inline map")
|
||||
@ -395,13 +385,11 @@ func TestNot(t *testing.T) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
}
|
||||
|
||||
users4 = []User{}
|
||||
DB.Not("name = ?", "user3").Find(&users4)
|
||||
if len(users1)-len(users4) != int(name3Count) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
}
|
||||
|
||||
users4 = []User{}
|
||||
DB.Not("name <> ?", "user3").Find(&users4)
|
||||
if len(users4) != int(name3Count) {
|
||||
t.Errorf("Should find all users's name not equal 3")
|
||||
|
38
scope.go
38
scope.go
@ -110,6 +110,14 @@ func (scope *Scope) HasError() bool {
|
||||
return scope.db.Error != nil
|
||||
}
|
||||
|
||||
func (scope *Scope) PrimaryFields() []*Field {
|
||||
var fields = []*Field{}
|
||||
for _, field := range scope.GetModelStruct().PrimaryFields {
|
||||
fields = append(fields, scope.Fields()[field.DBName])
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func (scope *Scope) PrimaryField() *Field {
|
||||
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
|
||||
if len(primaryFields) > 1 {
|
||||
@ -158,13 +166,18 @@ func (scope *Scope) HasColumn(column string) bool {
|
||||
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
||||
if field, ok := column.(*Field); ok {
|
||||
return field.Set(value)
|
||||
} else if dbName, ok := column.(string); ok {
|
||||
} else if name, ok := column.(string); ok {
|
||||
|
||||
if field, ok := scope.Fields()[name]; ok {
|
||||
return field.Set(value)
|
||||
}
|
||||
|
||||
dbName := ToDBName(name)
|
||||
if field, ok := scope.Fields()[dbName]; ok {
|
||||
return field.Set(value)
|
||||
}
|
||||
|
||||
dbName = ToDBName(dbName)
|
||||
if field, ok := scope.Fields()[dbName]; ok {
|
||||
if field, ok := scope.FieldByName(name); ok {
|
||||
return field.Set(value)
|
||||
}
|
||||
}
|
||||
@ -172,7 +185,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
|
||||
}
|
||||
|
||||
func (scope *Scope) CallMethod(name string, checkError bool) {
|
||||
if scope.Value == nil && (!checkError || !scope.HasError()) {
|
||||
if scope.Value == nil || (checkError && scope.HasError()) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -246,17 +259,14 @@ func (scope *Scope) TableName() string {
|
||||
return tabler.TableName(scope.db)
|
||||
}
|
||||
|
||||
if scope.GetModelStruct().TableName != nil {
|
||||
scope.Search.tableName = scope.GetModelStruct().TableName(scope.db)
|
||||
return scope.Search.tableName
|
||||
}
|
||||
|
||||
scope.Err(errors.New("wrong table name"))
|
||||
return ""
|
||||
return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
|
||||
}
|
||||
|
||||
func (scope *Scope) QuotedTableName() (name string) {
|
||||
if scope.Search != nil && len(scope.Search.tableName) > 0 {
|
||||
if strings.Index(scope.Search.tableName, " ") != -1 {
|
||||
return scope.Search.tableName
|
||||
}
|
||||
return scope.Quote(scope.Search.tableName)
|
||||
} else {
|
||||
return scope.Quote(scope.TableName())
|
||||
@ -271,7 +281,7 @@ func (scope *Scope) CombinedConditionSql() string {
|
||||
|
||||
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
|
||||
for _, field := range scope.Fields() {
|
||||
if field.Name == name {
|
||||
if field.Name == name || field.DBName == name {
|
||||
return field, true
|
||||
}
|
||||
}
|
||||
@ -290,7 +300,7 @@ func (scope *Scope) Exec() *Scope {
|
||||
|
||||
if !scope.HasError() {
|
||||
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
|
||||
if count, err := result.RowsAffected(); err == nil {
|
||||
if count, err := result.RowsAffected(); scope.Err(err) == nil {
|
||||
scope.db.RowsAffected = count
|
||||
}
|
||||
}
|
||||
@ -364,6 +374,8 @@ func (scope *Scope) SelectAttrs() []string {
|
||||
for _, value := range scope.Search.selects {
|
||||
if str, ok := value.(string); ok {
|
||||
attrs = append(attrs, str)
|
||||
} else if strs, ok := value.([]string); ok {
|
||||
attrs = append(attrs, strs...)
|
||||
} else if strs, ok := value.([]interface{}); ok {
|
||||
for _, str := range strs {
|
||||
attrs = append(attrs, fmt.Sprintf("%v", str))
|
||||
|
@ -38,7 +38,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
|
||||
case interface{}:
|
||||
var sqls []string
|
||||
for _, field := range scope.New(value).Fields() {
|
||||
if !field.IsBlank {
|
||||
if !field.IsIgnored && !field.IsBlank {
|
||||
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||
}
|
||||
}
|
||||
@ -336,12 +336,14 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
|
||||
|
||||
func (scope *Scope) row() *sql.Row {
|
||||
defer scope.Trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callback.rowQueries)
|
||||
scope.prepareQuerySql()
|
||||
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
|
||||
}
|
||||
|
||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||
defer scope.Trace(NowFunc())
|
||||
scope.callCallbacks(scope.db.parent.callback.rowQueries)
|
||||
scope.prepareQuerySql()
|
||||
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
|
||||
}
|
||||
@ -411,7 +413,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
|
||||
if relationship := fromField.Relationship; relationship != nil {
|
||||
if relationship.Kind == "many_to_many" {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error)
|
||||
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
|
||||
} else if relationship.Kind == "belongs_to" {
|
||||
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
|
||||
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()
|
||||
@ -445,13 +447,19 @@ func (scope *Scope) createJoinTable(field *StructField) {
|
||||
joinTableHandler := relationship.JoinTableHandler
|
||||
joinTable := joinTableHandler.Table(scope.db)
|
||||
if !scope.Dialect().HasTable(scope, joinTable) {
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
|
||||
scope.Quote(joinTable),
|
||||
strings.Join([]string{
|
||||
scope.Quote(relationship.ForeignDBName) + " " + primaryKeySqlType,
|
||||
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
|
||||
).Error)
|
||||
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
|
||||
|
||||
var sqlTypes []string
|
||||
for _, s := range []*Scope{scope, toScope} {
|
||||
for _, primaryField := range s.GetModelStruct().PrimaryFields {
|
||||
value := reflect.Indirect(reflect.New(primaryField.Struct.Type))
|
||||
primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
|
||||
dbName := ToDBName(s.GetModelStruct().ModelType.Name() + primaryField.Name)
|
||||
sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType)
|
||||
}
|
||||
}
|
||||
|
||||
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join(sqlTypes, ","))).Error)
|
||||
}
|
||||
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
|
||||
}
|
||||
@ -467,7 +475,7 @@ func (scope *Scope) createTable() *Scope {
|
||||
}
|
||||
|
||||
if field.IsPrimaryKey {
|
||||
primaryKeys = append(primaryKeys, field.DBName)
|
||||
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
}
|
||||
@ -522,7 +530,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
|
||||
var table = scope.TableName()
|
||||
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
|
||||
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
|
||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), keyName, field, dest, onDelete, onUpdate)).Exec()
|
||||
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.Quote(keyName), scope.Quote(field), scope.Quote(dest), onDelete, onUpdate)).Exec()
|
||||
}
|
||||
|
||||
func (scope *Scope) removeIndex(indexName string) {
|
||||
@ -540,7 +548,7 @@ func (scope *Scope) autoMigrate() *Scope {
|
||||
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
|
||||
if field.IsNormal {
|
||||
sqlTag := scope.generateSqlTag(field)
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
|
||||
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
|
||||
}
|
||||
}
|
||||
scope.createJoinTable(field)
|
||||
|
40
search.go
40
search.go
@ -14,7 +14,7 @@ type search struct {
|
||||
omits []string
|
||||
orders []string
|
||||
joins string
|
||||
preload map[string][]interface{}
|
||||
preload []searchPreload
|
||||
offset string
|
||||
limit string
|
||||
group string
|
||||
@ -23,26 +23,14 @@ type search struct {
|
||||
Unscoped bool
|
||||
}
|
||||
|
||||
func (s *search) clone() *search {
|
||||
return &search{
|
||||
preload: s.preload,
|
||||
whereConditions: s.whereConditions,
|
||||
orConditions: s.orConditions,
|
||||
notConditions: s.notConditions,
|
||||
havingCondition: s.havingCondition,
|
||||
initAttrs: s.initAttrs,
|
||||
assignAttrs: s.assignAttrs,
|
||||
selects: s.selects,
|
||||
omits: s.omits,
|
||||
orders: s.orders,
|
||||
joins: s.joins,
|
||||
offset: s.offset,
|
||||
limit: s.limit,
|
||||
group: s.group,
|
||||
tableName: s.tableName,
|
||||
raw: s.raw,
|
||||
Unscoped: s.Unscoped,
|
||||
type searchPreload struct {
|
||||
schema string
|
||||
conditions []interface{}
|
||||
}
|
||||
|
||||
func (s *search) clone() *search {
|
||||
clone := *s
|
||||
return &clone
|
||||
}
|
||||
|
||||
func (s *search) Where(query interface{}, values ...interface{}) *search {
|
||||
@ -114,11 +102,15 @@ func (s *search) Joins(query string) *search {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *search) Preload(column string, values ...interface{}) *search {
|
||||
if s.preload == nil {
|
||||
s.preload = map[string][]interface{}{}
|
||||
func (s *search) Preload(schema string, values ...interface{}) *search {
|
||||
var preloads []searchPreload
|
||||
for _, preload := range s.preload {
|
||||
if preload.schema != schema {
|
||||
preloads = append(preloads, preload)
|
||||
}
|
||||
s.preload[column] = values
|
||||
}
|
||||
preloads = append(preloads, searchPreload{schema, values})
|
||||
s.preload = preloads
|
||||
return s
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
dialects=("postgres" "foundation" "mysql" "sqlite")
|
||||
dialects=("postgres" "mysql" "sqlite")
|
||||
|
||||
for dialect in "${dialects[@]}" ; do
|
||||
GORM_DIALECT=${dialect} go test
|
||||
|
Loading…
x
Reference in New Issue
Block a user