From e2eef50fb4bd50c8cb899f73d317723a22a1fdfc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 8 Apr 2015 14:00:29 +0800 Subject: [PATCH 01/52] Add SetTableNameHandler --- main.go | 4 ++++ scope.go | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 7049675e..86301207 100644 --- a/main.go +++ b/main.go @@ -487,3 +487,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join } } } + +func (s *DB) SetTableNameHandler(source interface{}, handler func(*DB) string) { + s.NewScope(source).GetModelStruct().TableName = handler +} diff --git a/scope.go b/scope.go index 86994a85..b83581e2 100644 --- a/scope.go +++ b/scope.go @@ -247,8 +247,7 @@ func (scope *Scope) TableName() string { } if scope.GetModelStruct().TableName != nil { - scope.Search.tableName = scope.GetModelStruct().TableName(scope.db) - return scope.Search.tableName + return scope.GetModelStruct().TableName(scope.db) } scope.Err(errors.New("wrong table name")) From 67266ebdb3f7deb12c913104ef56ece51a83e193 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 9 Apr 2015 15:51:47 +0800 Subject: [PATCH 02/52] Fix create join table with multi primary keys --- scope_private.go | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/scope_private.go b/scope_private.go index 99dda2ed..03a47acc 100644 --- a/scope_private.go +++ b/scope_private.go @@ -445,13 +445,19 @@ func (scope *Scope) createJoinTable(field *StructField) { joinTableHandler := relationship.JoinTableHandler joinTable := joinTableHandler.Table(scope.db) if !scope.Dialect().HasTable(scope, joinTable) { - primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false) - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", - scope.Quote(joinTable), - strings.Join([]string{ - scope.Quote(relationship.ForeignDBName) + " " + primaryKeySqlType, - scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")), - ).Error) + toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} + + var sqlTypes []string + for _, s := range []*Scope{scope, toScope} { + for _, primaryField := range s.GetModelStruct().PrimaryFields { + value := reflect.Indirect(reflect.New(primaryField.Struct.Type)) + primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) + dbName := ToDBName(s.GetModelStruct().ModelType.Name() + primaryField.Name) + sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType) + } + } + + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join(sqlTypes, ","))).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } From 1eb1ed091f2b725699a930e03a05890c861d2c50 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 10 Apr 2015 16:55:53 +0800 Subject: [PATCH 03/52] Test ManyToMany relations with multi primary keys --- association.go | 1 + callback_update.go | 16 +++++++------ join_table_handler.go | 27 ++++++++++++++++------ multi_primary_keys_test.go | 46 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 14 deletions(-) create mode 100644 multi_primary_keys_test.go diff --git a/association.go b/association.go index 89bb1bec..37e10516 100644 --- a/association.go +++ b/association.go @@ -40,6 +40,7 @@ func (association *Association) Append(values ...interface{}) *Association { association.setErr(errors.New("invalid association type")) } } + scope.Search.Select(association.Column) scope.callCallbacks(scope.db.parent.callback.updates) return association.setErr(scope.db.Error) } diff --git a/callback_update.go b/callback_update.go index 1167871c..c3f7b4b6 100644 --- a/callback_update.go +++ b/callback_update.go @@ -64,13 +64,15 @@ func Update(scope *Scope) { } } - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v %v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - scope.CombinedConditionSql(), - )) - scope.Exec() + if len(sqls) > 0 { + scope.Raw(fmt.Sprintf( + "UPDATE %v SET %v %v", + scope.QuotedTableName(), + strings.Join(sqls, ", "), + scope.CombinedConditionSql(), + )) + scope.Exec() + } } } diff --git a/join_table_handler.go b/join_table_handler.go index 9f705564..e589d6f5 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -36,26 +36,39 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s s.Source = JoinTableSource{ModelType: source} sourceScope := &Scope{Value: reflect.New(source).Interface()} - for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields + for _, primaryField := range sourcePrimaryFields { if relationship.ForeignDBName == "" { relationship.ForeignFieldName = source.Name() + primaryField.Name relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName) } + + var dbName string + if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { + dbName = relationship.ForeignDBName + } else { + dbName = ToDBName(source.Name() + primaryField.Name) + } + s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBName, + DBName: dbName, AssociationDBName: primaryField.DBName, }) } s.Destination = JoinTableSource{ModelType: destination} destinationScope := &Scope{Value: reflect.New(destination).Interface()} - for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { - if relationship.AssociationForeignDBName == "" { - relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name - relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName) + destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields + for _, primaryField := range destinationPrimaryFields { + var dbName string + if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { + dbName = relationship.AssociationForeignDBName + } else { + dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name) } + s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBName, + DBName: dbName, AssociationDBName: primaryField.DBName, }) } diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go new file mode 100644 index 00000000..4aa3517e --- /dev/null +++ b/multi_primary_keys_test.go @@ -0,0 +1,46 @@ +package gorm_test + +import ( + "fmt" + "os" + "testing" +) + +type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` +} + +type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string +} + +func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "sqlite" { + DB.Exec(fmt.Sprintf("drop table blog_tags;")) + DB.AutoMigrate(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + Tags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + + DB.Save(&blog) + DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}}) + + var tags []Tag + DB.Model(&blog).Related(&tags, "Tags") + if len(tags) != 3 { + t.Errorf("should found 3 tags with blog") + } + } +} From d61af54b96cb2dc0806e2247df05582e79b2636b Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 13 Apr 2015 10:09:00 +0800 Subject: [PATCH 04/52] Add default model struct --- README.md | 17 ++++++++++++++++- model.go | 10 ++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 model.go diff --git a/README.md b/README.md index 0c4ea83a..4bcba989 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ type User struct { Num int `sql:"AUTO_INCREMENT"` CreatedAt time.Time UpdatedAt time.Time - DeletedAt time.Time + DeletedAt *time.Time Emails []Email // One-To-Many relationship (has many) BillingAddress Address // One-To-One relationship (has one) @@ -84,6 +84,21 @@ type User struct{} // struct User's database table name is "users" by default, w * Use `CreatedAt` to store record's created time if field exists * Use `UpdatedAt` to store record's updated time if field exists * Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete) +* Gorm provide a default model struct, you could embed it in your struct + +```go +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time +} + +type User struct { + gorm.Model + Name string +} +``` ## Initialize Database diff --git a/model.go b/model.go new file mode 100644 index 00000000..50fa52e6 --- /dev/null +++ b/model.go @@ -0,0 +1,10 @@ +package gorm + +import "time" + +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time +} From 7966cde51410986039fb6707a2ae417ce1bbf869 Mon Sep 17 00:00:00 2001 From: li3p Date: Thu, 16 Apr 2015 14:08:13 +0800 Subject: [PATCH 05/52] look up the Field.Name at Scope.SetColumn --- scope.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/scope.go b/scope.go index b83581e2..c42c6afb 100644 --- a/scope.go +++ b/scope.go @@ -158,13 +158,18 @@ func (scope *Scope) HasColumn(column string) bool { func (scope *Scope) SetColumn(column interface{}, value interface{}) error { if field, ok := column.(*Field); ok { return field.Set(value) - } else if dbName, ok := column.(string); ok { + } else if name, ok := column.(string); ok { + + if field, ok := scope.Fields()[name]; ok { + return field.Set(value) + } + + dbName = ToDBName(name) if field, ok := scope.Fields()[dbName]; ok { return field.Set(value) } - dbName = ToDBName(dbName) - if field, ok := scope.Fields()[dbName]; ok { + if field, ok := scope.FieldByName(name); ok { return field.Set(value) } } From f9bd6bcc6418ccd5fcc54862b3b6568747f894a8 Mon Sep 17 00:00:00 2001 From: li3p Date: Thu, 16 Apr 2015 14:28:55 +0800 Subject: [PATCH 06/52] look up the Field.Name at Scope.SetColumn --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index c42c6afb..f1733aa3 100644 --- a/scope.go +++ b/scope.go @@ -164,7 +164,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { return field.Set(value) } - dbName = ToDBName(name) + dbName := ToDBName(name) if field, ok := scope.Fields()[dbName]; ok { return field.Set(value) } From 5aca010140e2cd4b30ce9bbbb2f61e4afc0b8c33 Mon Sep 17 00:00:00 2001 From: bom_d_van Date: Thu, 16 Apr 2015 17:36:22 +0800 Subject: [PATCH 07/52] ignore empty dialect in TestManyToManyWithMultiPrimaryKeys --- multi_primary_keys_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go index 4aa3517e..9ca68d13 100644 --- a/multi_primary_keys_test.go +++ b/multi_primary_keys_test.go @@ -21,7 +21,7 @@ type Tag struct { } func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { - if dialect := os.Getenv("GORM_DIALECT"); dialect != "sqlite" { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { DB.Exec(fmt.Sprintf("drop table blog_tags;")) DB.AutoMigrate(&Blog{}, &Tag{}) blog := Blog{ From 681aa90995c744fc794398219c05125981f61436 Mon Sep 17 00:00:00 2001 From: bom_d_van Date: Thu, 16 Apr 2015 17:36:55 +0800 Subject: [PATCH 08/52] simplify clone --- search.go | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/search.go b/search.go index 502c226f..a180fb92 100644 --- a/search.go +++ b/search.go @@ -24,25 +24,8 @@ type search struct { } func (s *search) clone() *search { - return &search{ - preload: s.preload, - whereConditions: s.whereConditions, - orConditions: s.orConditions, - notConditions: s.notConditions, - havingCondition: s.havingCondition, - initAttrs: s.initAttrs, - assignAttrs: s.assignAttrs, - selects: s.selects, - omits: s.omits, - orders: s.orders, - joins: s.joins, - offset: s.offset, - limit: s.limit, - group: s.group, - tableName: s.tableName, - raw: s.raw, - Unscoped: s.Unscoped, - } + clone := *s + return &clone } func (s *search) Where(query interface{}, values ...interface{}) *search { From 0e2cd4475f53c8143fd5f758933355425dcafdff Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 16 Apr 2015 18:39:30 +0800 Subject: [PATCH 09/52] Only load Fields when defined preload --- preload.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/preload.go b/preload.go index d252238a..541f7b95 100644 --- a/preload.go +++ b/preload.go @@ -20,10 +20,10 @@ func equalAsString(a interface{}, b interface{}) bool { } func Preload(scope *Scope) { - fields := scope.Fields() - isSlice := scope.IndirectValue().Kind() == reflect.Slice - if scope.Search.preload != nil { + fields := scope.Fields() + isSlice := scope.IndirectValue().Kind() == reflect.Slice + for key, conditions := range scope.Search.preload { for _, field := range fields { if field.Name == key && field.Relationship != nil { From 4fbc9d2a8ff31c9e7f43ef9a81d9fc715d13d70c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Apr 2015 14:52:59 +0800 Subject: [PATCH 10/52] Remove foundationdb from tests all script because it is not downloadable from offical site --- main.go | 4 +--- test_all.sh | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/main.go b/main.go index 86301207..f39a373f 100644 --- a/main.go +++ b/main.go @@ -220,9 +220,7 @@ func (s *DB) Rows() (*sql.Rows, error) { } func (s *DB) Scan(dest interface{}) *DB { - scope := s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest) - Query(scope) - return scope.db + return s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db } func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { diff --git a/test_all.sh b/test_all.sh index bd28294d..6c5593b3 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "foundation" "mysql" "sqlite") +dialects=("postgres" "mysql" "sqlite") for dialect in "${dialects[@]}" ; do GORM_DIALECT=${dialect} go test From ef4299b39879ad31b5511acecc12ef4457276d40 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 17 Apr 2015 18:27:20 +0800 Subject: [PATCH 11/52] Add RowQuery callback --- callback.go | 10 +++++++++- callback_query.go | 10 +++++----- main.go | 28 ++++++++++++++-------------- scope_private.go | 2 ++ 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/callback.go b/callback.go index be9a7f12..603e5111 100644 --- a/callback.go +++ b/callback.go @@ -9,6 +9,7 @@ type callback struct { updates []*func(scope *Scope) deletes []*func(scope *Scope) queries []*func(scope *Scope) + rowQueries []*func(scope *Scope) processors []*callbackProcessor } @@ -55,6 +56,10 @@ func (c *callback) Query() *callbackProcessor { return c.addProcessor("query") } +func (c *callback) RowQuery() *callbackProcessor { + return c.addProcessor("row_query") +} + func (cp *callbackProcessor) Before(name string) *callbackProcessor { cp.before = name return cp @@ -168,7 +173,7 @@ func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) { } func (c *callback) sort() { - creates, updates, deletes, queries := []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{} + var creates, updates, deletes, queries, rowQueries []*callbackProcessor for _, processor := range c.processors { switch processor.typ { @@ -180,6 +185,8 @@ func (c *callback) sort() { deletes = append(deletes, processor) case "query": queries = append(queries, processor) + case "row_query": + rowQueries = append(rowQueries, processor) } } @@ -187,6 +194,7 @@ func (c *callback) sort() { c.updates = sortProcessors(updates) c.deletes = sortProcessors(deletes) c.queries = sortProcessors(queries) + c.rowQueries = sortProcessors(rowQueries) } var DefaultCallback = &callback{processors: []*callbackProcessor{}} diff --git a/callback_query.go b/callback_query.go index 5daa5fec..0eea6f89 100644 --- a/callback_query.go +++ b/callback_query.go @@ -16,17 +16,17 @@ func Query(scope *Scope) { destType reflect.Type ) - var dest = scope.IndirectValue() - if value, ok := scope.InstanceGet("gorm:query_destination"); ok { - dest = reflect.Indirect(reflect.ValueOf(value)) - } - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { if primaryKey := scope.PrimaryKey(); primaryKey != "" { scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy)) } } + var dest = scope.IndirectValue() + if value, ok := scope.InstanceGet("gorm:query_destination"); ok { + dest = reflect.Indirect(reflect.ValueOf(value)) + } + if kind := dest.Kind(); kind == reflect.Slice { isSlice = true destType = dest.Type().Elem() diff --git a/main.go b/main.go index f39a373f..04f59bcf 100644 --- a/main.go +++ b/main.go @@ -211,6 +211,10 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } +func (s *DB) Scan(dest interface{}) *DB { + return s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db +} + func (s *DB) Row() *sql.Row { return s.NewScope(s.Value).row() } @@ -219,8 +223,16 @@ func (s *DB) Rows() (*sql.Rows, error) { return s.NewScope(s.Value).rows() } -func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db +func (s *DB) Pluck(column string, value interface{}) *DB { + return s.NewScope(s.Value).pluck(column, value).db +} + +func (s *DB) Count(value interface{}) *DB { + return s.NewScope(s.Value).count(value).db +} + +func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { + return s.clone().NewScope(s.Value).related(value, foreignKeys...).db } func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { @@ -307,18 +319,6 @@ func (s *DB) Model(value interface{}) *DB { return c } -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db -} - -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - func (s *DB) Table(name string) *DB { clone := s.clone() clone.search.Table(name) diff --git a/scope_private.go b/scope_private.go index 03a47acc..b9476455 100644 --- a/scope_private.go +++ b/scope_private.go @@ -336,12 +336,14 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore func (scope *Scope) row() *sql.Row { defer scope.Trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callback.rowQueries) scope.prepareQuerySql() return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...) } func (scope *Scope) rows() (*sql.Rows, error) { defer scope.Trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callback.rowQueries) scope.prepareQuerySql() return scope.SqlDB().Query(scope.Sql, scope.SqlVars...) } From 055bf79f8bb07cf81706c25e5a7d18cd422c1d35 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 21 Apr 2015 11:24:48 +0800 Subject: [PATCH 12/52] Don't call method if value is nil --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index f1733aa3..559b4daa 100644 --- a/scope.go +++ b/scope.go @@ -177,7 +177,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { } func (scope *Scope) CallMethod(name string, checkError bool) { - if scope.Value == nil && (!checkError || !scope.HasError()) { + if scope.Value == nil || (checkError && scope.HasError()) { return } From 6d58dc9f4ec47282650fe3e0d3d2c7a0e95915fa Mon Sep 17 00:00:00 2001 From: bom_d_van Date: Tue, 21 Apr 2015 15:00:36 +0800 Subject: [PATCH 13/52] support nested preloading --- preload.go | 245 ++++++++++++++++++---------- preload_test.go | 421 +++++++++++++++++++++++++++++++++++++++++++++++- search.go | 19 ++- 3 files changed, 594 insertions(+), 91 deletions(-) diff --git a/preload.go b/preload.go index 541f7b95..0c8d70ad 100644 --- a/preload.go +++ b/preload.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "strings" ) func getRealValue(value reflect.Value, field string) interface{} { @@ -20,90 +21,139 @@ func equalAsString(a interface{}, b interface{}) bool { } func Preload(scope *Scope) { + preloadMap := map[string]bool{} if scope.Search.preload != nil { fields := scope.Fields() isSlice := scope.IndirectValue().Kind() == reflect.Slice - for key, conditions := range scope.Search.preload { - for _, field := range fields { - if field.Name == key && field.Relationship != nil { - results := makeSlice(field.Struct.Type) - relation := field.Relationship - primaryName := scope.PrimaryField().Name - associationPrimaryKey := scope.New(results).PrimaryField().Name - - switch relation.Kind { - case "has_one": - if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if isSlice { - value := getRealValue(result, relation.ForeignFieldName) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { - reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) - break - } - } - } else { - scope.SetColumn(field, result) - } - } - } - case "has_many": - if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - if isSlice { - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldName) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { - f := object.FieldByName(field.Name) - f.Set(reflect.Append(f, result)) - break - } - } - } - } else { - scope.SetColumn(field, resultValues) - } - } - case "belongs_to": - if primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 { - scope.NewDB().Where(primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if isSlice { - value := getRealValue(result, associationPrimaryKey) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.SetColumn(field, result) - } - } - } - case "many_to_many": - scope.Err(errors.New("not supported relation")) - default: - scope.Err(errors.New("not supported relation")) + for _, preload := range scope.Search.preload { + schema, conditions := preload.schema, preload.conditions + keys := strings.Split(schema, ".") + currentScope := scope + currentFields := fields + currentIsSlice := isSlice + originalConditions := conditions + conditions = []interface{}{} + for i, key := range keys { + // log.Printf("--> %+v\n", key) + if !preloadMap[strings.Join(keys[:i+1], ".")] { + if i == len(keys)-1 { + // log.Printf("--> %+v\n", originalConditions) + conditions = originalConditions } - break + + var found bool + for _, field := range currentFields { + if field.Name == key && field.Relationship != nil { + found = true + // log.Printf("--> %+v\n", field.Name) + results := makeSlice(field.Struct.Type) + relation := field.Relationship + primaryName := currentScope.PrimaryField().Name + associationPrimaryKey := currentScope.New(results).PrimaryField().Name + + switch relation.Kind { + case "has_one": + if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { + condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName)) + currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + + resultValues := reflect.Indirect(reflect.ValueOf(results)) + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if currentIsSlice { + value := getRealValue(result, relation.ForeignFieldName) + objects := currentScope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) + break + } + } + } else { + // log.Printf("--> %+v\n", result.Interface()) + err := currentScope.SetColumn(field, result) + if err != nil { + scope.Err(err) + return + } + // printutils.PrettyPrint(currentScope.Value) + } + } + // printutils.PrettyPrint(currentScope.Value) + } + case "has_many": + // log.Printf("--> %+v\n", key) + if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { + condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName)) + currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + if currentIsSlice { + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + value := getRealValue(result, relation.ForeignFieldName) + objects := currentScope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, primaryName), value) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + break + } + } + } + // printutils.PrettyPrint(currentScope.IndirectValue().Interface()) + } else { + currentScope.SetColumn(field, resultValues) + } + } + case "belongs_to": + if primaryKeys := currentScope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 { + currentScope.NewDB().Where(primaryKeys).Find(results, conditions...) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if currentIsSlice { + value := getRealValue(result, associationPrimaryKey) + objects := currentScope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { + object.FieldByName(field.Name).Set(result) + } + } + } else { + currentScope.SetColumn(field, result) + } + } + } + case "many_to_many": + // currentScope.Err(errors.New("not supported relation")) + fallthrough + default: + currentScope.Err(errors.New("not supported relation")) + } + break + } + } + + if !found { + value := reflect.ValueOf(currentScope.Value) + if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { + value = value.Index(0).Elem() + } + scope.Err(fmt.Errorf("can't found field %s in %s", key, value.Type())) + return + } + + preloadMap[strings.Join(keys[:i+1], ".")] = true + } + + if i < len(keys)-1 { + // TODO: update current scope + currentScope = currentScope.getColumnsAsScope(key) + currentFields = currentScope.Fields() + currentIsSlice = currentScope.IndirectValue().Kind() == reflect.Slice } } } @@ -120,19 +170,44 @@ func makeSlice(typ reflect.Type) interface{} { return slice.Interface() } -func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) { +func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { values := scope.IndirectValue() switch values.Kind() { case reflect.Slice: - primaryKeyMap := map[interface{}]bool{} for i := 0; i < values.Len(); i++ { - primaryKeyMap[reflect.Indirect(values.Index(i)).FieldByName(column).Interface()] = true - } - for key := range primaryKeyMap { - primaryKeys = append(primaryKeys, key) + columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) } case reflect.Struct: return []interface{}{values.FieldByName(column).Interface()} } return } + +func (scope *Scope) getColumnsAsScope(column string) *Scope { + values := scope.IndirectValue() + // log.Println(values.Type(), column) + switch values.Kind() { + case reflect.Slice: + fieldType, _ := values.Type().Elem().FieldByName(column) + var columns reflect.Value + if fieldType.Type.Kind() == reflect.Slice { + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem() + } else { + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type))).Elem() + } + for i := 0; i < values.Len(); i++ { + column := reflect.Indirect(values.Index(i)).FieldByName(column) + if column.Kind() == reflect.Slice { + for i := 0; i < column.Len(); i++ { + columns = reflect.Append(columns, column.Index(i).Addr()) + } + } else { + columns = reflect.Append(columns, column.Addr()) + } + } + return scope.New(columns.Interface()) + case reflect.Struct: + return scope.New(values.FieldByName(column).Addr().Interface()) + } + return nil +} diff --git a/preload_test.go b/preload_test.go index 2547933b..c5d395d4 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1,6 +1,11 @@ package gorm_test -import "testing" +import ( + "encoding/json" + "log" + "reflect" + "testing" +) func getPreloadUser(name string) *User { return getPreparedUser(name, "Preload") @@ -85,3 +90,417 @@ func TestPreload(t *testing.T) { } } } + +func TestNestedPreload(t *testing.T) { + log.SetFlags(log.Lshortfile) + // Struct: Level3 + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + + // Slice: []Level3 + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value5"}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value3"}}, + {Level1: Level1{Value: "value4"}}, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } + { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value3"}, + Level1{Value: "value4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } + } +} + +func toJSONString(v interface{}) []byte { + r, _ := json.MarshalIndent(v, "", " ") + return r +} diff --git a/search.go b/search.go index a180fb92..9411af43 100644 --- a/search.go +++ b/search.go @@ -14,7 +14,7 @@ type search struct { omits []string orders []string joins string - preload map[string][]interface{} + preload []searchPreload offset string limit string group string @@ -23,6 +23,11 @@ type search struct { Unscoped bool } +type searchPreload struct { + schema string + conditions []interface{} +} + func (s *search) clone() *search { clone := *s return &clone @@ -97,11 +102,15 @@ func (s *search) Joins(query string) *search { return s } -func (s *search) Preload(column string, values ...interface{}) *search { - if s.preload == nil { - s.preload = map[string][]interface{}{} +func (s *search) Preload(schema string, values ...interface{}) *search { + var preloads []searchPreload + for _, preload := range s.preload { + if preload.schema != schema { + preloads = append(preloads, preload) + } } - s.preload[column] = values + preloads = append(preloads, searchPreload{schema, values}) + s.preload = preloads return s } From 9e9367e8157453e151f0b8da795d45a688cacb98 Mon Sep 17 00:00:00 2001 From: bom_d_van Date: Tue, 21 Apr 2015 16:51:52 +0800 Subject: [PATCH 14/52] refactor preload and its tests --- preload.go | 288 ++++++++-------- preload_test.go | 862 +++++++++++++++++++++++++++--------------------- 2 files changed, 635 insertions(+), 515 deletions(-) diff --git a/preload.go b/preload.go index 0c8d70ad..42836067 100644 --- a/preload.go +++ b/preload.go @@ -21,143 +21,69 @@ func equalAsString(a interface{}, b interface{}) bool { } func Preload(scope *Scope) { + if scope.Search.preload == nil { + return + } + preloadMap := map[string]bool{} - if scope.Search.preload != nil { - fields := scope.Fields() - isSlice := scope.IndirectValue().Kind() == reflect.Slice + fields := scope.Fields() + for _, preload := range scope.Search.preload { + schema, conditions := preload.schema, preload.conditions + keys := strings.Split(schema, ".") + currentScope := scope + currentFields := fields + originalConditions := conditions + conditions = []interface{}{} + for i, key := range keys { + var found bool + if preloadMap[strings.Join(keys[:i+1], ".")] { + goto nextLoop + } - for _, preload := range scope.Search.preload { - schema, conditions := preload.schema, preload.conditions - keys := strings.Split(schema, ".") - currentScope := scope - currentFields := fields - currentIsSlice := isSlice - originalConditions := conditions - conditions = []interface{}{} - for i, key := range keys { - // log.Printf("--> %+v\n", key) - if !preloadMap[strings.Join(keys[:i+1], ".")] { - if i == len(keys)-1 { - // log.Printf("--> %+v\n", originalConditions) - conditions = originalConditions - } + if i == len(keys)-1 { + conditions = originalConditions + } - var found bool - for _, field := range currentFields { - if field.Name == key && field.Relationship != nil { - found = true - // log.Printf("--> %+v\n", field.Name) - results := makeSlice(field.Struct.Type) - relation := field.Relationship - primaryName := currentScope.PrimaryField().Name - associationPrimaryKey := currentScope.New(results).PrimaryField().Name - - switch relation.Kind { - case "has_one": - if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName)) - currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if currentIsSlice { - value := getRealValue(result, relation.ForeignFieldName) - objects := currentScope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { - reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) - break - } - } - } else { - // log.Printf("--> %+v\n", result.Interface()) - err := currentScope.SetColumn(field, result) - if err != nil { - scope.Err(err) - return - } - // printutils.PrettyPrint(currentScope.Value) - } - } - // printutils.PrettyPrint(currentScope.Value) - } - case "has_many": - // log.Printf("--> %+v\n", key) - if primaryKeys := currentScope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", currentScope.Quote(relation.ForeignDBName)) - currentScope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - if currentIsSlice { - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldName) - objects := currentScope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { - f := object.FieldByName(field.Name) - f.Set(reflect.Append(f, result)) - break - } - } - } - // printutils.PrettyPrint(currentScope.IndirectValue().Interface()) - } else { - currentScope.SetColumn(field, resultValues) - } - } - case "belongs_to": - if primaryKeys := currentScope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 { - currentScope.NewDB().Where(primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if currentIsSlice { - value := getRealValue(result, associationPrimaryKey) - objects := currentScope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - currentScope.SetColumn(field, result) - } - } - } - case "many_to_many": - // currentScope.Err(errors.New("not supported relation")) - fallthrough - default: - currentScope.Err(errors.New("not supported relation")) - } - break - } - } - - if !found { - value := reflect.ValueOf(currentScope.Value) - if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { - value = value.Index(0).Elem() - } - scope.Err(fmt.Errorf("can't found field %s in %s", key, value.Type())) - return - } - - preloadMap[strings.Join(keys[:i+1], ".")] = true + for _, field := range currentFields { + if field.Name != key || field.Relationship == nil { + continue } - if i < len(keys)-1 { - // TODO: update current scope - currentScope = currentScope.getColumnsAsScope(key) - currentFields = currentScope.Fields() - currentIsSlice = currentScope.IndirectValue().Kind() == reflect.Slice + found = true + switch field.Relationship.Kind { + case "has_one": + currentScope.handleHasOnePreload(field, conditions) + case "has_many": + currentScope.handleHasManyPreload(field, conditions) + case "belongs_to": + currentScope.handleBelongsToPreload(field, conditions) + case "many_to_many": + fallthrough + default: + currentScope.Err(errors.New("not supported relation")) } + break + } + + if !found { + value := reflect.ValueOf(currentScope.Value) + if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { + value = value.Index(0).Elem() + } + scope.Err(fmt.Errorf("can't found field %s in %s", key, value.Type())) + return + } + + preloadMap[strings.Join(keys[:i+1], ".")] = true + + nextLoop: + if i < len(keys)-1 { + currentScope = currentScope.getColumnsAsScope(key) + currentFields = currentScope.Fields() } } } + } func makeSlice(typ reflect.Type) interface{} { @@ -170,6 +96,105 @@ func makeSlice(typ reflect.Type) interface{} { return slice.Interface() } +func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { + primaryName := scope.PrimaryField().Name + primaryKeys := scope.getColumnAsArray(primaryName) + if len(primaryKeys) == 0 { + return + } + + results := makeSlice(field.Struct.Type) + relation := field.Relationship + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + + // TODO: handle error? + scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if scope.IndirectValue().Kind() == reflect.Slice { + value := getRealValue(result, relation.ForeignFieldName) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) + break + } + } + } else { + err := scope.SetColumn(field, result) + if err != nil { + scope.Err(err) + return + } + } + } +} + +func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { + primaryName := scope.PrimaryField().Name + primaryKeys := scope.getColumnAsArray(primaryName) + if len(primaryKeys) == 0 { + return + } + + results := makeSlice(field.Struct.Type) + relation := field.Relationship + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + + scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + + if scope.IndirectValue().Kind() == reflect.Slice { + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + value := getRealValue(result, relation.ForeignFieldName) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, primaryName), value) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + break + } + } + } + } else { + scope.SetColumn(field, resultValues) + } +} + +func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName) + if len(primaryKeys) == 0 { + return + } + + results := makeSlice(field.Struct.Type) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + associationPrimaryKey := scope.New(results).PrimaryField().Name + + scope.NewDB().Where(primaryKeys).Find(results, conditions...) + + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if scope.IndirectValue().Kind() == reflect.Slice { + value := getRealValue(result, associationPrimaryKey) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { + object.FieldByName(field.Name).Set(result) + } + } + } else { + scope.SetColumn(field, result) + } + } +} + func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { values := scope.IndirectValue() switch values.Kind() { @@ -185,10 +210,13 @@ func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { func (scope *Scope) getColumnsAsScope(column string) *Scope { values := scope.IndirectValue() - // log.Println(values.Type(), column) switch values.Kind() { case reflect.Slice: - fieldType, _ := values.Type().Elem().FieldByName(column) + model := values.Type().Elem() + if model.Kind() == reflect.Ptr { + model = model.Elem() + } + fieldType, _ := model.FieldByName(column) var columns reflect.Value if fieldType.Type.Kind() == reflect.Slice { columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem() diff --git a/preload_test.go b/preload_test.go index c5d395d4..2929392c 100644 --- a/preload_test.go +++ b/preload_test.go @@ -2,7 +2,6 @@ package gorm_test import ( "encoding/json" - "log" "reflect" "testing" ) @@ -91,412 +90,505 @@ func TestPreload(t *testing.T) { } } -func TestNestedPreload(t *testing.T) { - log.SetFlags(log.Lshortfile) - // Struct: Level3 - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) +func TestNestedPreload1(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint } - - want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want).Error; err != nil { - panic(err) + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint } - - var got Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + Level3 struct { + ID uint + Level2 Level2 } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - want := Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value3"}, - }, + want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload2(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []*Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + { + Level1s: []*Level1{ + &Level1{Value: "value1"}, + &Level1{Value: "value2"}, }, }, - } - if err := DB.Create(&want).Error; err != nil { - panic(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, + { + Level1s: []*Level1{ + &Level1{Value: "value3"}, + }, }, - } - if err := DB.Create(&want).Error; err != nil { - panic(err) - } - - var got Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - want := Level3{ - Level2: Level2{ + var got Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload3(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload4(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +// Slice: []Level3 +func TestNestedPreload5(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload6(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + { Level1s: []Level1{ - Level1{Value: "value1"}, - Level1{Value: "value2"}, + {Value: "value1"}, + {Value: "value2"}, }, }, - } - if err := DB.Create(&want).Error; err != nil { - panic(err) - } - - var got Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - - // Slice: []Level3 - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := make([]Level3, 2) - want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} - if err := DB.Create(&want[0]).Error; err != nil { - panic(err) - } - want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} - if err := DB.Create(&want[1]).Error; err != nil { - panic(err) - } - - var got []Level3 - if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value1"}, - {Value: "value2"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value3"}, - }, - }, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - panic(err) - } - want[1] = Level3{ - Level2s: []Level2{ - { - Level1s: []Level1{ - {Value: "value3"}, - {Value: "value4"}, - }, - }, - { - Level1s: []Level1{ - {Value: "value5"}, - }, - }, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - panic(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1 Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2s []Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value1"}}, - {Level1: Level1{Value: "value2"}}, - }, - } - if err := DB.Create(&want[0]).Error; err != nil { - panic(err) - } - want[1] = Level3{ - Level2s: []Level2{ - {Level1: Level1{Value: "value3"}}, - {Level1: Level1{Value: "value4"}}, - }, - } - if err := DB.Create(&want[1]).Error; err != nil { - panic(err) - } - - var got []Level3 - if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { - panic(err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) - } - } - { - type ( - Level1 struct { - ID uint - Value string - Level2ID uint - } - Level2 struct { - ID uint - Level1s []Level1 - Level3ID uint - } - Level3 struct { - ID uint - Level2 Level2 - } - ) - DB.DropTableIfExists(&Level3{}) - DB.DropTableIfExists(&Level2{}) - DB.DropTableIfExists(&Level1{}) - if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { - panic(err) - } - - want := make([]Level3, 2) - want[0] = Level3{ - Level2: Level2{ + { Level1s: []Level1{ - Level1{Value: "value1"}, - Level1{Value: "value2"}, + {Value: "value3"}, }, }, - } - if err := DB.Create(&want[0]).Error; err != nil { - panic(err) - } - want[1] = Level3{ - Level2: Level2{ + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2s: []Level2{ + { Level1s: []Level1{ - Level1{Value: "value3"}, - Level1{Value: "value4"}, + {Value: "value3"}, + {Value: "value4"}, }, }, - } - if err := DB.Create(&want[1]).Error; err != nil { - panic(err) - } + { + Level1s: []Level1{ + {Value: "value5"}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } - var got []Level3 - if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { - panic(err) - } + var got []Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } - if !reflect.DeepEqual(got, want) { - t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload7(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value3"}}, + {Level1: Level1{Value: "value4"}}, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload8(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value3"}, + Level1{Value: "value4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload9(t *testing.T) { + type ( + Level0 struct { + ID uint + Value string + Level1ID uint + } + Level1 struct { + ID uint + Value string + Level2ID uint + Level2_1ID uint + Level0s []Level0 + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level2_1 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Level2 Level2 + Level2_1 Level2_1 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level2_1{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level0{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + Level1{ + Value: "value1-1", + Level0s: []Level0{{Value: "Level0-1"}}, + }, + Level1{ + Value: "value2-2", + Level0s: []Level0{{Value: "Level0-2"}}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value3"}, + Level1{Value: "value4"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + Level1{Value: "value3-3"}, + Level1{Value: "value4-4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } From 7ef8b06cb40757138ab81f03d20e0ebb2ca73da6 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Apr 2015 14:28:50 +0800 Subject: [PATCH 15/52] Fix tests with mysql --- preload_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/preload_test.go b/preload_test.go index 2929392c..7fbba5c0 100644 --- a/preload_test.go +++ b/preload_test.go @@ -104,6 +104,7 @@ func TestNestedPreload1(t *testing.T) { } Level3 struct { ID uint + Name string Level2 Level2 } ) @@ -143,6 +144,7 @@ func TestNestedPreload2(t *testing.T) { } Level3 struct { ID uint + Name string Level2s []Level2 } ) @@ -154,6 +156,7 @@ func TestNestedPreload2(t *testing.T) { } want := Level3{ + Name: "name", Level2s: []Level2{ { Level1s: []*Level1{ @@ -195,6 +198,7 @@ func TestNestedPreload3(t *testing.T) { Level3ID uint } Level3 struct { + Name string ID uint Level2s []Level2 } @@ -207,6 +211,7 @@ func TestNestedPreload3(t *testing.T) { } want := Level3{ + Name: "name", Level2s: []Level2{ {Level1: Level1{Value: "value1"}}, {Level1: Level1{Value: "value2"}}, @@ -240,6 +245,7 @@ func TestNestedPreload4(t *testing.T) { } Level3 struct { ID uint + Name string Level2 Level2 } ) @@ -287,6 +293,7 @@ func TestNestedPreload5(t *testing.T) { } Level3 struct { ID uint + Name string Level2 Level2 } ) @@ -331,6 +338,7 @@ func TestNestedPreload6(t *testing.T) { } Level3 struct { ID uint + Name string Level2s []Level2 } ) @@ -403,6 +411,7 @@ func TestNestedPreload7(t *testing.T) { } Level3 struct { ID uint + Name string Level2s []Level2 } ) @@ -457,6 +466,7 @@ func TestNestedPreload8(t *testing.T) { } Level3 struct { ID uint + Name string Level2 Level2 } ) @@ -527,6 +537,7 @@ func TestNestedPreload9(t *testing.T) { } Level3 struct { ID uint + Name string Level2 Level2 Level2_1 Level2_1 } From 7693c093a96ef1bd0c8312b938201997dfe99f46 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Apr 2015 15:36:10 +0800 Subject: [PATCH 16/52] Refactor Preload --- preload.go | 30 ++++++++++++++---------------- preload_test.go | 4 ++-- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/preload.go b/preload.go index 42836067..add077ab 100644 --- a/preload.go +++ b/preload.go @@ -106,10 +106,9 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) results := makeSlice(field.Struct.Type) relation := field.Relationship condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - // TODO: handle error? - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) @@ -123,8 +122,7 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) } } } else { - err := scope.SetColumn(field, result) - if err != nil { + if err := scope.SetColumn(field, result); err != nil { scope.Err(err) return } @@ -142,9 +140,9 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) results := makeSlice(field.Struct.Type) relation := field.Relationship condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) + scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + resultValues := reflect.Indirect(reflect.ValueOf(results)) if scope.IndirectValue().Kind() == reflect.Slice { for i := 0; i < resultValues.Len(); i++ { @@ -173,10 +171,10 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ } results := makeSlice(field.Struct.Type) - resultValues := reflect.Indirect(reflect.ValueOf(results)) associationPrimaryKey := scope.New(results).PrimaryField().Name - scope.NewDB().Where(primaryKeys).Find(results, conditions...) + scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error) + resultValues := reflect.Indirect(reflect.ValueOf(results)) for i := 0; i < resultValues.Len(); i++ { result := resultValues.Index(i) @@ -212,16 +210,16 @@ func (scope *Scope) getColumnsAsScope(column string) *Scope { values := scope.IndirectValue() switch values.Kind() { case reflect.Slice: - model := values.Type().Elem() - if model.Kind() == reflect.Ptr { - model = model.Elem() + modelType := values.Type().Elem() + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() } - fieldType, _ := model.FieldByName(column) + fieldStruct, _ := modelType.FieldByName(column) var columns reflect.Value - if fieldType.Type.Kind() == reflect.Slice { - columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type.Elem()))).Elem() + if fieldStruct.Type.Kind() == reflect.Slice { + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() } else { - columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType.Type))).Elem() + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() } for i := 0; i < values.Len(); i++ { column := reflect.Indirect(values.Index(i)).FieldByName(column) diff --git a/preload_test.go b/preload_test.go index 7fbba5c0..a6647bbd 100644 --- a/preload_test.go +++ b/preload_test.go @@ -156,7 +156,6 @@ func TestNestedPreload2(t *testing.T) { } want := Level3{ - Name: "name", Level2s: []Level2{ { Level1s: []*Level1{ @@ -211,7 +210,6 @@ func TestNestedPreload3(t *testing.T) { } want := Level3{ - Name: "name", Level2s: []Level2{ {Level1: Level1{Value: "value1"}}, {Level1: Level1{Value: "value2"}}, @@ -368,6 +366,7 @@ func TestNestedPreload6(t *testing.T) { if err := DB.Create(&want[0]).Error; err != nil { panic(err) } + want[1] = Level3{ Level2s: []Level2{ { @@ -432,6 +431,7 @@ func TestNestedPreload7(t *testing.T) { if err := DB.Create(&want[0]).Error; err != nil { panic(err) } + want[1] = Level3{ Level2s: []Level2{ {Level1: Level1{Value: "value3"}}, From 3c2915a9dfaaeaeb1e049436ae9291f5e84c95fa Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 22 Apr 2015 15:38:22 +0800 Subject: [PATCH 17/52] Add Nested Preloading to README --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 4bcba989..5ab38976 100644 --- a/README.md +++ b/README.md @@ -347,6 +347,13 @@ db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users) //// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to ``` +#### Nested Preloading + +```go +db.Preload("Orders.OrderItems").Find(&users) +db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users) +``` + ## Update ```go From a9aef2dd90791b88febfb0b85fa4d7ce597acd93 Mon Sep 17 00:00:00 2001 From: nicholasf Date: Wed, 22 Apr 2015 12:54:19 -0700 Subject: [PATCH 18/52] Simpler (and correct) example for showing how to declare foreign keys. Impact: Trivial. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5ab38976..6e929c4c 100644 --- a/README.md +++ b/README.md @@ -1091,7 +1091,7 @@ type Product struct { // 2nd param : destination table(id) // 3rd param : ONDELETE // 4th param : ONUPDATE -db.Model(&User{}).AddForeignKey("user_id", "destination_table(id)", "CASCADE", "RESTRICT") +db.Model(&User{}).AddForeignKey("role_id", "roles", "CASCADE", "RESTRICT") // Add index db.Model(&User{}).AddIndex("idx_user_name", "name") From fa696b0e2f84ee48dd9ce374edeec881a01aaa9a Mon Sep 17 00:00:00 2001 From: Femaref Date: Sun, 26 Apr 2015 16:34:52 +0200 Subject: [PATCH 19/52] Quote the primary column name when doing queries Postgresql requires certain column names to be quoted. When unquoted, all upper-case characters will be converted to lower-case, and column names like 'typeID' will result in a query on 'typeid'. --- callback_query.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callback_query.go b/callback_query.go index 0eea6f89..825caac1 100644 --- a/callback_query.go +++ b/callback_query.go @@ -18,7 +18,7 @@ func Query(scope *Scope) { if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { if primaryKey := scope.PrimaryKey(); primaryKey != "" { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy)) + scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy)) } } From a4a29d6025bd28f6ad7755e563e6e65aebb14703 Mon Sep 17 00:00:00 2001 From: Nguyen Dang Minh Date: Sat, 2 May 2015 16:19:54 +0700 Subject: [PATCH 20/52] Update README.md The line: // db := gorm.Open("postgres", dbSql) should be // db, _ := gorm.Open("postgres", dbSql) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6e929c4c..c4ff661d 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") // You can also use an existing database connection handle // dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") -// db := gorm.Open("postgres", dbSql) +// db, _ := gorm.Open("postgres", dbSql) // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB) db.DB() From a0c527f1cce14d6b0d9f13b5f41080365e4b86ae Mon Sep 17 00:00:00 2001 From: Constantin Schomburg Date: Sat, 9 May 2015 13:12:13 +0200 Subject: [PATCH 21/52] Fix including ignored field in Where condition --- scope_private.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope_private.go b/scope_private.go index b9476455..4ecefe3a 100644 --- a/scope_private.go +++ b/scope_private.go @@ -38,7 +38,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case interface{}: var sqls []string for _, field := range scope.New(value).Fields() { - if !field.IsBlank { + if !field.IsIgnored && !field.IsBlank { sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } From 5af077cd2d8ac5992d4d79661ee6a1db161204c4 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 11 May 2015 15:17:35 +0800 Subject: [PATCH 22/52] Handle []string for Select --- scope.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scope.go b/scope.go index 559b4daa..21ad2991 100644 --- a/scope.go +++ b/scope.go @@ -368,6 +368,8 @@ func (scope *Scope) SelectAttrs() []string { for _, value := range scope.Search.selects { if str, ok := value.(string); ok { attrs = append(attrs, str) + } else if strs, ok := value.([]string); ok { + attrs = append(attrs, strs...) } else if strs, ok := value.([]interface{}); ok { for _, str := range strs { attrs = append(attrs, fmt.Sprintf("%v", str)) From 2adfd70bb5e92ecf91186f4414cdb59f73513314 Mon Sep 17 00:00:00 2001 From: Daniel Perez Date: Wed, 13 May 2015 14:20:20 +0900 Subject: [PATCH 23/52] Add examples for join. --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index c4ff661d..21707a8c 100644 --- a/README.md +++ b/README.md @@ -845,6 +845,12 @@ for rows.Next() { } db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results) + +// find a user by email address +db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user) + +// find all email addresses for a user +db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails) ``` ## Transactions From dc55c59b842d79778fec9929bd84c9dbf52b4fbc Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 19 May 2015 10:43:32 +0800 Subject: [PATCH 24/52] Fix use SQL as table name --- scope.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scope.go b/scope.go index 21ad2991..54bf5c84 100644 --- a/scope.go +++ b/scope.go @@ -261,6 +261,9 @@ func (scope *Scope) TableName() string { func (scope *Scope) QuotedTableName() (name string) { if scope.Search != nil && len(scope.Search.tableName) > 0 { + if strings.Index(scope.Search.tableName, " ") != -1 { + return scope.Search.tableName + } return scope.Quote(scope.Search.tableName) } else { return scope.Quote(scope.TableName()) From c2dda88f9a8a0b5c73733e87689dc94cf2818fa2 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 19 May 2015 16:58:33 +0800 Subject: [PATCH 25/52] Use Get to replace InstanceGet --- callback_query.go | 2 +- main.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/callback_query.go b/callback_query.go index 825caac1..59022eba 100644 --- a/callback_query.go +++ b/callback_query.go @@ -23,7 +23,7 @@ func Query(scope *Scope) { } var dest = scope.IndirectValue() - if value, ok := scope.InstanceGet("gorm:query_destination"); ok { + if value, ok := scope.Get("gorm:query_destination"); ok { dest = reflect.Indirect(reflect.ValueOf(value)) } diff --git a/main.go b/main.go index 04f59bcf..bf8acbae 100644 --- a/main.go +++ b/main.go @@ -212,7 +212,7 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { } func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db + return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db } func (s *DB) Row() *sql.Row { From cbebcf6d6fe4961d6f0ce4115f97301a942bd9ce Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 22 May 2015 11:13:14 +0800 Subject: [PATCH 26/52] Quote db name when create primary keys --- scope_private.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope_private.go b/scope_private.go index 4ecefe3a..63fcea46 100644 --- a/scope_private.go +++ b/scope_private.go @@ -475,7 +475,7 @@ func (scope *Scope) createTable() *Scope { } if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, field.DBName) + primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) } scope.createJoinTable(field) } From b96ca76e5924c4f83d9bac72415dfe754e2812bd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 27 May 2015 12:19:48 +0800 Subject: [PATCH 27/52] Set table name handler --- main.go | 4 ---- model_struct.go | 36 +++++++++++++++++++++--------------- scope.go | 7 +------ 3 files changed, 22 insertions(+), 25 deletions(-) diff --git a/main.go b/main.go index bf8acbae..181722fd 100644 --- a/main.go +++ b/main.go @@ -485,7 +485,3 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join } } } - -func (s *DB) SetTableNameHandler(source interface{}, handler func(*DB) string) { - s.NewScope(source).GetModelStruct().TableName = handler -} diff --git a/model_struct.go b/model_struct.go index a70489fc..10423ae2 100644 --- a/model_struct.go +++ b/model_struct.go @@ -13,11 +13,19 @@ import ( var modelStructs = map[reflect.Type]*ModelStruct{} +var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { + return defaultTableName +} + type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - TableName func(*DB) string + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type + defaultTableName string +} + +func (s ModelStruct) TableName(db *DB) string { + return DefaultTableNameHandler(db, s.defaultTableName) } type StructField struct { @@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Set tablename - if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { - if results := fm.Call([]reflect.Value{}); len(results) > 0 { - if name, ok := results[0].Interface().(string); ok { - modelStruct.TableName = func(*DB) string { - return name - } - } - } + type tabler interface { + TableName() string + } + + if tabler, ok := reflect.New(scopeType).Interface().(interface { + TableName() string + }); ok { + modelStruct.defaultTableName = tabler.TableName() } else { name := ToDBName(scopeType.Name()) if scope.db == nil || !scope.db.parent.singularTable { @@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStruct.TableName = func(*DB) string { - return name - } + modelStruct.defaultTableName = name } // Get all fields diff --git a/scope.go b/scope.go index 54bf5c84..960a653c 100644 --- a/scope.go +++ b/scope.go @@ -251,12 +251,7 @@ func (scope *Scope) TableName() string { return tabler.TableName(scope.db) } - if scope.GetModelStruct().TableName != nil { - return scope.GetModelStruct().TableName(scope.db) - } - - scope.Err(errors.New("wrong table name")) - return "" + return scope.GetModelStruct().TableName(scope.db) } func (scope *Scope) QuotedTableName() (name string) { From 331d8ceabdff5ebf4dc2aba2ccc19772ba2db17f Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2015 11:04:11 +0800 Subject: [PATCH 28/52] Don't overwrite primary key if already it is already exist --- callback_create.go | 2 +- main_test.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/callback_create.go b/callback_create.go index b21df08b..e7ec40bb 100644 --- a/callback_create.go +++ b/callback_create.go @@ -70,7 +70,7 @@ func Create(scope *Scope) { id, err := result.LastInsertId() if scope.Err(err) == nil { scope.db.RowsAffected, _ = result.RowsAffected() - if primaryField != nil { + if primaryField != nil && primaryField.IsBlank { scope.Err(scope.SetColumn(primaryField, id)) } } diff --git a/main_test.go b/main_test.go index b547534c..666ba564 100644 --- a/main_test.go +++ b/main_test.go @@ -61,6 +61,18 @@ func init() { runMigration() } +func TestStringPrimaryKey(t *testing.T) { + type UUIDStruct struct { + ID string `gorm:"primary_key"` + Name string + } + + data := UUIDStruct{ID: "uuid", Name: "hello"} + if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { + t.Errorf("string primary key should not be populated") + } +} + func TestExceptionsWithInvalidSql(t *testing.T) { var columns []string if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { From 58f379b098dfaa2614b5705ed97dc7ed5d01d26e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2015 11:17:51 +0800 Subject: [PATCH 29/52] Add auto migration --- main_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/main_test.go b/main_test.go index 666ba564..0dc5e337 100644 --- a/main_test.go +++ b/main_test.go @@ -66,6 +66,7 @@ func TestStringPrimaryKey(t *testing.T) { ID string `gorm:"primary_key"` Name string } + DB.AutoMigrate(&UUIDStruct{}) data := UUIDStruct{ID: "uuid", Name: "hello"} if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { From 4fd6e62a022815853eecdd556add706ffd005b54 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 1 Jun 2015 14:02:15 +0800 Subject: [PATCH 30/52] Add unsigned support for mysql --- mysql.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mysql.go b/mysql.go index e37a23e0..a5e4a459 100644 --- a/mysql.go +++ b/mysql.go @@ -14,16 +14,26 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: if autoIncrease { return "int AUTO_INCREMENT" } return "int" - case reflect.Int64, reflect.Uint64: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "int unsigned AUTO_INCREMENT" + } + return "int unsigned" + case reflect.Int64: if autoIncrease { return "bigint AUTO_INCREMENT" } return "bigint" + case reflect.Uint64: + if autoIncrease { + return "bigint unsigned AUTO_INCREMENT" + } + return "bigint unsigned" case reflect.Float32, reflect.Float64: return "double" case reflect.String: From 65b42ad6f3fe491bc4ead1ab6439997419d7053b Mon Sep 17 00:00:00 2001 From: Soroush Mirzaei Date: Mon, 1 Jun 2015 14:20:23 +0430 Subject: [PATCH 31/52] Fixed the "Query callbacks" URL. Changed it to point to `callback_query.go` instead of `callback_create.go`. --- doc/development.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/development.md b/doc/development.md index 674cfc43..08166661 100644 --- a/doc/development.md +++ b/doc/development.md @@ -61,7 +61,7 @@ Gorm is powered by callbacks, so you could refer below links to learn how to wri [Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go) -[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go) +[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go) [Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go) From 0b8c9f29a9e28708c17e0ef1aed71487e8bd356c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 4 Jun 2015 12:10:09 +0800 Subject: [PATCH 32/52] Find Field by db name also --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 960a653c..11bad777 100644 --- a/scope.go +++ b/scope.go @@ -273,7 +273,7 @@ func (scope *Scope) CombinedConditionSql() string { func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { for _, field := range scope.Fields() { - if field.Name == name { + if field.Name == name || field.DBName == name { return field, true } } From 94f56e646bb69d8128545a1b0cee90eb783bafca Mon Sep 17 00:00:00 2001 From: Bojan Petrovic Date: Thu, 4 Jun 2015 13:47:25 +0200 Subject: [PATCH 33/52] Fix Create when dialect does not support last inserted id --- callback_create.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/callback_create.go b/callback_create.go index e7ec40bb..9f0c9bc2 100644 --- a/callback_create.go +++ b/callback_create.go @@ -77,7 +77,7 @@ func Create(scope *Scope) { } } else { if primaryField == nil { - if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil { + if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { scope.db.RowsAffected, _ = results.RowsAffected() } } else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { From d9faa4949cf3cc73602389ba5bbf95ea012ca9eb Mon Sep 17 00:00:00 2001 From: Bojan Petrovic Date: Thu, 4 Jun 2015 14:23:57 +0200 Subject: [PATCH 34/52] Fix Create error reporting. --- callback_create.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/callback_create.go b/callback_create.go index 9f0c9bc2..7f21ed6a 100644 --- a/callback_create.go +++ b/callback_create.go @@ -79,9 +79,15 @@ func Create(scope *Scope) { if primaryField == nil { if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { scope.db.RowsAffected, _ = results.RowsAffected() + } else { + scope.Err(err) + } + } else { + if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { + scope.db.RowsAffected = 1 + } else { + scope.Err(err) } - } else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { - scope.db.RowsAffected = 1 } } } From f05a6b37949f3a28f07debeaa25c14792323c7a8 Mon Sep 17 00:00:00 2001 From: crystalin Date: Fri, 5 Jun 2015 12:54:52 +0200 Subject: [PATCH 35/52] Support for preload of Struct Ptr This fixes the issue when preloading .Preload("Project.Repositories").Find(&[]User{}) with type User struct { Project *Project } type Project struct { Repositories []Repository } type Repository struct { ... } --- preload.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/preload.go b/preload.go index add077ab..f1c0fae5 100644 --- a/preload.go +++ b/preload.go @@ -216,13 +216,16 @@ func (scope *Scope) getColumnsAsScope(column string) *Scope { } fieldStruct, _ := modelType.FieldByName(column) var columns reflect.Value - if fieldStruct.Type.Kind() == reflect.Slice { + if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr { columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() } else { columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() } for i := 0; i < values.Len(); i++ { column := reflect.Indirect(values.Index(i)).FieldByName(column) + if column.Kind() == reflect.Ptr { + column = column.Elem() + } if column.Kind() == reflect.Slice { for i := 0; i < column.Len(); i++ { columns = reflect.Append(columns, column.Index(i).Addr()) From 5b282263d87696d58d90842ffca85d79ed0681a6 Mon Sep 17 00:00:00 2001 From: Yan-Fa Li Date: Tue, 9 Jun 2015 14:51:46 -0700 Subject: [PATCH 36/52] Update README.md with Transaction example - add a more detailed slightly more realistic example for a transaction. --- README.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/README.md b/README.md index 21707a8c..5e790a1c 100644 --- a/README.md +++ b/README.md @@ -868,6 +868,27 @@ tx.Rollback() tx.Commit() ``` +### More Complex Example +``` +func CreateAnimals(db *gorm.DB) err { + tx := db.Begin() + // Note the use of tx as the database handle once you are within a transaction + + if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil { + tx.Rollback() + return err + } + + if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil { + tx.Rollback() + return err + } + + tx.Commit() + return nil +} +``` + ## Scopes ```go From 14dde4b9f2401901747f24f8c48700f87ef74ade Mon Sep 17 00:00:00 2001 From: Joakim Lundborg Date: Thu, 11 Jun 2015 16:14:36 +0200 Subject: [PATCH 37/52] Correct error message --- preload.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/preload.go b/preload.go index f1c0fae5..03910c44 100644 --- a/preload.go +++ b/preload.go @@ -70,7 +70,7 @@ func Preload(scope *Scope) { if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { value = value.Index(0).Elem() } - scope.Err(fmt.Errorf("can't found field %s in %s", key, value.Type())) + scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type())) return } From 64f61aaaf99af2b8e8623fe9ac17f1b44a20bc70 Mon Sep 17 00:00:00 2001 From: Rohan Allison Date: Mon, 15 Jun 2015 14:37:58 -0500 Subject: [PATCH 38/52] Update README for revertable transactions --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 5e790a1c..f2b9978b 100644 --- a/README.md +++ b/README.md @@ -861,6 +861,9 @@ All individual save and delete operations are run in a transaction by default. // begin tx := db.Begin() +// do revertable work in a transaction (use 'tx' in place of 'db') +tx.Exec + // rollback tx.Rollback() From ded91a21fe275c1d8dd999d0141560e077db899e Mon Sep 17 00:00:00 2001 From: Rohan Allison Date: Tue, 16 Jun 2015 23:28:54 -0500 Subject: [PATCH 39/52] Update README with clear explanation of transaction db handle --- README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f2b9978b..ccab06db 100644 --- a/README.md +++ b/README.md @@ -855,23 +855,26 @@ db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", ## Transactions -All individual save and delete operations are run in a transaction by default. +To perform a set of operations within a transaction, the general flow is as below. +The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction. +(Note that all individual save and delete operations are run in a transaction by default.) ```go // begin tx := db.Begin() -// do revertable work in a transaction (use 'tx' in place of 'db') -tx.Exec +// do some database operations (use 'tx' from this point, not 'db') +tx.Create(...) +... -// rollback +// rollback in case of error tx.Rollback() -// commit +// Or commit if all is ok tx.Commit() ``` -### More Complex Example +### A Specific Example ``` func CreateAnimals(db *gorm.DB) err { tx := db.Begin() From cad0a428754eb0cd93ef5c6cdb062853872649dd Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 18 Jun 2015 15:39:21 +0800 Subject: [PATCH 40/52] Get correct quoted table name --- join_table_handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/join_table_handler.go b/join_table_handler.go index e589d6f5..27051cbd 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -151,7 +151,7 @@ func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { var values []interface{} if s.Source.ModelType == modelType { for _, foreignKey := range s.Destination.ForeignKeys { - destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } From 7e8622f67140131e4afdc0cb54acb5acfff71943 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Thu, 18 Jun 2015 18:23:11 +0800 Subject: [PATCH 41/52] Don't need to delete join table records if no record added --- association.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/association.go b/association.go index 37e10516..dbc928e8 100644 --- a/association.go +++ b/association.go @@ -131,9 +131,11 @@ func (association *Association) Replace(values ...interface{}) *Association { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } - sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) - association.setErr(relationship.JoinTableHandler.Delete(query, relationship)) + if len(addedPrimaryKeys) > 0 { + sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) + query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) + association.setErr(relationship.JoinTableHandler.Delete(query, relationship)) + } } else { association.setErr(errors.New("replace only support many to many")) } From d75612b86fc0bf9ec921b98ac275fa444d64cb16 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Jun 2015 11:32:11 +0800 Subject: [PATCH 42/52] Update JoinTableHandler API --- association.go | 8 ++++---- callback_shared.go | 2 +- join_table_handler.go | 26 +++++++++++++++----------- join_table_test.go | 6 +++--- scope_private.go | 2 +- 5 files changed, 24 insertions(+), 20 deletions(-) diff --git a/association.go b/association.go index dbc928e8..e34a10bd 100644 --- a/association.go +++ b/association.go @@ -78,7 +78,7 @@ func (association *Association) Delete(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) - if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { value := association.Field.Field.Index(i) @@ -134,7 +134,7 @@ func (association *Association) Replace(values ...interface{}) *Association { if len(addedPrimaryKeys) > 0 { sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) - association.setErr(relationship.JoinTableHandler.Delete(query, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) } } else { association.setErr(errors.New("replace only support many to many")) @@ -148,7 +148,7 @@ func (association *Association) Clear() *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey) - if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { association.setErr(err) @@ -166,7 +166,7 @@ func (association *Association) Count() int { newScope := scope.New(association.Field.Field.Interface()) if relationship.Kind == "many_to_many" { - relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) + relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) diff --git a/callback_shared.go b/callback_shared.go index 88158cfc..c1b9bd00 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -55,7 +55,7 @@ func SaveAfterAssociations(scope *Scope) { scope.Err(newDB.Save(elem).Error) if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value)) + scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value)) } } default: diff --git a/join_table_handler.go b/join_table_handler.go index 27051cbd..dceb4277 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -10,9 +10,9 @@ import ( type JoinTableHandlerInterface interface { Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) Table(db *DB) string - Add(db *DB, source interface{}, destination interface{}) error - Delete(db *DB, sources ...interface{}) error - JoinWith(db *DB, source interface{}) *DB + Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error + Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error + JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB } type JoinTableForeignKey struct { @@ -74,8 +74,12 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s } } -func (s JoinTableHandler) Table(*DB) string { - return s.TableName +func (s JoinTableHandler) Table(db *DB) string { + if draftMode, ok := db.Get("publish:draft_mode"); ok && draftMode.(bool) { + return s.TableName + "_draft" + } else { + return s.TableName + } } func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { @@ -98,7 +102,7 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin return values } -func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { +func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { scope := db.NewScope("") searchMap := s.GetSearchMap(db, source1, source2) @@ -115,7 +119,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) values = append(values, value) } - quotedTable := s.Table(db) + quotedTable := handler.Table(db) sql := fmt.Sprintf( "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", quotedTable, @@ -129,7 +133,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) return db.Exec(sql, values...).Error } -func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error { +func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var conditions []string var values []interface{} @@ -138,11 +142,11 @@ func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error { values = append(values, value) } - return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error + return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error } -func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { - quotedTable := s.Table(db) +func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { + quotedTable := handler.Table(db) scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType diff --git a/join_table_test.go b/join_table_test.go index f8b097b6..3353aee2 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -22,7 +22,7 @@ type PersonAddress struct { CreatedAt time.Time } -func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { +func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), "address_id": db.NewScope(associationValue).PrimaryKeyValue(), @@ -33,11 +33,11 @@ func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValu }).FirstOrCreate(&PersonAddress{}).Error } -func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error { +func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { return db.Delete(&PersonAddress{}).Error } -func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB { +func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } diff --git a/scope_private.go b/scope_private.go index 63fcea46..5faebe2e 100644 --- a/scope_private.go +++ b/scope_private.go @@ -413,7 +413,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error) + scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() From 7e587724e850cc3c9d84a004efe3466517e2ec76 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Jun 2015 11:31:56 +0800 Subject: [PATCH 43/52] Remove dummy code used for qor --- join_table_handler.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/join_table_handler.go b/join_table_handler.go index dceb4277..ac909966 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -75,11 +75,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s } func (s JoinTableHandler) Table(db *DB) string { - if draftMode, ok := db.Get("publish:draft_mode"); ok && draftMode.(bool) { - return s.TableName + "_draft" - } else { - return s.TableName - } + return s.TableName } func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[string]interface{} { From dbedca4e5fe81ba68969835399c3d0866a260fa3 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 23 Jun 2015 14:19:59 +0800 Subject: [PATCH 44/52] Don't run auto migrate if join table doesn't exist --- main.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index 181722fd..c3a1629c 100644 --- a/main.go +++ b/main.go @@ -473,14 +473,17 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { } func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - for _, field := range s.NewScope(source).GetModelStruct().StructFields { + scope := s.NewScope(source) + for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) field.Relationship.JoinTableHandler = handler - s.Table(handler.Table(s)).AutoMigrate(handler) + if table := handler.Table(s); scope.Dialect().HasTable(scope, table) { + s.Table(table).AutoMigrate(handler) + } } } } From c2c1dd1fc80c16f466ab889a29c787840b7c6d21 Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Tue, 23 Jun 2015 15:27:21 -0700 Subject: [PATCH 45/52] Fix errors being inaccessible due to errors being set on different *DB instance than what is returned. --- main.go | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/main.go b/main.go index c3a1629c..ff7182bf 100644 --- a/main.go +++ b/main.go @@ -396,28 +396,33 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB { } func (s *DB) ModifyColumn(column string, typ string) *DB { - s.clone().NewScope(s.Value).modifyColumn(column, typ) - return s + scope := s.clone().NewScope(s.Value) + scope.modifyColumn(column, typ) + return scope.db } func (s *DB) DropColumn(column string) *DB { - s.clone().NewScope(s.Value).dropColumn(column) - return s + scope := s.clone().NewScope(s.Value) + scope.dropColumn(column) + return scope.db } func (s *DB) AddIndex(indexName string, column ...string) *DB { - s.clone().NewScope(s.Value).addIndex(false, indexName, column...) - return s + scope := s.clone().NewScope(s.Value) + scope.addIndex(false, indexName, column...) + return scope.db } func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { - s.clone().NewScope(s.Value).addIndex(true, indexName, column...) - return s + scope := s.clone().NewScope(s.Value) + scope.addIndex(true, indexName, column...) + return scope.db } func (s *DB) RemoveIndex(indexName string) *DB { - s.clone().NewScope(s.Value).removeIndex(indexName) - return s + scope := s.clone().NewScope(s.Value) + scope.removeIndex(indexName) + return scope.db } /* @@ -427,7 +432,8 @@ Example: db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") */ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - s.clone().NewScope(s.Value).addForeignKey(field, dest, onDelete, onUpdate) + scope := s.clone().NewScope(s.Value) + scope.addForeignKey(field, dest, onDelete, onUpdate) return s } From 2d802c3445e3e55535fa7648c111b530c262b32a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2015 13:56:30 +0800 Subject: [PATCH 46/52] Overwrite slice results with Find --- association_test.go | 2 -- callback_query.go | 2 ++ query_test.go | 12 ------------ 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/association_test.go b/association_test.go index 3ffd8880..ea5b1b80 100644 --- a/association_test.go +++ b/association_test.go @@ -148,7 +148,6 @@ func TestManyToMany(t *testing.T) { t.Errorf("Query many to many relations") } - newLanguages = []Language{} DB.Model(&user).Association("Languages").Find(&newLanguages) if len(newLanguages) != len([]string{"ZH", "EN"}) { t.Errorf("Should be able to find many to many relations") @@ -194,7 +193,6 @@ func TestManyToMany(t *testing.T) { t.Errorf("Language EE should not be deleted") } - languages = []Language{} DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) user2 := User{Name: "Many2Many_User2", Languages: languages} diff --git a/callback_query.go b/callback_query.go index 59022eba..4de911e8 100644 --- a/callback_query.go +++ b/callback_query.go @@ -30,6 +30,8 @@ func Query(scope *Scope) { if kind := dest.Kind(); kind == reflect.Slice { isSlice = true destType = dest.Type().Elem() + dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType)))) + if destType.Kind() == reflect.Ptr { isPtr = true destType = destType.Elem() diff --git a/query_test.go b/query_test.go index d84fae93..b15d01ba 100644 --- a/query_test.go +++ b/query_test.go @@ -98,49 +98,41 @@ func TestSearchWithPlainSQL(t *testing.T) { t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) } - users = []User{} DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) if len(users) != 3 { t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) } - users = []User{} scopedb.Where("age <> ?", 20).Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users age != 20, but got %v", len(users)) } - users = []User{} scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) } - users = []User{} scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) } - users = []User{} scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } - users = []User{} DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users, but got %v", len(users)) } - users = []User{} DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) if len(users) != 3 { t.Errorf("Should found 3 users, but got %v", len(users)) } - users = []User{} DB.Where("id in (?)", user1.Id).Find(&users) if len(users) != 1 { t.Errorf("Should found 1 users, but got %v", len(users)) @@ -191,7 +183,6 @@ func TestSearchWithStruct(t *testing.T) { t.Errorf("Search first record with where struct") } - users = []User{} DB.Find(&users, &User{Name: user2.Name}) if len(users) != 1 { t.Errorf("Search all records with inline struct") @@ -222,7 +213,6 @@ func TestSearchWithMap(t *testing.T) { t.Errorf("Search all records with inline map") } - users = []User{} DB.Find(&users, map[string]interface{}{"name": user3.Name}) if len(users) != 1 { t.Errorf("Search all records with inline map") @@ -395,13 +385,11 @@ func TestNot(t *testing.T) { t.Errorf("Should find all users's name not equal 3") } - users4 = []User{} DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { t.Errorf("Should find all users's name not equal 3") } - users4 = []User{} DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { t.Errorf("Should find all users's name not equal 3") From bdb6fc55e8bb267450a99dc722c9d69dfe814298 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2015 13:56:30 +0800 Subject: [PATCH 47/52] Overwrite slice results with Find --- association_test.go | 2 -- callback_query.go | 2 ++ query_test.go | 12 ------------ 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/association_test.go b/association_test.go index 3ffd8880..ea5b1b80 100644 --- a/association_test.go +++ b/association_test.go @@ -148,7 +148,6 @@ func TestManyToMany(t *testing.T) { t.Errorf("Query many to many relations") } - newLanguages = []Language{} DB.Model(&user).Association("Languages").Find(&newLanguages) if len(newLanguages) != len([]string{"ZH", "EN"}) { t.Errorf("Should be able to find many to many relations") @@ -194,7 +193,6 @@ func TestManyToMany(t *testing.T) { t.Errorf("Language EE should not be deleted") } - languages = []Language{} DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) user2 := User{Name: "Many2Many_User2", Languages: languages} diff --git a/callback_query.go b/callback_query.go index 59022eba..4de911e8 100644 --- a/callback_query.go +++ b/callback_query.go @@ -30,6 +30,8 @@ func Query(scope *Scope) { if kind := dest.Kind(); kind == reflect.Slice { isSlice = true destType = dest.Type().Elem() + dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType)))) + if destType.Kind() == reflect.Ptr { isPtr = true destType = destType.Elem() diff --git a/query_test.go b/query_test.go index d84fae93..b15d01ba 100644 --- a/query_test.go +++ b/query_test.go @@ -98,49 +98,41 @@ func TestSearchWithPlainSQL(t *testing.T) { t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) } - users = []User{} DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) if len(users) != 3 { t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) } - users = []User{} scopedb.Where("age <> ?", 20).Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users age != 20, but got %v", len(users)) } - users = []User{} scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) } - users = []User{} scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) } - users = []User{} scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } - users = []User{} DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users, but got %v", len(users)) } - users = []User{} DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) if len(users) != 3 { t.Errorf("Should found 3 users, but got %v", len(users)) } - users = []User{} DB.Where("id in (?)", user1.Id).Find(&users) if len(users) != 1 { t.Errorf("Should found 1 users, but got %v", len(users)) @@ -191,7 +183,6 @@ func TestSearchWithStruct(t *testing.T) { t.Errorf("Search first record with where struct") } - users = []User{} DB.Find(&users, &User{Name: user2.Name}) if len(users) != 1 { t.Errorf("Search all records with inline struct") @@ -222,7 +213,6 @@ func TestSearchWithMap(t *testing.T) { t.Errorf("Search all records with inline map") } - users = []User{} DB.Find(&users, map[string]interface{}{"name": user3.Name}) if len(users) != 1 { t.Errorf("Search all records with inline map") @@ -395,13 +385,11 @@ func TestNot(t *testing.T) { t.Errorf("Should find all users's name not equal 3") } - users4 = []User{} DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { t.Errorf("Should find all users's name not equal 3") } - users4 = []User{} DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { t.Errorf("Should find all users's name not equal 3") From 2a1d64c3e067bd1af99987157f9625a4605ed85e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Wed, 24 Jun 2015 14:09:59 +0800 Subject: [PATCH 48/52] Return cloned db instance for AddForeignKey --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index ff7182bf..aba51fc4 100644 --- a/main.go +++ b/main.go @@ -434,7 +434,7 @@ Example: func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { scope := s.clone().NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) - return s + return scope.db } func (s *DB) Association(column string) *Association { From 308c96ee4c525837965cd256afe18816f2ca60fb Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 29 Jun 2015 18:04:15 +0800 Subject: [PATCH 49/52] Add PrimaryFields --- join_table_handler.go | 10 ++++++++++ scope.go | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/join_table_handler.go b/join_table_handler.go index ac909966..07ecee2e 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -13,6 +13,8 @@ type JoinTableHandlerInterface interface { Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + SourceForeignKeys() []JoinTableForeignKey + DestinationForeignKeys() []JoinTableForeignKey } type JoinTableForeignKey struct { @@ -31,6 +33,14 @@ type JoinTableHandler struct { Destination JoinTableSource `sql:"-"` } +func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { + return s.Source.ForeignKeys +} + +func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { + return s.Destination.ForeignKeys +} + func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { s.TableName = tableName diff --git a/scope.go b/scope.go index 11bad777..de1b6159 100644 --- a/scope.go +++ b/scope.go @@ -110,6 +110,14 @@ func (scope *Scope) HasError() bool { return scope.db.Error != nil } +func (scope *Scope) PrimaryFields() []*Field { + var fields = []*Field{} + for _, field := range scope.GetModelStruct().PrimaryFields { + fields = append(fields, scope.Fields()[field.DBName]) + } + return fields +} + func (scope *Scope) PrimaryField() *Field { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { if len(primaryFields) > 1 { From e31752757a46b92728d63112714fa52b813a7546 Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Mon, 29 Jun 2015 15:35:50 -0700 Subject: [PATCH 50/52] Added missing field name quoting for `ALTER TABLE' and `CREATE INDEX' statements. --- scope_private.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scope_private.go b/scope_private.go index 5faebe2e..85f07e99 100644 --- a/scope_private.go +++ b/scope_private.go @@ -530,7 +530,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on var table = scope.TableName() var keyName = fmt.Sprintf("%s_%s_foreign", table, field) var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), keyName, field, dest, onDelete, onUpdate)).Exec() + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.Quote(keyName), scope.Quote(field), scope.Quote(dest), onDelete, onUpdate)).Exec() } func (scope *Scope) removeIndex(indexName string) { @@ -548,7 +548,7 @@ func (scope *Scope) autoMigrate() *Scope { if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { if field.IsNormal { sqlTag := scope.generateSqlTag(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() } } scope.createJoinTable(field) From 5ddca7c427a2e7558191412ad70cc210901164c5 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Tue, 30 Jun 2015 10:39:29 +0800 Subject: [PATCH 51/52] Fix table name for association --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index de1b6159..8ff514a9 100644 --- a/scope.go +++ b/scope.go @@ -259,7 +259,7 @@ func (scope *Scope) TableName() string { return tabler.TableName(scope.db) } - return scope.GetModelStruct().TableName(scope.db) + return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) } func (scope *Scope) QuotedTableName() (name string) { From 923ca15b6f8a45678ea3b864ffeb07da6c7b4f0c Mon Sep 17 00:00:00 2001 From: Jay Taylor Date: Thu, 2 Jul 2015 12:06:06 -0700 Subject: [PATCH 52/52] Surface errors emitted by `RowsAffected'. --- scope.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scope.go b/scope.go index 8ff514a9..cd6b235d 100644 --- a/scope.go +++ b/scope.go @@ -300,7 +300,7 @@ func (scope *Scope) Exec() *Scope { if !scope.HasError() { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); err == nil { + if count, err := result.RowsAffected(); scope.Err(err) == nil { scope.db.RowsAffected = count } }