Update structure

This commit is contained in:
Jinzhu 2018-02-26 22:29:21 +08:00
parent ef35e82691
commit b0c6124926
7 changed files with 55 additions and 31 deletions

2
.gitignore vendored
View File

@ -1,2 +0,0 @@
documents
_book

35
api.go
View File

@ -2,7 +2,6 @@ package gorm
import ( import (
"github.com/jinzhu/gorm/builder" "github.com/jinzhu/gorm/builder"
"github.com/jinzhu/gorm/dialects"
) )
// Where add condition // Where add condition
@ -150,12 +149,17 @@ func (s *DB) Find(out interface{}, where ...interface{}) *DB {
stmt := tx.Statement stmt := tx.Statement
stmt.Dest = out stmt.Dest = out
// has inline condition
if len(where) > 0 { if len(where) > 0 {
clone := tx.clone()
stmt = s.Statement.Clone() stmt = s.Statement.Clone()
stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...)) stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...))
tx.AddError(clone.Dialect().Query(clone))
tx.AddError(clone.Error)
} else {
tx.AddError(tx.Dialect().Query(tx))
} }
tx.AddError(tx.Dialect().Query(stmt))
return tx return tx
} }
@ -168,7 +172,7 @@ func (s *DB) Scan(dest interface{}) *DB {
stmt.Table = stmt.Dest stmt.Table = stmt.Dest
stmt.Dest = dest stmt.Dest = dest
tx.AddError(tx.Dialect().Query(stmt)) tx.AddError(tx.Dialect().Query(tx))
return tx return tx
} }
@ -176,7 +180,7 @@ func (s *DB) Scan(dest interface{}) *DB {
func (s *DB) Create(value interface{}) *DB { func (s *DB) Create(value interface{}) *DB {
tx := s.init() tx := s.init()
tx.Statement.Dest = value tx.Statement.Dest = value
tx.AddError(tx.Dialect().Insert(tx.Statement)) tx.AddError(tx.Dialect().Insert(tx))
return tx return tx
} }
@ -185,7 +189,7 @@ func (s *DB) Save(value interface{}) *DB {
tx := s.init() tx := s.init()
tx.Statement.Dest = value tx.Statement.Dest = value
// FIXME check primary key has value or not // FIXME check primary key has value or not
tx.AddError(tx.Dialect().Update(tx.Statement)) tx.AddError(tx.Dialect().Update(tx))
return tx return tx
} }
@ -193,7 +197,7 @@ func (s *DB) Save(value interface{}) *DB {
func (s *DB) Update(column string, value interface{}) *DB { func (s *DB) Update(column string, value interface{}) *DB {
tx := s.init() tx := s.init()
tx.Statement.Assignments = append(tx.Statement.Assignments, builder.Assignment{Column: column, Value: value}) tx.Statement.Assignments = append(tx.Statement.Assignments, builder.Assignment{Column: column, Value: value})
tx.AddError(tx.Dialect().Update(tx.Statement)) tx.AddError(tx.Dialect().Update(tx))
return tx return tx
} }
@ -201,7 +205,7 @@ func (s *DB) Update(column string, value interface{}) *DB {
func (s *DB) Updates(values interface{}) *DB { func (s *DB) Updates(values interface{}) *DB {
tx := s.init() tx := s.init()
tx.Statement.Assignments = append(tx.Statement.Assignments, builder.Assignment{Value: values}) tx.Statement.Assignments = append(tx.Statement.Assignments, builder.Assignment{Value: values})
tx.AddError(tx.Dialect().Update(tx.Statement)) tx.AddError(tx.Dialect().Update(tx))
return tx return tx
} }
@ -211,12 +215,17 @@ func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
stmt := tx.Statement stmt := tx.Statement
stmt.Dest = value stmt.Dest = value
// has inline condition
if len(where) > 0 { if len(where) > 0 {
clone := tx.clone()
stmt = s.Statement.Clone() stmt = s.Statement.Clone()
stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...)) stmt.Conditions = append(stmt.Conditions, s.Statement.BuildCondition(where[0], where[1:]...))
tx.AddError(clone.Dialect().Update(clone))
tx.AddError(clone.Error)
} else {
tx.AddError(tx.Dialect().Update(tx))
} }
tx.AddError(tx.Dialect().Update(stmt))
return tx return tx
} }
@ -264,7 +273,7 @@ func (s *DB) GetErrors() []error {
} }
// Dialect return DB dialect // Dialect return DB dialect
func (s *DB) Dialect() dialects.Dialect { func (s *DB) Dialect() Dialect {
if s.TxDialect != nil { if s.TxDialect != nil {
return s.TxDialect return s.TxDialect
} }
@ -302,3 +311,11 @@ func (s *DB) init() *DB {
} }
return s return s
} }
func (s *DB) clone() *DB {
return &DB{
TxDialect: s.TxDialect,
Statement: s.Statement,
Config: s.Config,
}
}

