Merge pull request #3 from jinzhu/master

Getting our fork up to date
This commit is contained in:
Leonard Austin 2015-07-13 18:20:04 +01:00
commit 36ea322ee8
24 changed files with 1067 additions and 252 deletions

View File

@ -39,7 +39,7 @@ type User struct {
Num int `sql:"AUTO_INCREMENT"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
DeletedAt *time.Time
Emails []Email // One-To-Many relationship (has many)
BillingAddress Address // One-To-One relationship (has one)
@ -84,6 +84,21 @@ type User struct{} // struct User's database table name is "users" by default, w
* Use `CreatedAt` to store record's created time if field exists
* Use `UpdatedAt` to store record's updated time if field exists
* Use `DeletedAt` to store record's deleted time if field exists [Soft Delete](#soft-delete)
* Gorm provide a default model struct, you could embed it in your struct
```go
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}
type User struct {
gorm.Model
Name string
}
```
## Initialize Database
@ -102,7 +117,7 @@ db, err := gorm.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// You can also use an existing database connection handle
// dbSql, _ := sql.Open("postgres", "user=gorm dbname=gorm sslmode=disable")
// db := gorm.Open("postgres", dbSql)
// db, _ := gorm.Open("postgres", dbSql)
// Get database connection handle [*sql.DB](http://golang.org/pkg/database/sql/#DB)
db.DB()
@ -332,6 +347,13 @@ db.Preload("Orders").Preload("Profile").Preload("Role").Find(&users)
//// SELECT * FROM roles WHERE id IN (4,5,6); // belongs to
```
#### Nested Preloading
```go
db.Preload("Orders.OrderItems").Find(&users)
db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users)
```
## Update
```go
@ -823,23 +845,56 @@ for rows.Next() {
}
db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results)
// find a user by email address
db.Joins("inner join emails on emails.user_id = users.id").Where("emails.email = ?", "x@example.org").Find(&user)
// find all email addresses for a user
db.Joins("left join users on users.id = emails.user_id").Where("users.name = ?", "jinzhu").Find(&emails)
```
## Transactions
All individual save and delete operations are run in a transaction by default.
To perform a set of operations within a transaction, the general flow is as below.
The database handle returned from ``` db.Begin() ``` should be used for all operations within the transaction.
(Note that all individual save and delete operations are run in a transaction by default.)
```go
// begin
tx := db.Begin()
// rollback
// do some database operations (use 'tx' from this point, not 'db')
tx.Create(...)
...
// rollback in case of error
tx.Rollback()
// commit
// Or commit if all is ok
tx.Commit()
```
### A Specific Example
```
func CreateAnimals(db *gorm.DB) err {
tx := db.Begin()
// Note the use of tx as the database handle once you are within a transaction
if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil {
tx.Rollback()
return err
}
if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
```
## Scopes
```go
@ -1069,7 +1124,7 @@ type Product struct {
// 2nd param : destination table(id)
// 3rd param : ONDELETE
// 4th param : ONUPDATE
db.Model(&User{}).AddForeignKey("user_id", "destination_table(id)", "CASCADE", "RESTRICT")
db.Model(&User{}).AddForeignKey("role_id", "roles", "CASCADE", "RESTRICT")
// Add index
db.Model(&User{}).AddIndex("idx_user_name", "name")

View File

@ -40,6 +40,7 @@ func (association *Association) Append(values ...interface{}) *Association {
association.setErr(errors.New("invalid association type"))
}
}
scope.Search.Select(association.Column)
scope.callCallbacks(scope.db.parent.callback.updates)
return association.setErr(scope.db.Error)
}
@ -77,7 +78,7 @@ func (association *Association) Delete(values ...interface{}) *Association {
if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ? AND %v IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey, primaryKeys)
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
leftValues := reflect.Zero(association.Field.Field.Type())
for i := 0; i < association.Field.Field.Len(); i++ {
value := association.Field.Field.Index(i)
@ -130,9 +131,11 @@ func (association *Association) Replace(values ...interface{}) *Association {
addedPrimaryKeys = append(addedPrimaryKeys, primaryKey)
}
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
association.setErr(relationship.JoinTableHandler.Delete(query, relationship))
if len(addedPrimaryKeys) > 0 {
sql := fmt.Sprintf("%v = ? AND %v NOT IN (?)", scope.Quote(relationship.ForeignDBName), scope.Quote(relationship.AssociationForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey, addedPrimaryKeys)
association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship))
}
} else {
association.setErr(errors.New("replace only support many to many"))
}
@ -145,7 +148,7 @@ func (association *Association) Clear() *Association {
if relationship.Kind == "many_to_many" {
sql := fmt.Sprintf("%v = ?", scope.Quote(relationship.ForeignDBName))
query := scope.NewDB().Where(sql, association.PrimaryKey)
if err := relationship.JoinTableHandler.Delete(query, relationship); err == nil {
if err := relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, query, relationship); err == nil {
association.Field.Set(reflect.Zero(association.Field.Field.Type()))
} else {
association.setErr(err)
@ -163,7 +166,7 @@ func (association *Association) Count() int {
newScope := scope.New(association.Field.Field.Interface())
if relationship.Kind == "many_to_many" {
relationship.JoinTableHandler.JoinWith(scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, scope.NewDB(), association.Scope.Value).Table(newScope.TableName()).Count(&count)
} else if relationship.Kind == "has_many" || relationship.Kind == "has_one" {
whereSql := fmt.Sprintf("%v.%v = ?", newScope.QuotedTableName(), newScope.Quote(relationship.ForeignDBName))
countScope := scope.DB().Table(newScope.TableName()).Where(whereSql, association.PrimaryKey)

View File

@ -148,7 +148,6 @@ func TestManyToMany(t *testing.T) {
t.Errorf("Query many to many relations")
}
newLanguages = []Language{}
DB.Model(&user).Association("Languages").Find(&newLanguages)
if len(newLanguages) != len([]string{"ZH", "EN"}) {
t.Errorf("Should be able to find many to many relations")
@ -194,7 +193,6 @@ func TestManyToMany(t *testing.T) {
t.Errorf("Language EE should not be deleted")
}
languages = []Language{}
DB.Where("name IN (?)", []string{"CC", "DD"}).Find(&languages)
user2 := User{Name: "Many2Many_User2", Languages: languages}

View File

@ -9,6 +9,7 @@ type callback struct {
updates []*func(scope *Scope)
deletes []*func(scope *Scope)
queries []*func(scope *Scope)
rowQueries []*func(scope *Scope)
processors []*callbackProcessor
}
@ -55,6 +56,10 @@ func (c *callback) Query() *callbackProcessor {
return c.addProcessor("query")
}
func (c *callback) RowQuery() *callbackProcessor {
return c.addProcessor("row_query")
}
func (cp *callbackProcessor) Before(name string) *callbackProcessor {
cp.before = name
return cp
@ -168,7 +173,7 @@ func sortProcessors(cps []*callbackProcessor) []*func(scope *Scope) {
}
func (c *callback) sort() {
creates, updates, deletes, queries := []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}, []*callbackProcessor{}
var creates, updates, deletes, queries, rowQueries []*callbackProcessor
for _, processor := range c.processors {
switch processor.typ {
@ -180,6 +185,8 @@ func (c *callback) sort() {
deletes = append(deletes, processor)
case "query":
queries = append(queries, processor)
case "row_query":
rowQueries = append(rowQueries, processor)
}
}
@ -187,6 +194,7 @@ func (c *callback) sort() {
c.updates = sortProcessors(updates)
c.deletes = sortProcessors(deletes)
c.queries = sortProcessors(queries)
c.rowQueries = sortProcessors(rowQueries)
}
var DefaultCallback = &callback{processors: []*callbackProcessor{}}

View File

@ -70,18 +70,24 @@ func Create(scope *Scope) {
id, err := result.LastInsertId()
if scope.Err(err) == nil {
scope.db.RowsAffected, _ = result.RowsAffected()
if primaryField != nil {
if primaryField != nil && primaryField.IsBlank {
scope.Err(scope.SetColumn(primaryField, id))
}
}
}
} else {
if primaryField == nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err != nil {
if results, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); err == nil {
scope.db.RowsAffected, _ = results.RowsAffected()
} else {
scope.Err(err)
}
} else {
if err := scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())); err == nil {
scope.db.RowsAffected = 1
} else {
scope.Err(err)
}
} else if scope.Err(scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...).Scan(primaryField.Field.Addr().Interface())) == nil {
scope.db.RowsAffected = 1
}
}
}

View File

@ -16,20 +16,22 @@ func Query(scope *Scope) {
destType reflect.Type
)
var dest = scope.IndirectValue()
if value, ok := scope.InstanceGet("gorm:query_destination"); ok {
dest = reflect.Indirect(reflect.ValueOf(value))
}
if orderBy, ok := scope.Get("gorm:order_by_primary_key"); ok {
if primaryKey := scope.PrimaryKey(); primaryKey != "" {
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), primaryKey, orderBy))
scope.Search.Order(fmt.Sprintf("%v.%v %v", scope.QuotedTableName(), scope.Quote(primaryKey), orderBy))
}
}
var dest = scope.IndirectValue()
if value, ok := scope.Get("gorm:query_destination"); ok {
dest = reflect.Indirect(reflect.ValueOf(value))
}
if kind := dest.Kind(); kind == reflect.Slice {
isSlice = true
destType = dest.Type().Elem()
dest.Set(reflect.Indirect(reflect.New(reflect.SliceOf(destType))))
if destType.Kind() == reflect.Ptr {
isPtr = true
destType = destType.Elem()

View File

@ -55,7 +55,7 @@ func SaveAfterAssociations(scope *Scope) {
scope.Err(newDB.Save(elem).Error)
if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil {
scope.Err(joinTableHandler.Add(scope.NewDB(), scope.Value, newScope.Value))
scope.Err(joinTableHandler.Add(joinTableHandler, scope.NewDB(), scope.Value, newScope.Value))
}
}
default:

View File

@ -64,13 +64,15 @@ func Update(scope *Scope) {
}
}
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v",
scope.QuotedTableName(),
strings.Join(sqls, ", "),
scope.CombinedConditionSql(),
))
scope.Exec()
if len(sqls) > 0 {
scope.Raw(fmt.Sprintf(
"UPDATE %v SET %v %v",
scope.QuotedTableName(),
strings.Join(sqls, ", "),
scope.CombinedConditionSql(),
))
scope.Exec()
}
}
}

View File

@ -61,7 +61,7 @@ Gorm is powered by callbacks, so you could refer below links to learn how to wri
[Update callbacks](https://github.com/jinzhu/gorm/blob/master/callback_update.go)
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_create.go)
[Query callbacks](https://github.com/jinzhu/gorm/blob/master/callback_query.go)
[Delete callbacks](https://github.com/jinzhu/gorm/blob/master/callback_delete.go)

View File

@ -10,9 +10,11 @@ import (
type JoinTableHandlerInterface interface {
Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type)
Table(db *DB) string
Add(db *DB, source interface{}, destination interface{}) error
Delete(db *DB, sources ...interface{}) error
JoinWith(db *DB, source interface{}) *DB
Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error
Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error
JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB
SourceForeignKeys() []JoinTableForeignKey
DestinationForeignKeys() []JoinTableForeignKey
}
type JoinTableForeignKey struct {
@ -31,37 +33,58 @@ type JoinTableHandler struct {
Destination JoinTableSource `sql:"-"`
}
func (s *JoinTableHandler) SourceForeignKeys() []JoinTableForeignKey {
return s.Source.ForeignKeys
}
func (s *JoinTableHandler) DestinationForeignKeys() []JoinTableForeignKey {
return s.Destination.ForeignKeys
}
func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, source reflect.Type, destination reflect.Type) {
s.TableName = tableName
s.Source = JoinTableSource{ModelType: source}
sourceScope := &Scope{Value: reflect.New(source).Interface()}
for _, primaryField := range sourceScope.GetModelStruct().PrimaryFields {
sourcePrimaryFields := sourceScope.GetModelStruct().PrimaryFields
for _, primaryField := range sourcePrimaryFields {
if relationship.ForeignDBName == "" {
relationship.ForeignFieldName = source.Name() + primaryField.Name
relationship.ForeignDBName = ToDBName(relationship.ForeignFieldName)
}
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.ForeignDBName
} else {
dbName = ToDBName(source.Name() + primaryField.Name)
}
s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{
DBName: relationship.ForeignDBName,
DBName: dbName,
AssociationDBName: primaryField.DBName,
})
}
s.Destination = JoinTableSource{ModelType: destination}
destinationScope := &Scope{Value: reflect.New(destination).Interface()}
for _, primaryField := range destinationScope.GetModelStruct().PrimaryFields {
if relationship.AssociationForeignDBName == "" {
relationship.AssociationForeignFieldName = destination.Name() + primaryField.Name
relationship.AssociationForeignDBName = ToDBName(relationship.AssociationForeignFieldName)
destinationPrimaryFields := destinationScope.GetModelStruct().PrimaryFields
for _, primaryField := range destinationPrimaryFields {
var dbName string
if len(sourcePrimaryFields) == 1 || primaryField.DBName == "id" {
dbName = relationship.AssociationForeignDBName
} else {
dbName = ToDBName(destinationScope.GetModelStruct().ModelType.Name() + primaryField.Name)
}
s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{
DBName: relationship.AssociationForeignDBName,
DBName: dbName,
AssociationDBName: primaryField.DBName,
})
}
}
func (s JoinTableHandler) Table(*DB) string {
func (s JoinTableHandler) Table(db *DB) string {
return s.TableName
}
@ -85,7 +108,7 @@ func (s JoinTableHandler) GetSearchMap(db *DB, sources ...interface{}) map[strin
return values
}
func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{}) error {
func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source1 interface{}, source2 interface{}) error {
scope := db.NewScope("")
searchMap := s.GetSearchMap(db, source1, source2)
@ -102,7 +125,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{})
values = append(values, value)
}
quotedTable := s.Table(db)
quotedTable := handler.Table(db)
sql := fmt.Sprintf(
"INSERT INTO %v (%v) SELECT %v %v WHERE NOT EXISTS (SELECT * FROM %v WHERE %v)",
quotedTable,
@ -116,7 +139,7 @@ func (s JoinTableHandler) Add(db *DB, source1 interface{}, source2 interface{})
return db.Exec(sql, values...).Error
}
func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error {
var conditions []string
var values []interface{}
@ -125,11 +148,11 @@ func (s JoinTableHandler) Delete(db *DB, sources ...interface{}) error {
values = append(values, value)
}
return db.Table(s.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
return db.Table(handler.Table(db)).Where(strings.Join(conditions, " AND "), values...).Delete("").Error
}
func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
quotedTable := s.Table(db)
func (s JoinTableHandler) JoinWith(handler JoinTableHandlerInterface, db *DB, source interface{}) *DB {
quotedTable := handler.Table(db)
scope := db.NewScope(source)
modelType := scope.GetModelStruct().ModelType
@ -138,7 +161,7 @@ func (s JoinTableHandler) JoinWith(db *DB, source interface{}) *DB {
var values []interface{}
if s.Source.ModelType == modelType {
for _, foreignKey := range s.Destination.ForeignKeys {
destinationTableName := scope.New(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
destinationTableName := db.NewScope(reflect.New(s.Destination.ModelType).Interface()).QuotedTableName()
joinConditions = append(joinConditions, fmt.Sprintf("%v.%v = %v.%v", quotedTable, scope.Quote(foreignKey.DBName), destinationTableName, scope.Quote(foreignKey.AssociationDBName)))
}

View File

@ -22,7 +22,7 @@ type PersonAddress struct {
CreatedAt time.Time
}
func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
func (*PersonAddress) Add(handler gorm.JoinTableHandlerInterface, db *gorm.DB, foreignValue interface{}, associationValue interface{}) error {
return db.Where(map[string]interface{}{
"person_id": db.NewScope(foreignValue).PrimaryKeyValue(),
"address_id": db.NewScope(associationValue).PrimaryKeyValue(),
@ -33,11 +33,11 @@ func (*PersonAddress) Add(db *gorm.DB, foreignValue interface{}, associationValu
}).FirstOrCreate(&PersonAddress{}).Error
}
func (*PersonAddress) Delete(db *gorm.DB, sources ...interface{}) error {
func (*PersonAddress) Delete(handler gorm.JoinTableHandlerInterface, db *gorm.DB, sources ...interface{}) error {
return db.Delete(&PersonAddress{}).Error
}
func (pa *PersonAddress) JoinWith(db *gorm.DB, source interface{}) *gorm.DB {
func (pa *PersonAddress) JoinWith(handler gorm.JoinTableHandlerInterface, db *gorm.DB, source interface{}) *gorm.DB {
table := pa.Table(db)
return db.Table(table).Joins("INNER JOIN person_addresses ON person_addresses.address_id = addresses.id").Where(fmt.Sprintf("%v.deleted_at IS NULL OR %v.deleted_at <= '0001-01-02'", table, table))
}

67
main.go
View File

@ -211,6 +211,10 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB {
return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Scan(dest interface{}) *DB {
return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callback.queries).db
}
func (s *DB) Row() *sql.Row {
return s.NewScope(s.Value).row()
}
@ -219,10 +223,16 @@ func (s *DB) Rows() (*sql.Rows, error) {
return s.NewScope(s.Value).rows()
}
func (s *DB) Scan(dest interface{}) *DB {
scope := s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest)
Query(scope)
return scope.db
func (s *DB) Pluck(column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(column, value).db
}
func (s *DB) Count(value interface{}) *DB {
return s.NewScope(s.Value).count(value).db
}
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
}
func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB {
@ -309,18 +319,6 @@ func (s *DB) Model(value interface{}) *DB {
return c
}
func (s *DB) Related(value interface{}, foreignKeys ...string) *DB {
return s.clone().NewScope(s.Value).related(value, foreignKeys...).db
}
func (s *DB) Pluck(column string, value interface{}) *DB {
return s.NewScope(s.Value).pluck(column, value).db
}
func (s *DB) Count(value interface{}) *DB {
return s.NewScope(s.Value).count(value).db
}
func (s *DB) Table(name string) *DB {
clone := s.clone()
clone.search.Table(name)
@ -398,28 +396,33 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB {
}
func (s *DB) ModifyColumn(column string, typ string) *DB {
s.clone().NewScope(s.Value).modifyColumn(column, typ)
return s
scope := s.clone().NewScope(s.Value)
scope.modifyColumn(column, typ)
return scope.db
}
func (s *DB) DropColumn(column string) *DB {
s.clone().NewScope(s.Value).dropColumn(column)
return s
scope := s.clone().NewScope(s.Value)
scope.dropColumn(column)
return scope.db
}
func (s *DB) AddIndex(indexName string, column ...string) *DB {
s.clone().NewScope(s.Value).addIndex(false, indexName, column...)
return s
scope := s.clone().NewScope(s.Value)
scope.addIndex(false, indexName, column...)
return scope.db
}
func (s *DB) AddUniqueIndex(indexName string, column ...string) *DB {
s.clone().NewScope(s.Value).addIndex(true, indexName, column...)
return s
scope := s.clone().NewScope(s.Value)
scope.addIndex(true, indexName, column...)
return scope.db
}
func (s *DB) RemoveIndex(indexName string) *DB {
s.clone().NewScope(s.Value).removeIndex(indexName)
return s
scope := s.clone().NewScope(s.Value)
scope.removeIndex(indexName)
return scope.db
}
/*
@ -429,8 +432,9 @@ Example:
db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT")
*/
func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB {
s.clone().NewScope(s.Value).addForeignKey(field, dest, onDelete, onUpdate)
return s
scope := s.clone().NewScope(s.Value)
scope.addForeignKey(field, dest, onDelete, onUpdate)
return scope.db
}
func (s *DB) Association(column string) *Association {
@ -475,14 +479,17 @@ func (s *DB) Get(name string) (value interface{}, ok bool) {
}
func (s *DB) SetJoinTableHandler(source interface{}, column string, handler JoinTableHandlerInterface) {
for _, field := range s.NewScope(source).GetModelStruct().StructFields {
scope := s.NewScope(source)
for _, field := range scope.GetModelStruct().StructFields {
if field.Name == column || field.DBName == column {
if many2many := parseTagSetting(field.Tag.Get("gorm"))["MANY2MANY"]; many2many != "" {
source := (&Scope{Value: source}).GetModelStruct().ModelType
destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType
handler.Setup(field.Relationship, many2many, source, destination)
field.Relationship.JoinTableHandler = handler
s.Table(handler.Table(s)).AutoMigrate(handler)
if table := handler.Table(s); scope.Dialect().HasTable(scope, table) {
s.Table(table).AutoMigrate(handler)
}
}
}
}

View File

@ -61,6 +61,19 @@ func init() {
runMigration()
}
func TestStringPrimaryKey(t *testing.T) {
type UUIDStruct struct {
ID string `gorm:"primary_key"`
Name string
}
DB.AutoMigrate(&UUIDStruct{})
data := UUIDStruct{ID: "uuid", Name: "hello"}
if err := DB.Save(&data).Error; err != nil || data.ID != "uuid" {
t.Errorf("string primary key should not be populated")
}
}
func TestExceptionsWithInvalidSql(t *testing.T) {
var columns []string
if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil {

10
model.go Normal file
View File

@ -0,0 +1,10 @@
package gorm
import "time"
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time
}

View File

@ -13,11 +13,19 @@ import (
var modelStructs = map[reflect.Type]*ModelStruct{}
var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}
type ModelStruct struct {
PrimaryFields []*StructField
StructFields []*StructField
ModelType reflect.Type
TableName func(*DB) string
PrimaryFields []*StructField
StructFields []*StructField
ModelType reflect.Type
defaultTableName string
}
func (s ModelStruct) TableName(db *DB) string {
return DefaultTableNameHandler(db, s.defaultTableName)
}
type StructField struct {
@ -94,14 +102,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
// Set tablename
if fm := reflect.New(scopeType).MethodByName("TableName"); fm.IsValid() {
if results := fm.Call([]reflect.Value{}); len(results) > 0 {
if name, ok := results[0].Interface().(string); ok {
modelStruct.TableName = func(*DB) string {
return name
}
}
}
type tabler interface {
TableName() string
}
if tabler, ok := reflect.New(scopeType).Interface().(interface {
TableName() string
}); ok {
modelStruct.defaultTableName = tabler.TableName()
} else {
name := ToDBName(scopeType.Name())
if scope.db == nil || !scope.db.parent.singularTable {
@ -112,9 +120,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
}
modelStruct.TableName = func(*DB) string {
return name
}
modelStruct.defaultTableName = name
}
// Get all fields

View File

@ -0,0 +1,46 @@
package gorm_test
import (
"fmt"
"os"
"testing"
)
type Blog struct {
ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"`
Subject string
Body string
Tags []Tag `gorm:"many2many:blog_tags;"`
}
type Tag struct {
ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"`
Value string
}
func TestManyToManyWithMultiPrimaryKeys(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "" && dialect != "sqlite" {
DB.Exec(fmt.Sprintf("drop table blog_tags;"))
DB.AutoMigrate(&Blog{}, &Tag{})
blog := Blog{
Locale: "ZH",
Subject: "subject",
Body: "body",
Tags: []Tag{
{Locale: "ZH", Value: "tag1"},
{Locale: "ZH", Value: "tag2"},
},
}
DB.Save(&blog)
DB.Model(&blog).Association("Tags").Append([]Tag{{Locale: "ZH", Value: "tag3"}})
var tags []Tag
DB.Model(&blog).Related(&tags, "Tags")
if len(tags) != 3 {
t.Errorf("should found 3 tags with blog")
}
}
}

