diff --git a/README.md b/README.md index 0c4ea83a..ccab06db 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ type User struct { Num int `sql:"AUTO_INCREMENT"` CreatedAt time.Time UpdatedAt time.Time - DeletedAt time.Time + DeletedAt *time.Time Emails []Email // One-To-Many relationship (has many) BillingAddress Address // One-To-One relationship (has one) @@ -84,6 +84,21 @@ type User struct{} // struct User's database table name is "users" by default, w * Use `CreatedAt` to store record's created time if field exists * Use `UpdatedAt` to store record's updated time if field exists * Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete) +* Gorm provide a default model struct, you could embed it in your struct + +```go +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time +} + +type User struct { + gorm.Model + Name string +} +``` ## Initialize Database @@ -102,7 +117,7 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable") // You can also use an existing database connection handle // dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable") -// db := gorm.Open("postgres", dbSql) +// db, _ := gorm.Open("postgres", dbSql) // Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB) db.DB() @@ -332,6 +347,13 @@ db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users) //// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to ``` +#### Nested Preloading + +```go +db.Preload("Orders.OrderItems").Find(&users) +db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users) +``` + ## Update ```go @@ -823,23 +845,56 @@ for rows.Next() { } db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results) + +// find a user by email address +db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user) + +// find all email addresses for a user +db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails) ``` ## Transactions -All individual save and delete operations are run in a transaction by default. +To perform a set of operations within a transaction, the general flow is as below. +The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction. +(Note that all individual save and delete operations are run in a transaction by default.) ```go // begin tx := db.Begin() -// rollback +// do some database operations (use 'tx' from this point, not 'db') +tx.Create(...) +... + +// rollback in case of error tx.Rollback() -// commit +// Or commit if all is ok tx.Commit() ``` +### A Specific Example +``` +func CreateAnimals(db *gorm.DB) err { + tx := db.Begin() + // Note the use of tx as the database handle once you are within a transaction + + if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil { + tx.Rollback() + return err + } + + if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil { + tx.Rollback() + return err + } + + tx.Commit() + return nil +} +``` + ## Scopes ```go @@ -1069,7 +1124,7 @@ type Product struct { // 2nd param : destination table(id) // 3rd param : ONDELETE // 4th param : ONUPDATE -db.Model(&User{}).AddForeignKey("user_id", "destination_table(id)", "CASCADE", "RESTRICT") +db.Model(&User{}).AddForeignKey("role_id", "roles", "CASCADE", "RESTRICT") // Add index db.Model(&User{}).AddIndex("idx_user_name", "name") diff --git a/association.go b/association.go index 89bb1bec..e34a10bd 100644 --- a/association.go +++ b/association.go @@ -40,6 +40,7 @@ func (association *Association) Append(values ...interface{}) *Association { association.setErr(errors.New("invalid association type")) } } + scope.Search.Select(association.Column) scope.callCallbacks(scope.db.parent.callback.updates) return association.setErr(scope.db.Error) } @@ -77,7 +78,7 @@ func (association *Association) Delete(values ...interface{}) *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys) - if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { leftValues := reflect.Zero(association.Field.Field.Type()) for i := 0; i < association.Field.Field.Len(); i++ { value := association.Field.Field.Index(i) @@ -130,9 +131,11 @@ func (association *Association) Replace(values ...interface{}) *Association { addedPrimaryKeys = append(addedPrimaryKeys, primaryKey) } - sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) - query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) - association.setErr(relationship.JoinTableHandler.Delete(query, relationship)) + if len(addedPrimaryKeys) > 0 { + sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName)) + query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship)) + } } else { association.setErr(errors.New("replace only support many to many")) } @@ -145,7 +148,7 @@ func (association *Association) Clear() *Association { if relationship.Kind == "many_to_many" { sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName)) query := scope.NewDB().Where(sql, association.PrimaryKey) - if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil { + if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil { association.Field.Set(reflect.Zero(association.Field.Field.Type())) } else { association.setErr(err) @@ -163,7 +166,7 @@ func (association *Association) Count() int { newScope := scope.New(association.Field.Field.Interface()) if relationship.Kind == "many_to_many" { - relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) + relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName)) countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey) diff --git a/association_test.go b/association_test.go index 3ffd8880..ea5b1b80 100644 --- a/association_test.go +++ b/association_test.go @@ -148,7 +148,6 @@ func TestManyToMany(t *testing.T) { t.Errorf("Query many to many relations") } - newLanguages = []Language{} DB.Model(&user).Association("Languages").Find(&newLanguages) if len(newLanguages) != len([]string{"ZH", "EN"}) { t.Errorf("Should be able to find many to many relations") @@ -194,7 +193,6 @@ func TestManyToMany(t *testing.T) { t.Errorf("Language EE should not be deleted") } - languages = []Language{} DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages) user2 := User{Name: "Many2Many_User2", Languages: languages} diff --git a/callback.go b/callback.go index be9a7f12..603e5111 100644 --- a/callback.go +++ b/callback.go @@ -9,6 +9,7 @@ type callback struct { updates []*func(scope *Scope) deletes []*func(scope *Scope) queries []*func(scope *Scope) + rowQueries []*func(scope *Scope) processors []*callbackProcessor } @@ -55,6 +56,10 @@ func (c *callback) Query() *callbackProcessor { return c.addProcessor("query") } +func (c *callback) RowQuery() *callbackProcessor { + return c.addProcessor("row_query") +} + func (cp *callbackProcessor) Before(name string) *callbackProcessor { cp.before = name return cp @@ -168,7 +173,7 @@ func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) { } func (c *callback) sort() { - creates, updates, deletes, queries := []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{} + var creates, updates, deletes, queries, rowQueries []*callbackProcessor for _, processor := range c.processors { switch processor.typ { @@ -180,6 +185,8 @@ func (c *callback) sort() { deletes = append(deletes, processor) case "query": queries = append(queries, processor) + case "row_query": + rowQueries = append(rowQueries, processor) } } @@ -187,6 +194,7 @@ func (c *callback) sort() { c.updates = sortProcessors(updates) c.deletes = sortProcessors(deletes) c.queries = sortProcessors(queries) + c.rowQueries = sortProcessors(rowQueries) } var DefaultCallback = &callback{processors: []*callbackProcessor{}} diff --git a/callback_create.go b/callback_create.go index b21df08b..7f21ed6a 100644 --- a/callback_create.go +++ b/callback_create.go @@ -70,18 +70,24 @@ func Create(scope *Scope) { id, err := result.LastInsertId() if scope.Err(err) == nil { scope.db.RowsAffected, _ = result.RowsAffected() - if primaryField != nil { + if primaryField != nil && primaryField.IsBlank { scope.Err(scope.SetColumn(primaryField, id)) } } } } else { if primaryField == nil { - if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil { + if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil { scope.db.RowsAffected, _ = results.RowsAffected() + } else { + scope.Err(err) + } + } else { + if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil { + scope.db.RowsAffected = 1 + } else { + scope.Err(err) } - } else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil { - scope.db.RowsAffected = 1 } } } diff --git a/callback_query.go b/callback_query.go index 5daa5fec..4de911e8 100644 --- a/callback_query.go +++ b/callback_query.go @@ -16,20 +16,22 @@ func Query(scope *Scope) { destType reflect.Type ) - var dest = scope.IndirectValue() - if value, ok := scope.InstanceGet("gorm:query_destination"); ok { - dest = reflect.Indirect(reflect.ValueOf(value)) - } - if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok { if primaryKey := scope.PrimaryKey(); primaryKey != "" { - scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy)) + scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy)) } } + var dest = scope.IndirectValue() + if value, ok := scope.Get("gorm:query_destination"); ok { + dest = reflect.Indirect(reflect.ValueOf(value)) + } + if kind := dest.Kind(); kind == reflect.Slice { isSlice = true destType = dest.Type().Elem() + dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType)))) + if destType.Kind() == reflect.Ptr { isPtr = true destType = destType.Elem() diff --git a/callback_shared.go b/callback_shared.go index 88158cfc..c1b9bd00 100644 --- a/callback_shared.go +++ b/callback_shared.go @@ -55,7 +55,7 @@ func SaveAfterAssociations(scope *Scope) { scope.Err(newDB.Save(elem).Error) if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { - scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value)) + scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value)) } } default: diff --git a/callback_update.go b/callback_update.go index 1167871c..c3f7b4b6 100644 --- a/callback_update.go +++ b/callback_update.go @@ -64,13 +64,15 @@ func Update(scope *Scope) { } } - scope.Raw(fmt.Sprintf( - "UPDATE %v SET %v %v", - scope.QuotedTableName(), - strings.Join(sqls, ", "), - scope.CombinedConditionSql(), - )) - scope.Exec() + if len(sqls) > 0 { + scope.Raw(fmt.Sprintf( + "UPDATE %v SET %v %v", + scope.QuotedTableName(), + strings.Join(sqls, ", "), + scope.CombinedConditionSql(), + )) + scope.Exec() + } } } diff --git a/doc/development.md b/doc/development.md index 674cfc43..08166661 100644 --- a/doc/development.md +++ b/doc/development.md @@ -61,7 +61,7 @@ Gorm is powered by callbacks, so you could refer below links to learn how to wri [Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go) -[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go) +[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go) [Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go) diff --git a/join_table_handler.go b/join_table_handler.go index 9f705564..07ecee2e 100644 --- a/join_table_handler.go +++ b/join_table_handler.go @@ -10,9 +10,11 @@ import ( type JoinTableHandlerInterface interface { Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) Table(db *DB) string - Add(db *DB, source interface{}, destination interface{}) error - Delete(db *DB, sources ...interface{}) error - JoinWith(db *DB, source interface{}) *DB + Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error + Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error + JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB + SourceForeignKeys() []JoinTableForeignKey + DestinationForeignKeys() []JoinTableForeignKey } type JoinTableForeignKey struct { @@ -31,37 +33,58 @@ type JoinTableHandler struct { Destination JoinTableSource `sql:"-"` } +func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey { + return s.Source.ForeignKeys +} + +func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey { + return s.Destination.ForeignKeys +} + func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) { s.TableName = tableName s.Source = JoinTableSource{ModelType: source} sourceScope := &Scope{Value: reflect.New(source).Interface()} - for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields { + sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields + for _, primaryField := range sourcePrimaryFields { if relationship.ForeignDBName == "" { relationship.ForeignFieldName = source.Name() + primaryField.Name relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName) } + + var dbName string + if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { + dbName = relationship.ForeignDBName + } else { + dbName = ToDBName(source.Name() + primaryField.Name) + } + s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.ForeignDBName, + DBName: dbName, AssociationDBName: primaryField.DBName, }) } s.Destination = JoinTableSource{ModelType: destination} destinationScope := &Scope{Value: reflect.New(destination).Interface()} - for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields { - if relationship.AssociationForeignDBName == "" { - relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name - relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName) + destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields + for _, primaryField := range destinationPrimaryFields { + var dbName string + if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" { + dbName = relationship.AssociationForeignDBName + } else { + dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name) } + s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ - DBName: relationship.AssociationForeignDBName, + DBName: dbName, AssociationDBName: primaryField.DBName, }) } } -func (s JoinTableHandler) Table(*DB) string { +func (s JoinTableHandler) Table(db *DB) string { return s.TableName } @@ -85,7 +108,7 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin return values } -func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error { +func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error { scope := db.NewScope("") searchMap := s.GetSearchMap(db, source1, source2) @@ -102,7 +125,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) values = append(values, value) } - quotedTable := s.Table(db) + quotedTable := handler.Table(db) sql := fmt.Sprintf( "INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)", quotedTable, @@ -116,7 +139,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) return db.Exec(sql, values...).Error } -func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error { +func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var conditions []string var values []interface{} @@ -125,11 +148,11 @@ func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error { values = append(values, value) } - return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error + return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error } -func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { - quotedTable := s.Table(db) +func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB { + quotedTable := handler.Table(db) scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType @@ -138,7 +161,7 @@ func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB { var values []interface{} if s.Source.ModelType == modelType { for _, foreignKey := range s.Destination.ForeignKeys { - destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() + destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName() joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName))) } diff --git a/join_table_test.go b/join_table_test.go index f8b097b6..3353aee2 100644 --- a/join_table_test.go +++ b/join_table_test.go @@ -22,7 +22,7 @@ type PersonAddress struct { CreatedAt time.Time } -func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { +func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error { return db.Where(map[string]interface{}{ "person_id": db.NewScope(foreignValue).PrimaryKeyValue(), "address_id": db.NewScope(associationValue).PrimaryKeyValue(), @@ -33,11 +33,11 @@ func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValu }).FirstOrCreate(&PersonAddress{}).Error } -func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error { +func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error { return db.Delete(&PersonAddress{}).Error } -func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB { +func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB { table := pa.Table(db) return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table)) } diff --git a/main.go b/main.go index 7049675e..aba51fc4 100644 --- a/main.go +++ b/main.go @@ -211,6 +211,10 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB { return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db } +func (s *DB) Scan(dest interface{}) *DB { + return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db +} + func (s *DB) Row() *sql.Row { return s.NewScope(s.Value).row() } @@ -219,10 +223,16 @@ func (s *DB) Rows() (*sql.Rows, error) { return s.NewScope(s.Value).rows() } -func (s *DB) Scan(dest interface{}) *DB { - scope := s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest) - Query(scope) - return scope.db +func (s *DB) Pluck(column string, value interface{}) *DB { + return s.NewScope(s.Value).pluck(column, value).db +} + +func (s *DB) Count(value interface{}) *DB { + return s.NewScope(s.Value).count(value).db +} + +func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { + return s.clone().NewScope(s.Value).related(value, foreignKeys...).db } func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { @@ -309,18 +319,6 @@ func (s *DB) Model(value interface{}) *DB { return c } -func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db -} - -func (s *DB) Pluck(column string, value interface{}) *DB { - return s.NewScope(s.Value).pluck(column, value).db -} - -func (s *DB) Count(value interface{}) *DB { - return s.NewScope(s.Value).count(value).db -} - func (s *DB) Table(name string) *DB { clone := s.clone() clone.search.Table(name) @@ -398,28 +396,33 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB { } func (s *DB) ModifyColumn(column string, typ string) *DB { - s.clone().NewScope(s.Value).modifyColumn(column, typ) - return s + scope := s.clone().NewScope(s.Value) + scope.modifyColumn(column, typ) + return scope.db } func (s *DB) DropColumn(column string) *DB { - s.clone().NewScope(s.Value).dropColumn(column) - return s + scope := s.clone().NewScope(s.Value) + scope.dropColumn(column) + return scope.db } func (s *DB) AddIndex(indexName string, column ...string) *DB { - s.clone().NewScope(s.Value).addIndex(false, indexName, column...) - return s + scope := s.clone().NewScope(s.Value) + scope.addIndex(false, indexName, column...) + return scope.db } func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB { - s.clone().NewScope(s.Value).addIndex(true, indexName, column...) - return s + scope := s.clone().NewScope(s.Value) + scope.addIndex(true, indexName, column...) + return scope.db } func (s *DB) RemoveIndex(indexName string) *DB { - s.clone().NewScope(s.Value).removeIndex(indexName) - return s + scope := s.clone().NewScope(s.Value) + scope.removeIndex(indexName) + return scope.db } /* @@ -429,8 +432,9 @@ Example: db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") */ func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - s.clone().NewScope(s.Value).addForeignKey(field, dest, onDelete, onUpdate) - return s + scope := s.clone().NewScope(s.Value) + scope.addForeignKey(field, dest, onDelete, onUpdate) + return scope.db } func (s *DB) Association(column string) *Association { @@ -475,14 +479,17 @@ func (s *DB) Get(name string) (value interface{}, ok bool) { } func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) { - for _, field := range s.NewScope(source).GetModelStruct().StructFields { + scope := s.NewScope(source) + for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) field.Relationship.JoinTableHandler = handler - s.Table(handler.Table(s)).AutoMigrate(handler) + if table := handler.Table(s); scope.Dialect().HasTable(scope, table) { + s.Table(table).AutoMigrate(handler) + } } } } diff --git a/main_test.go b/main_test.go index b547534c..0dc5e337 100644 --- a/main_test.go +++ b/main_test.go @@ -61,6 +61,19 @@ func init() { runMigration() } +func TestStringPrimaryKey(t *testing.T) { + type UUIDStruct struct { + ID string `gorm:"primary_key"` + Name string + } + DB.AutoMigrate(&UUIDStruct{}) + + data := UUIDStruct{ID: "uuid", Name: "hello"} + if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" { + t.Errorf("string primary key should not be populated") + } +} + func TestExceptionsWithInvalidSql(t *testing.T) { var columns []string if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { diff --git a/model.go b/model.go new file mode 100644 index 00000000..50fa52e6 --- /dev/null +++ b/model.go @@ -0,0 +1,10 @@ +package gorm + +import "time" + +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time +} diff --git a/model_struct.go b/model_struct.go index a70489fc..10423ae2 100644 --- a/model_struct.go +++ b/model_struct.go @@ -13,11 +13,19 @@ import ( var modelStructs = map[reflect.Type]*ModelStruct{} +var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { + return defaultTableName +} + type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type - TableName func(*DB) string + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type + defaultTableName string +} + +func (s ModelStruct) TableName(db *DB) string { + return DefaultTableNameHandler(db, s.defaultTableName) } type StructField struct { @@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Set tablename - if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() { - if results := fm.Call([]reflect.Value{}); len(results) > 0 { - if name, ok := results[0].Interface().(string); ok { - modelStruct.TableName = func(*DB) string { - return name - } - } - } + type tabler interface { + TableName() string + } + + if tabler, ok := reflect.New(scopeType).Interface().(interface { + TableName() string + }); ok { + modelStruct.defaultTableName = tabler.TableName() } else { name := ToDBName(scopeType.Name()) if scope.db == nil || !scope.db.parent.singularTable { @@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStruct.TableName = func(*DB) string { - return name - } + modelStruct.defaultTableName = name } // Get all fields diff --git a/multi_primary_keys_test.go b/multi_primary_keys_test.go new file mode 100644 index 00000000..9ca68d13 --- /dev/null +++ b/multi_primary_keys_test.go @@ -0,0 +1,46 @@ +package gorm_test + +import ( + "fmt" + "os" + "testing" +) + +type Blog struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Subject string + Body string + Tags []Tag `gorm:"many2many:blog_tags;"` +} + +type Tag struct { + ID uint `gorm:"primary_key"` + Locale string `gorm:"primary_key"` + Value string +} + +func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" { + DB.Exec(fmt.Sprintf("drop table blog_tags;")) + DB.AutoMigrate(&Blog{}, &Tag{}) + blog := Blog{ + Locale: "ZH", + Subject: "subject", + Body: "body", + Tags: []Tag{ + {Locale: "ZH", Value: "tag1"}, + {Locale: "ZH", Value: "tag2"}, + }, + } + + DB.Save(&blog) + DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}}) + + var tags []Tag + DB.Model(&blog).Related(&tags, "Tags") + if len(tags) != 3 { + t.Errorf("should found 3 tags with blog") + } + } +} diff --git a/mysql.go b/mysql.go index e37a23e0..a5e4a459 100644 --- a/mysql.go +++ b/mysql.go @@ -14,16 +14,26 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string { switch value.Kind() { case reflect.Bool: return "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: if autoIncrease { return "int AUTO_INCREMENT" } return "int" - case reflect.Int64, reflect.Uint64: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if autoIncrease { + return "int unsigned AUTO_INCREMENT" + } + return "int unsigned" + case reflect.Int64: if autoIncrease { return "bigint AUTO_INCREMENT" } return "bigint" + case reflect.Uint64: + if autoIncrease { + return "bigint unsigned AUTO_INCREMENT" + } + return "bigint unsigned" case reflect.Float32, reflect.Float64: return "double" case reflect.String: diff --git a/preload.go b/preload.go index d252238a..03910c44 100644 --- a/preload.go +++ b/preload.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "strings" ) func getRealValue(value reflect.Value, field string) interface{} { @@ -20,94 +21,69 @@ func equalAsString(a interface{}, b interface{}) bool { } func Preload(scope *Scope) { + if scope.Search.preload == nil { + return + } + + preloadMap := map[string]bool{} fields := scope.Fields() - isSlice := scope.IndirectValue().Kind() == reflect.Slice + for _, preload := range scope.Search.preload { + schema, conditions := preload.schema, preload.conditions + keys := strings.Split(schema, ".") + currentScope := scope + currentFields := fields + originalConditions := conditions + conditions = []interface{}{} + for i, key := range keys { + var found bool + if preloadMap[strings.Join(keys[:i+1], ".")] { + goto nextLoop + } - if scope.Search.preload != nil { - for key, conditions := range scope.Search.preload { - for _, field := range fields { - if field.Name == key && field.Relationship != nil { - results := makeSlice(field.Struct.Type) - relation := field.Relationship - primaryName := scope.PrimaryField().Name - associationPrimaryKey := scope.New(results).PrimaryField().Name + if i == len(keys)-1 { + conditions = originalConditions + } - switch relation.Kind { - case "has_one": - if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if isSlice { - value := getRealValue(result, relation.ForeignFieldName) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - if equalAsString(getRealValue(objects.Index(j), primaryName), value) { - reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) - break - } - } - } else { - scope.SetColumn(field, result) - } - } - } - case "has_many": - if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 { - condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) - scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - if isSlice { - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - value := getRealValue(result, relation.ForeignFieldName) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, primaryName), value) { - f := object.FieldByName(field.Name) - f.Set(reflect.Append(f, result)) - break - } - } - } - } else { - scope.SetColumn(field, resultValues) - } - } - case "belongs_to": - if primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 { - scope.NewDB().Where(primaryKeys).Find(results, conditions...) - resultValues := reflect.Indirect(reflect.ValueOf(results)) - for i := 0; i < resultValues.Len(); i++ { - result := resultValues.Index(i) - if isSlice { - value := getRealValue(result, associationPrimaryKey) - objects := scope.IndirectValue() - for j := 0; j < objects.Len(); j++ { - object := reflect.Indirect(objects.Index(j)) - if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { - object.FieldByName(field.Name).Set(result) - } - } - } else { - scope.SetColumn(field, result) - } - } - } - case "many_to_many": - scope.Err(errors.New("not supported relation")) - default: - scope.Err(errors.New("not supported relation")) - } - break + for _, field := range currentFields { + if field.Name != key || field.Relationship == nil { + continue } + + found = true + switch field.Relationship.Kind { + case "has_one": + currentScope.handleHasOnePreload(field, conditions) + case "has_many": + currentScope.handleHasManyPreload(field, conditions) + case "belongs_to": + currentScope.handleBelongsToPreload(field, conditions) + case "many_to_many": + fallthrough + default: + currentScope.Err(errors.New("not supported relation")) + } + break + } + + if !found { + value := reflect.ValueOf(currentScope.Value) + if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface { + value = value.Index(0).Elem() + } + scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type())) + return + } + + preloadMap[strings.Join(keys[:i+1], ".")] = true + + nextLoop: + if i < len(keys)-1 { + currentScope = currentScope.getColumnsAsScope(key) + currentFields = currentScope.Fields() } } } + } func makeSlice(typ reflect.Type) interface{} { @@ -120,19 +96,147 @@ func makeSlice(typ reflect.Type) interface{} { return slice.Interface() } -func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) { +func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { + primaryName := scope.PrimaryField().Name + primaryKeys := scope.getColumnAsArray(primaryName) + if len(primaryKeys) == 0 { + return + } + + results := makeSlice(field.Struct.Type) + relation := field.Relationship + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) + + scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if scope.IndirectValue().Kind() == reflect.Slice { + value := getRealValue(result, relation.ForeignFieldName) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + if equalAsString(getRealValue(objects.Index(j), primaryName), value) { + reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result) + break + } + } + } else { + if err := scope.SetColumn(field, result); err != nil { + scope.Err(err) + return + } + } + } +} + +func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { + primaryName := scope.PrimaryField().Name + primaryKeys := scope.getColumnAsArray(primaryName) + if len(primaryKeys) == 0 { + return + } + + results := makeSlice(field.Struct.Type) + relation := field.Relationship + condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName)) + + scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + + if scope.IndirectValue().Kind() == reflect.Slice { + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + value := getRealValue(result, relation.ForeignFieldName) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, primaryName), value) { + f := object.FieldByName(field.Name) + f.Set(reflect.Append(f, result)) + break + } + } + } + } else { + scope.SetColumn(field, resultValues) + } +} + +func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { + relation := field.Relationship + primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName) + if len(primaryKeys) == 0 { + return + } + + results := makeSlice(field.Struct.Type) + associationPrimaryKey := scope.New(results).PrimaryField().Name + + scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error) + resultValues := reflect.Indirect(reflect.ValueOf(results)) + + for i := 0; i < resultValues.Len(); i++ { + result := resultValues.Index(i) + if scope.IndirectValue().Kind() == reflect.Slice { + value := getRealValue(result, associationPrimaryKey) + objects := scope.IndirectValue() + for j := 0; j < objects.Len(); j++ { + object := reflect.Indirect(objects.Index(j)) + if equalAsString(getRealValue(object, relation.ForeignFieldName), value) { + object.FieldByName(field.Name).Set(result) + } + } + } else { + scope.SetColumn(field, result) + } + } +} + +func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) { values := scope.IndirectValue() switch values.Kind() { case reflect.Slice: - primaryKeyMap := map[interface{}]bool{} for i := 0; i < values.Len(); i++ { - primaryKeyMap[reflect.Indirect(values.Index(i)).FieldByName(column).Interface()] = true - } - for key := range primaryKeyMap { - primaryKeys = append(primaryKeys, key) + columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface()) } case reflect.Struct: return []interface{}{values.FieldByName(column).Interface()} } return } + +func (scope *Scope) getColumnsAsScope(column string) *Scope { + values := scope.IndirectValue() + switch values.Kind() { + case reflect.Slice: + modelType := values.Type().Elem() + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + fieldStruct, _ := modelType.FieldByName(column) + var columns reflect.Value + if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr { + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem() + } else { + columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem() + } + for i := 0; i < values.Len(); i++ { + column := reflect.Indirect(values.Index(i)).FieldByName(column) + if column.Kind() == reflect.Ptr { + column = column.Elem() + } + if column.Kind() == reflect.Slice { + for i := 0; i < column.Len(); i++ { + columns = reflect.Append(columns, column.Index(i).Addr()) + } + } else { + columns = reflect.Append(columns, column.Addr()) + } + } + return scope.New(columns.Interface()) + case reflect.Struct: + return scope.New(values.FieldByName(column).Addr().Interface()) + } + return nil +} diff --git a/preload_test.go b/preload_test.go index 2547933b..a6647bbd 100644 --- a/preload_test.go +++ b/preload_test.go @@ -1,6 +1,10 @@ package gorm_test -import "testing" +import ( + "encoding/json" + "reflect" + "testing" +) func getPreloadUser(name string) *User { return getPreparedUser(name, "Preload") @@ -85,3 +89,521 @@ func TestPreload(t *testing.T) { } } } + +func TestNestedPreload1(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload2(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []*Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + { + Level1s: []*Level1{ + &Level1{Value: "value1"}, + &Level1{Value: "value2"}, + }, + }, + { + Level1s: []*Level1{ + &Level1{Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload3(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + Name string + ID uint + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload4(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want).Error; err != nil { + panic(err) + } + + var got Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +// Slice: []Level3 +func TestNestedPreload5(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload6(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value1"}, + {Value: "value2"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value3"}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + + want[1] = Level3{ + Level2s: []Level2{ + { + Level1s: []Level1{ + {Value: "value3"}, + {Value: "value4"}, + }, + }, + { + Level1s: []Level1{ + {Value: "value5"}, + }, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload7(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1 Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2s []Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value1"}}, + {Level1: Level1{Value: "value2"}}, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + + want[1] = Level3{ + Level2s: []Level2{ + {Level1: Level1{Value: "value3"}}, + {Level1: Level1{Value: "value4"}}, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload8(t *testing.T) { + type ( + Level1 struct { + ID uint + Value string + Level2ID uint + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level1{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value3"}, + Level1{Value: "value4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func TestNestedPreload9(t *testing.T) { + type ( + Level0 struct { + ID uint + Value string + Level1ID uint + } + Level1 struct { + ID uint + Value string + Level2ID uint + Level2_1ID uint + Level0s []Level0 + } + Level2 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level2_1 struct { + ID uint + Level1s []Level1 + Level3ID uint + } + Level3 struct { + ID uint + Name string + Level2 Level2 + Level2_1 Level2_1 + } + ) + DB.DropTableIfExists(&Level3{}) + DB.DropTableIfExists(&Level2{}) + DB.DropTableIfExists(&Level2_1{}) + DB.DropTableIfExists(&Level1{}) + DB.DropTableIfExists(&Level0{}) + if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil { + panic(err) + } + + want := make([]Level3, 2) + want[0] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value1"}, + Level1{Value: "value2"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + Level1{ + Value: "value1-1", + Level0s: []Level0{{Value: "Level0-1"}}, + }, + Level1{ + Value: "value2-2", + Level0s: []Level0{{Value: "Level0-2"}}, + }, + }, + }, + } + if err := DB.Create(&want[0]).Error; err != nil { + panic(err) + } + want[1] = Level3{ + Level2: Level2{ + Level1s: []Level1{ + Level1{Value: "value3"}, + Level1{Value: "value4"}, + }, + }, + Level2_1: Level2_1{ + Level1s: []Level1{ + Level1{Value: "value3-3"}, + Level1{Value: "value4-4"}, + }, + }, + } + if err := DB.Create(&want[1]).Error; err != nil { + panic(err) + } + + var got []Level3 + if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { + panic(err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) + } +} + +func toJSONString(v interface{}) []byte { + r, _ := json.MarshalIndent(v, "", " ") + return r +} diff --git a/query_test.go b/query_test.go index d84fae93..b15d01ba 100644 --- a/query_test.go +++ b/query_test.go @@ -98,49 +98,41 @@ func TestSearchWithPlainSQL(t *testing.T) { t.Errorf("Should found 2 users that age > 1, but got %v", len(users)) } - users = []User{} DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users) if len(users) != 3 { t.Errorf("Should found 3 users that age >= 1, but got %v", len(users)) } - users = []User{} scopedb.Where("age <> ?", 20).Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users age != 20, but got %v", len(users)) } - users = []User{} scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users)) } - users = []User{} scopedb.Where("birthday > ?", "2002-10-10").Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users)) } - users = []User{} scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users) if len(users) != 1 { t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users)) } - users = []User{} DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users) if len(users) != 2 { t.Errorf("Should found 2 users, but got %v", len(users)) } - users = []User{} DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users) if len(users) != 3 { t.Errorf("Should found 3 users, but got %v", len(users)) } - users = []User{} DB.Where("id in (?)", user1.Id).Find(&users) if len(users) != 1 { t.Errorf("Should found 1 users, but got %v", len(users)) @@ -191,7 +183,6 @@ func TestSearchWithStruct(t *testing.T) { t.Errorf("Search first record with where struct") } - users = []User{} DB.Find(&users, &User{Name: user2.Name}) if len(users) != 1 { t.Errorf("Search all records with inline struct") @@ -222,7 +213,6 @@ func TestSearchWithMap(t *testing.T) { t.Errorf("Search all records with inline map") } - users = []User{} DB.Find(&users, map[string]interface{}{"name": user3.Name}) if len(users) != 1 { t.Errorf("Search all records with inline map") @@ -395,13 +385,11 @@ func TestNot(t *testing.T) { t.Errorf("Should find all users's name not equal 3") } - users4 = []User{} DB.Not("name = ?", "user3").Find(&users4) if len(users1)-len(users4) != int(name3Count) { t.Errorf("Should find all users's name not equal 3") } - users4 = []User{} DB.Not("name <> ?", "user3").Find(&users4) if len(users4) != int(name3Count) { t.Errorf("Should find all users's name not equal 3") diff --git a/scope.go b/scope.go index 86994a85..cd6b235d 100644 --- a/scope.go +++ b/scope.go @@ -110,6 +110,14 @@ func (scope *Scope) HasError() bool { return scope.db.Error != nil } +func (scope *Scope) PrimaryFields() []*Field { + var fields = []*Field{} + for _, field := range scope.GetModelStruct().PrimaryFields { + fields = append(fields, scope.Fields()[field.DBName]) + } + return fields +} + func (scope *Scope) PrimaryField() *Field { if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 { if len(primaryFields) > 1 { @@ -158,13 +166,18 @@ func (scope *Scope) HasColumn(column string) bool { func (scope *Scope) SetColumn(column interface{}, value interface{}) error { if field, ok := column.(*Field); ok { return field.Set(value) - } else if dbName, ok := column.(string); ok { + } else if name, ok := column.(string); ok { + + if field, ok := scope.Fields()[name]; ok { + return field.Set(value) + } + + dbName := ToDBName(name) if field, ok := scope.Fields()[dbName]; ok { return field.Set(value) } - dbName = ToDBName(dbName) - if field, ok := scope.Fields()[dbName]; ok { + if field, ok := scope.FieldByName(name); ok { return field.Set(value) } } @@ -172,7 +185,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { } func (scope *Scope) CallMethod(name string, checkError bool) { - if scope.Value == nil && (!checkError || !scope.HasError()) { + if scope.Value == nil || (checkError && scope.HasError()) { return } @@ -246,17 +259,14 @@ func (scope *Scope) TableName() string { return tabler.TableName(scope.db) } - if scope.GetModelStruct().TableName != nil { - scope.Search.tableName = scope.GetModelStruct().TableName(scope.db) - return scope.Search.tableName - } - - scope.Err(errors.New("wrong table name")) - return "" + return scope.GetModelStruct().TableName(scope.db.Model(scope.Value)) } func (scope *Scope) QuotedTableName() (name string) { if scope.Search != nil && len(scope.Search.tableName) > 0 { + if strings.Index(scope.Search.tableName, " ") != -1 { + return scope.Search.tableName + } return scope.Quote(scope.Search.tableName) } else { return scope.Quote(scope.TableName()) @@ -271,7 +281,7 @@ func (scope *Scope) CombinedConditionSql() string { func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { for _, field := range scope.Fields() { - if field.Name == name { + if field.Name == name || field.DBName == name { return field, true } } @@ -290,7 +300,7 @@ func (scope *Scope) Exec() *Scope { if !scope.HasError() { if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil { - if count, err := result.RowsAffected(); err == nil { + if count, err := result.RowsAffected(); scope.Err(err) == nil { scope.db.RowsAffected = count } } @@ -364,6 +374,8 @@ func (scope *Scope) SelectAttrs() []string { for _, value := range scope.Search.selects { if str, ok := value.(string); ok { attrs = append(attrs, str) + } else if strs, ok := value.([]string); ok { + attrs = append(attrs, strs...) } else if strs, ok := value.([]interface{}); ok { for _, str := range strs { attrs = append(attrs, fmt.Sprintf("%v", str)) diff --git a/scope_private.go b/scope_private.go index 99dda2ed..85f07e99 100644 --- a/scope_private.go +++ b/scope_private.go @@ -38,7 +38,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri case interface{}: var sqls []string for _, field := range scope.New(value).Fields() { - if !field.IsBlank { + if !field.IsIgnored && !field.IsBlank { sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) } } @@ -336,12 +336,14 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore func (scope *Scope) row() *sql.Row { defer scope.Trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callback.rowQueries) scope.prepareQuerySql() return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...) } func (scope *Scope) rows() (*sql.Rows, error) { defer scope.Trace(NowFunc()) + scope.callCallbacks(scope.db.parent.callback.rowQueries) scope.prepareQuerySql() return scope.SqlDB().Query(scope.Sql, scope.SqlVars...) } @@ -411,7 +413,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error) + scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface() @@ -445,13 +447,19 @@ func (scope *Scope) createJoinTable(field *StructField) { joinTableHandler := relationship.JoinTableHandler joinTable := joinTableHandler.Table(scope.db) if !scope.Dialect().HasTable(scope, joinTable) { - primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false) - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", - scope.Quote(joinTable), - strings.Join([]string{ - scope.Quote(relationship.ForeignDBName) + " " + primaryKeySqlType, - scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")), - ).Error) + toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()} + + var sqlTypes []string + for _, s := range []*Scope{scope, toScope} { + for _, primaryField := range s.GetModelStruct().PrimaryFields { + value := reflect.Indirect(reflect.New(primaryField.Struct.Type)) + primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false) + dbName := ToDBName(s.GetModelStruct().ModelType.Name() + primaryField.Name) + sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType) + } + } + + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join(sqlTypes, ","))).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -467,7 +475,7 @@ func (scope *Scope) createTable() *Scope { } if field.IsPrimaryKey { - primaryKeys = append(primaryKeys, field.DBName) + primaryKeys = append(primaryKeys, scope.Quote(field.DBName)) } scope.createJoinTable(field) } @@ -522,7 +530,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on var table = scope.TableName() var keyName = fmt.Sprintf("%s_%s_foreign", table, field) var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;` - scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), keyName, field, dest, onDelete, onUpdate)).Exec() + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.Quote(keyName), scope.Quote(field), scope.Quote(dest), onDelete, onUpdate)).Exec() } func (scope *Scope) removeIndex(indexName string) { @@ -540,7 +548,7 @@ func (scope *Scope) autoMigrate() *Scope { if !scope.Dialect().HasColumn(scope, tableName, field.DBName) { if field.IsNormal { sqlTag := scope.generateSqlTag(field) - scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec() + scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec() } } scope.createJoinTable(field) diff --git a/search.go b/search.go index 502c226f..9411af43 100644 --- a/search.go +++ b/search.go @@ -14,7 +14,7 @@ type search struct { omits []string orders []string joins string - preload map[string][]interface{} + preload []searchPreload offset string limit string group string @@ -23,26 +23,14 @@ type search struct { Unscoped bool } +type searchPreload struct { + schema string + conditions []interface{} +} + func (s *search) clone() *search { - return &search{ - preload: s.preload, - whereConditions: s.whereConditions, - orConditions: s.orConditions, - notConditions: s.notConditions, - havingCondition: s.havingCondition, - initAttrs: s.initAttrs, - assignAttrs: s.assignAttrs, - selects: s.selects, - omits: s.omits, - orders: s.orders, - joins: s.joins, - offset: s.offset, - limit: s.limit, - group: s.group, - tableName: s.tableName, - raw: s.raw, - Unscoped: s.Unscoped, - } + clone := *s + return &clone } func (s *search) Where(query interface{}, values ...interface{}) *search { @@ -114,11 +102,15 @@ func (s *search) Joins(query string) *search { return s } -func (s *search) Preload(column string, values ...interface{}) *search { - if s.preload == nil { - s.preload = map[string][]interface{}{} +func (s *search) Preload(schema string, values ...interface{}) *search { + var preloads []searchPreload + for _, preload := range s.preload { + if preload.schema != schema { + preloads = append(preloads, preload) + } } - s.preload[column] = values + preloads = append(preloads, searchPreload{schema, values}) + s.preload = preloads return s } diff --git a/test_all.sh b/test_all.sh index bd28294d..6c5593b3 100755 --- a/test_all.sh +++ b/test_all.sh @@ -1,4 +1,4 @@ -dialects=("postgres" "foundation" "mysql" "sqlite") +dialects=("postgres" "mysql" "sqlite") for dialect in "${dialects[@]}" ; do GORM_DIALECT=${dialect} go test