9
dialect.go Normal file
View File

@ -0,0 +1,9 @@
package gorm
// Dialect GORM dialect interface
type Dialect interface {
Insert(*DB) error
Query(*DB) error
Update(*DB) error
Delete(*DB) error
}

View File

@ -1,13 +0,0 @@
package dialects
import (
"github.com/jinzhu/gorm/builder"
)
// Dialect GORM dialect interface
type Dialect interface {
Insert(*builder.Statement) error
Query(*builder.Statement) error
Update(*builder.Statement) error
Delete(*builder.Statement) error
}

View File

@ -4,7 +4,6 @@ import (
"time" "time"
"github.com/jinzhu/gorm/builder" "github.com/jinzhu/gorm/builder"
"github.com/jinzhu/gorm/dialects"
"github.com/jinzhu/gorm/logger" "github.com/jinzhu/gorm/logger"
) )
@ -24,12 +23,12 @@ type Config struct {
LogMode logger.LogLevel LogMode logger.LogLevel
// Dialect DB Dialect // Dialect DB Dialect
Dialect dialects.Dialect Dialect Dialect
} }
// DB GORM DB definition // DB GORM DB definition
type DB struct { type DB struct {
TxDialect dialects.Dialect TxDialect Dialect
Statement *builder.Statement Statement *builder.Statement
// Global config // Global config

View File

@ -1,4 +1,4 @@
package utils package model
import ( import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -18,7 +18,22 @@ func GetCreatingAssignments(stmt *builder.Statement, errs *gorm.Errors) chan []s
// GetTable get table name // GetTable get table name
func GetTable(stmt *builder.Statement, errs *gorm.Errors) chan string { func GetTable(stmt *builder.Statement, errs *gorm.Errors) chan string {
return nil tableChan := make(chan string)
go func() {
if stmt.Table != nil {
if table, ok := stmt.Table.(string); ok {
tableChan <- DefaultTableNameHandler(stmt, table)
} else if tableSchema := schema.Parse(stmt.Table); tableSchema != nil {
if tableSchema.TableName != "" {
tableChan <- DefaultTableNameHandler(stmt, tableSchema.TableName)
}
tableSchema.ModelType.Name
}
}
}()
return tableChan
} }
// if scope.Value == nil { // if scope.Value == nil {
@ -38,6 +53,5 @@ func GetTable(stmt *builder.Statement, errs *gorm.Errors) chan string {
// s.defaultTableName = tableName // s.defaultTableName = tableName
// } // }
// } // }
// return DefaultTableNameHandler(db, s.defaultTableName) // return DefaultTableNameHandler(db, s.defaultTableName)
// } // }

View File

@ -1,4 +1,4 @@
package utils package model
// ToSearchableMap convert attrs to searchable map // ToSearchableMap convert attrs to searchable map
func ToSearchableMap(attrs ...interface{}) (result interface{}) { func ToSearchableMap(attrs ...interface{}) (result interface{}) {