View File

@ -14,16 +14,26 @@ func (mysql) SqlTag(value reflect.Value, size int, autoIncrease bool) string {
switch value.Kind() {
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32:
if autoIncrease {
return "int AUTO_INCREMENT"
}
return "int"
case reflect.Int64, reflect.Uint64:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if autoIncrease {
return "int unsigned AUTO_INCREMENT"
}
return "int unsigned"
case reflect.Int64:
if autoIncrease {
return "bigint AUTO_INCREMENT"
}
return "bigint"
case reflect.Uint64:
if autoIncrease {
return "bigint unsigned AUTO_INCREMENT"
}
return "bigint unsigned"
case reflect.Float32, reflect.Float64:
return "double"
case reflect.String:

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"strings"
)
func getRealValue(value reflect.Value, field string) interface{} {
@ -20,94 +21,69 @@ func equalAsString(a interface{}, b interface{}) bool {
}
func Preload(scope *Scope) {
if scope.Search.preload == nil {
return
}
preloadMap := map[string]bool{}
fields := scope.Fields()
isSlice := scope.IndirectValue().Kind() == reflect.Slice
for _, preload := range scope.Search.preload {
schema, conditions := preload.schema, preload.conditions
keys := strings.Split(schema, ".")
currentScope := scope
currentFields := fields
originalConditions := conditions
conditions = []interface{}{}
for i, key := range keys {
var found bool
if preloadMap[strings.Join(keys[:i+1], ".")] {
goto nextLoop
}
if scope.Search.preload != nil {
for key, conditions := range scope.Search.preload {
for _, field := range fields {
if field.Name == key && field.Relationship != nil {
results := makeSlice(field.Struct.Type)
relation := field.Relationship
primaryName := scope.PrimaryField().Name
associationPrimaryKey := scope.New(results).PrimaryField().Name
if i == len(keys)-1 {
conditions = originalConditions
}
switch relation.Kind {
case "has_one":
if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if isSlice {
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
}
}
} else {
scope.SetColumn(field, result)
}
}
}
case "has_many":
if primaryKeys := scope.getColumnAsArray(primaryName); len(primaryKeys) > 0 {
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
if isSlice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, primaryName), value) {
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
}
}
}
} else {
scope.SetColumn(field, resultValues)
}
}
case "belongs_to":
if primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName); len(primaryKeys) > 0 {
scope.NewDB().Where(primaryKeys).Find(results, conditions...)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if isSlice {
value := getRealValue(result, associationPrimaryKey)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
object.FieldByName(field.Name).Set(result)
}
}
} else {
scope.SetColumn(field, result)
}
}
}
case "many_to_many":
scope.Err(errors.New("not supported relation"))
default:
scope.Err(errors.New("not supported relation"))
}
break
for _, field := range currentFields {
if field.Name != key || field.Relationship == nil {
continue
}
found = true
switch field.Relationship.Kind {
case "has_one":
currentScope.handleHasOnePreload(field, conditions)
case "has_many":
currentScope.handleHasManyPreload(field, conditions)
case "belongs_to":
currentScope.handleBelongsToPreload(field, conditions)
case "many_to_many":
fallthrough
default:
currentScope.Err(errors.New("not supported relation"))
}
break
}
if !found {
value := reflect.ValueOf(currentScope.Value)
if value.Kind() == reflect.Slice && value.Type().Elem().Kind() == reflect.Interface {
value = value.Index(0).Elem()
}
scope.Err(fmt.Errorf("can't find field %s in %s", key, value.Type()))
return
}
preloadMap[strings.Join(keys[:i+1], ".")] = true
nextLoop:
if i < len(keys)-1 {
currentScope = currentScope.getColumnsAsScope(key)
currentFields = currentScope.Fields()
}
}
}
}
func makeSlice(typ reflect.Type) interface{} {
@ -120,19 +96,147 @@ func makeSlice(typ reflect.Type) interface{} {
return slice.Interface()
}
func (scope *Scope) getColumnAsArray(column string) (primaryKeys []interface{}) {
func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name
primaryKeys := scope.getColumnAsArray(primaryName)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
relation := field.Relationship
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
if equalAsString(getRealValue(objects.Index(j), primaryName), value) {
reflect.Indirect(objects.Index(j)).FieldByName(field.Name).Set(result)
break
}
}
} else {
if err := scope.SetColumn(field, result); err != nil {
scope.Err(err)
return
}
}
}
}
func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) {
primaryName := scope.PrimaryField().Name
primaryKeys := scope.getColumnAsArray(primaryName)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
relation := field.Relationship
condition := fmt.Sprintf("%v IN (?)", scope.Quote(relation.ForeignDBName))
scope.Err(scope.NewDB().Where(condition, primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
if scope.IndirectValue().Kind() == reflect.Slice {
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
value := getRealValue(result, relation.ForeignFieldName)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, primaryName), value) {
f := object.FieldByName(field.Name)
f.Set(reflect.Append(f, result))
break
}
}
}
} else {
scope.SetColumn(field, resultValues)
}
}
func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) {
relation := field.Relationship
primaryKeys := scope.getColumnAsArray(relation.ForeignFieldName)
if len(primaryKeys) == 0 {
return
}
results := makeSlice(field.Struct.Type)
associationPrimaryKey := scope.New(results).PrimaryField().Name
scope.Err(scope.NewDB().Where(primaryKeys).Find(results, conditions...).Error)
resultValues := reflect.Indirect(reflect.ValueOf(results))
for i := 0; i < resultValues.Len(); i++ {
result := resultValues.Index(i)
if scope.IndirectValue().Kind() == reflect.Slice {
value := getRealValue(result, associationPrimaryKey)
objects := scope.IndirectValue()
for j := 0; j < objects.Len(); j++ {
object := reflect.Indirect(objects.Index(j))
if equalAsString(getRealValue(object, relation.ForeignFieldName), value) {
object.FieldByName(field.Name).Set(result)
}
}
} else {
scope.SetColumn(field, result)
}
}
}
func (scope *Scope) getColumnAsArray(column string) (columns []interface{}) {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
primaryKeyMap := map[interface{}]bool{}
for i := 0; i < values.Len(); i++ {
primaryKeyMap[reflect.Indirect(values.Index(i)).FieldByName(column).Interface()] = true
}
for key := range primaryKeyMap {
primaryKeys = append(primaryKeys, key)
columns = append(columns, reflect.Indirect(values.Index(i)).FieldByName(column).Interface())
}
case reflect.Struct:
return []interface{}{values.FieldByName(column).Interface()}
}
return
}
func (scope *Scope) getColumnsAsScope(column string) *Scope {
values := scope.IndirectValue()
switch values.Kind() {
case reflect.Slice:
modelType := values.Type().Elem()
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
fieldStruct, _ := modelType.FieldByName(column)
var columns reflect.Value
if fieldStruct.Type.Kind() == reflect.Slice || fieldStruct.Type.Kind() == reflect.Ptr {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type.Elem()))).Elem()
} else {
columns = reflect.New(reflect.SliceOf(reflect.PtrTo(fieldStruct.Type))).Elem()
}
for i := 0; i < values.Len(); i++ {
column := reflect.Indirect(values.Index(i)).FieldByName(column)
if column.Kind() == reflect.Ptr {
column = column.Elem()
}
if column.Kind() == reflect.Slice {
for i := 0; i < column.Len(); i++ {
columns = reflect.Append(columns, column.Index(i).Addr())
}
} else {
columns = reflect.Append(columns, column.Addr())
}
}
return scope.New(columns.Interface())
case reflect.Struct:
return scope.New(values.FieldByName(column).Addr().Interface())
}
return nil
}

View File

@ -1,6 +1,10 @@
package gorm_test
import "testing"
import (
"encoding/json"
"reflect"
"testing"
)
func getPreloadUser(name string) *User {
return getPreparedUser(name, "Preload")
@ -85,3 +89,521 @@ func TestPreload(t *testing.T) {
}
}
}
func TestNestedPreload1(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1 Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
if err := DB.Create(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload2(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1s []*Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2s []Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{
Level2s: []Level2{
{
Level1s: []*Level1{
&Level1{Value: "value1"},
&Level1{Value: "value2"},
},
},
{
Level1s: []*Level1{
&Level1{Value: "value3"},
},
},
},
}
if err := DB.Create(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload3(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1 Level1
Level3ID uint
}
Level3 struct {
Name string
ID uint
Level2s []Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{
Level2s: []Level2{
{Level1: Level1{Value: "value1"}},
{Level1: Level1{Value: "value2"}},
},
}
if err := DB.Create(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload4(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
},
},
}
if err := DB.Create(&want).Error; err != nil {
panic(err)
}
var got Level3
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
// Slice: []Level3
func TestNestedPreload5(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1 Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload6(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2s []Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{
Level2s: []Level2{
{
Level1s: []Level1{
{Value: "value1"},
{Value: "value2"},
},
},
{
Level1s: []Level1{
{Value: "value3"},
},
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{
Level2s: []Level2{
{
Level1s: []Level1{
{Value: "value3"},
{Value: "value4"},
},
},
{
Level1s: []Level1{
{Value: "value5"},
},
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload7(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1 Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2s []Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{
Level2s: []Level2{
{Level1: Level1{Value: "value1"}},
{Level1: Level1{Value: "value2"}},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{
Level2s: []Level2{
{Level1: Level1{Value: "value3"}},
{Level1: Level1{Value: "value4"}},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload8(t *testing.T) {
type (
Level1 struct {
ID uint
Value string
Level2ID uint
}
Level2 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level1{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value3"},
Level1{Value: "value4"},
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func TestNestedPreload9(t *testing.T) {
type (
Level0 struct {
ID uint
Value string
Level1ID uint
}
Level1 struct {
ID uint
Value string
Level2ID uint
Level2_1ID uint
Level0s []Level0
}
Level2 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level2_1 struct {
ID uint
Level1s []Level1
Level3ID uint
}
Level3 struct {
ID uint
Name string
Level2 Level2
Level2_1 Level2_1
}
)
DB.DropTableIfExists(&Level3{})
DB.DropTableIfExists(&Level2{})
DB.DropTableIfExists(&Level2_1{})
DB.DropTableIfExists(&Level1{})
DB.DropTableIfExists(&Level0{})
if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}).Error; err != nil {
panic(err)
}
want := make([]Level3, 2)
want[0] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value1"},
Level1{Value: "value2"},
},
},
Level2_1: Level2_1{
Level1s: []Level1{
Level1{
Value: "value1-1",
Level0s: []Level0{{Value: "Level0-1"}},
},
Level1{
Value: "value2-2",
Level0s: []Level0{{Value: "Level0-2"}},
},
},
},
}
if err := DB.Create(&want[0]).Error; err != nil {
panic(err)
}
want[1] = Level3{
Level2: Level2{
Level1s: []Level1{
Level1{Value: "value3"},
Level1{Value: "value4"},
},
},
Level2_1: Level2_1{
Level1s: []Level1{
Level1{Value: "value3-3"},
Level1{Value: "value4-4"},
},
},
}
if err := DB.Create(&want[1]).Error; err != nil {
panic(err)
}
var got []Level3
if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil {
panic(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want))
}
}
func toJSONString(v interface{}) []byte {
r, _ := json.MarshalIndent(v, "", " ")
return r
}

View File

@ -98,49 +98,41 @@ func TestSearchWithPlainSQL(t *testing.T) {
t.Errorf("Should found 2 users that age > 1, but got %v", len(users))
}
users = []User{}
DB.Where("name LIKE ?", "%PlainSqlUser%").Where("age >= ?", 1).Find(&users)
if len(users) != 3 {
t.Errorf("Should found 3 users that age >= 1, but got %v", len(users))
}
users = []User{}
scopedb.Where("age <> ?", 20).Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users age != 20, but got %v", len(users))
}
users = []User{}
scopedb.Where("birthday > ?", now.MustParse("2000-1-1")).Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users's birthday > 2000-1-1, but got %v", len(users))
}
users = []User{}
scopedb.Where("birthday > ?", "2002-10-10").Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users's birthday >= 2002-10-10, but got %v", len(users))
}
users = []User{}
scopedb.Where("birthday >= ?", "2010-1-1").Where("birthday < ?", "2020-1-1").Find(&users)
if len(users) != 1 {
t.Errorf("Should found 1 users's birthday < 2020-1-1 and >= 2010-1-1, but got %v", len(users))
}
users = []User{}
DB.Where("name in (?)", []string{user1.Name, user2.Name}).Find(&users)
if len(users) != 2 {
t.Errorf("Should found 2 users, but got %v", len(users))
}
users = []User{}
DB.Where("id in (?)", []int64{user1.Id, user2.Id, user3.Id}).Find(&users)
if len(users) != 3 {
t.Errorf("Should found 3 users, but got %v", len(users))
}
users = []User{}
DB.Where("id in (?)", user1.Id).Find(&users)
if len(users) != 1 {
t.Errorf("Should found 1 users, but got %v", len(users))
@ -191,7 +183,6 @@ func TestSearchWithStruct(t *testing.T) {
t.Errorf("Search first record with where struct")
}
users = []User{}
DB.Find(&users, &User{Name: user2.Name})
if len(users) != 1 {
t.Errorf("Search all records with inline struct")
@ -222,7 +213,6 @@ func TestSearchWithMap(t *testing.T) {
t.Errorf("Search all records with inline map")
}
users = []User{}
DB.Find(&users, map[string]interface{}{"name": user3.Name})
if len(users) != 1 {
t.Errorf("Search all records with inline map")
@ -395,13 +385,11 @@ func TestNot(t *testing.T) {
t.Errorf("Should find all users's name not equal 3")
}
users4 = []User{}
DB.Not("name = ?", "user3").Find(&users4)
if len(users1)-len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")
}
users4 = []User{}
DB.Not("name <> ?", "user3").Find(&users4)
if len(users4) != int(name3Count) {
t.Errorf("Should find all users's name not equal 3")

View File

@ -110,6 +110,14 @@ func (scope *Scope) HasError() bool {
return scope.db.Error != nil
}
func (scope *Scope) PrimaryFields() []*Field {
var fields = []*Field{}
for _, field := range scope.GetModelStruct().PrimaryFields {
fields = append(fields, scope.Fields()[field.DBName])
}
return fields
}
func (scope *Scope) PrimaryField() *Field {
if primaryFields := scope.GetModelStruct().PrimaryFields; len(primaryFields) > 0 {
if len(primaryFields) > 1 {
@ -158,13 +166,18 @@ func (scope *Scope) HasColumn(column string) bool {
func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
if field, ok := column.(*Field); ok {
return field.Set(value)
} else if dbName, ok := column.(string); ok {
} else if name, ok := column.(string); ok {
if field, ok := scope.Fields()[name]; ok {
return field.Set(value)
}
dbName := ToDBName(name)
if field, ok := scope.Fields()[dbName]; ok {
return field.Set(value)
}
dbName = ToDBName(dbName)
if field, ok := scope.Fields()[dbName]; ok {
if field, ok := scope.FieldByName(name); ok {
return field.Set(value)
}
}
@ -172,7 +185,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error {
}
func (scope *Scope) CallMethod(name string, checkError bool) {
if scope.Value == nil && (!checkError || !scope.HasError()) {
if scope.Value == nil || (checkError && scope.HasError()) {
return
}
@ -246,17 +259,14 @@ func (scope *Scope) TableName() string {
return tabler.TableName(scope.db)
}
if scope.GetModelStruct().TableName != nil {
scope.Search.tableName = scope.GetModelStruct().TableName(scope.db)
return scope.Search.tableName
}
scope.Err(errors.New("wrong table name"))
return ""
return scope.GetModelStruct().TableName(scope.db.Model(scope.Value))
}
func (scope *Scope) QuotedTableName() (name string) {
if scope.Search != nil && len(scope.Search.tableName) > 0 {
if strings.Index(scope.Search.tableName, " ") != -1 {
return scope.Search.tableName
}
return scope.Quote(scope.Search.tableName)
} else {
return scope.Quote(scope.TableName())
@ -271,7 +281,7 @@ func (scope *Scope) CombinedConditionSql() string {
func (scope *Scope) FieldByName(name string) (field *Field, ok bool) {
for _, field := range scope.Fields() {
if field.Name == name {
if field.Name == name || field.DBName == name {
return field, true
}
}
@ -290,7 +300,7 @@ func (scope *Scope) Exec() *Scope {
if !scope.HasError() {
if result, err := scope.SqlDB().Exec(scope.Sql, scope.SqlVars...); scope.Err(err) == nil {
if count, err := result.RowsAffected(); err == nil {
if count, err := result.RowsAffected(); scope.Err(err) == nil {
scope.db.RowsAffected = count
}
}
@ -364,6 +374,8 @@ func (scope *Scope) SelectAttrs() []string {
for _, value := range scope.Search.selects {
if str, ok := value.(string); ok {
attrs = append(attrs, str)
} else if strs, ok := value.([]string); ok {
attrs = append(attrs, strs...)
} else if strs, ok := value.([]interface{}); ok {
for _, str := range strs {
attrs = append(attrs, fmt.Sprintf("%v", str))

View File

@ -38,7 +38,7 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
case interface{}:
var sqls []string
for _, field := range scope.New(value).Fields() {
if !field.IsBlank {
if !field.IsIgnored && !field.IsBlank {
sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
}
}
@ -336,12 +336,14 @@ func (scope *Scope) updatedAttrsWithValues(values map[string]interface{}, ignore
func (scope *Scope) row() *sql.Row {
defer scope.Trace(NowFunc())
scope.callCallbacks(scope.db.parent.callback.rowQueries)
scope.prepareQuerySql()
return scope.SqlDB().QueryRow(scope.Sql, scope.SqlVars...)
}
func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.Trace(NowFunc())
scope.callCallbacks(scope.db.parent.callback.rowQueries)
scope.prepareQuerySql()
return scope.SqlDB().Query(scope.Sql, scope.SqlVars...)
}
@ -411,7 +413,7 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope {
if relationship := fromField.Relationship; relationship != nil {
if relationship.Kind == "many_to_many" {
joinTableHandler := relationship.JoinTableHandler
scope.Err(joinTableHandler.JoinWith(toScope.db, scope.Value).Find(value).Error)
scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error)
} else if relationship.Kind == "belongs_to" {
sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey()))
foreignKeyValue := fromFields[relationship.ForeignDBName].Field.Interface()
@ -445,13 +447,19 @@ func (scope *Scope) createJoinTable(field *StructField) {
joinTableHandler := relationship.JoinTableHandler
joinTable := joinTableHandler.Table(scope.db)
if !scope.Dialect().HasTable(scope, joinTable) {
primaryKeySqlType := scope.Dialect().SqlTag(scope.PrimaryField().Field, 255, false)
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)",
scope.Quote(joinTable),
strings.Join([]string{
scope.Quote(relationship.ForeignDBName) + " " + primaryKeySqlType,
scope.Quote(relationship.AssociationForeignDBName) + " " + primaryKeySqlType}, ",")),
).Error)
toScope := &Scope{Value: reflect.New(field.Struct.Type).Interface()}
var sqlTypes []string
for _, s := range []*Scope{scope, toScope} {
for _, primaryField := range s.GetModelStruct().PrimaryFields {
value := reflect.Indirect(reflect.New(primaryField.Struct.Type))
primaryKeySqlType := scope.Dialect().SqlTag(value, 255, false)
dbName := ToDBName(s.GetModelStruct().ModelType.Name() + primaryField.Name)
sqlTypes = append(sqlTypes, scope.Quote(dbName)+" "+primaryKeySqlType)
}
}
scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v)", scope.Quote(joinTable), strings.Join(sqlTypes, ","))).Error)
}
scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler)
}
@ -467,7 +475,7 @@ func (scope *Scope) createTable() *Scope {
}
if field.IsPrimaryKey {
primaryKeys = append(primaryKeys, field.DBName)
primaryKeys = append(primaryKeys, scope.Quote(field.DBName))
}
scope.createJoinTable(field)
}
@ -522,7 +530,7 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on
var table = scope.TableName()
var keyName = fmt.Sprintf("%s_%s_foreign", table, field)
var query = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE %s ON UPDATE %s;`
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), keyName, field, dest, onDelete, onUpdate)).Exec()
scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.Quote(keyName), scope.Quote(field), scope.Quote(dest), onDelete, onUpdate)).Exec()
}
func (scope *Scope) removeIndex(indexName string) {
@ -540,7 +548,7 @@ func (scope *Scope) autoMigrate() *Scope {
if !scope.Dialect().HasColumn(scope, tableName, field.DBName) {
if field.IsNormal {
sqlTag := scope.generateSqlTag(field)
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, field.DBName, sqlTag)).Exec()
scope.Raw(fmt.Sprintf("ALTER TABLE %v ADD %v %v;", quotedTableName, scope.Quote(field.DBName), sqlTag)).Exec()
}
}
scope.createJoinTable(field)

View File

@ -14,7 +14,7 @@ type search struct {
omits []string
orders []string
joins string
preload map[string][]interface{}
preload []searchPreload
offset string
limit string
group string
@ -23,26 +23,14 @@ type search struct {
Unscoped bool
}
type searchPreload struct {
schema string
conditions []interface{}
}
func (s *search) clone() *search {
return &search{
preload: s.preload,
whereConditions: s.whereConditions,
orConditions: s.orConditions,
notConditions: s.notConditions,
havingCondition: s.havingCondition,
initAttrs: s.initAttrs,
assignAttrs: s.assignAttrs,
selects: s.selects,
omits: s.omits,
orders: s.orders,
joins: s.joins,
offset: s.offset,
limit: s.limit,
group: s.group,
tableName: s.tableName,
raw: s.raw,
Unscoped: s.Unscoped,
}
clone := *s
return &clone
}
func (s *search) Where(query interface{}, values ...interface{}) *search {
@ -114,11 +102,15 @@ func (s *search) Joins(query string) *search {
return s
}
func (s *search) Preload(column string, values ...interface{}) *search {
if s.preload == nil {
s.preload = map[string][]interface{}{}
func (s *search) Preload(schema string, values ...interface{}) *search {
var preloads []searchPreload
for _, preload := range s.preload {
if preload.schema != schema {
preloads = append(preloads, preload)
}
}
s.preload[column] = values
preloads = append(preloads, searchPreload{schema, values})
s.preload = preloads
return s
}

View File

@ -1,4 +1,4 @@
dialects=("postgres" "foundation" "mysql" "sqlite")
dialects=("postgres" "mysql" "sqlite")
for dialect in "${dialects[@]}" ; do
GORM_DIALECT=${dialect} go test