Add Set, Get method for DB

This commit is contained in:
Jinzhu 2014-08-20 16:25:01 +08:00
parent a80a0a1636
commit 9012c9ea09
5 changed files with 27 additions and 17 deletions

View File

@ -1022,8 +1022,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111
db.Model(&[]User{}).Limit(10).Do("Search", "search func's argument") db.Model(&[]User{}).Limit(10).Do("Search", "search func's argument")
db.Mode(&User{}).Do("EditForm").Get("edit_form_html") db.Mode(&User{}).Do("EditForm").Get("edit_form_html")
DefaultValue, DefaultTimeZone, R/W Splitting, Validation DefaultValue, DefaultTimeZone, R/W Splitting, Validation
* Getter/Setter
share or not? transaction?
* Github Pages * Github Pages
* Includes * Includes
* AlertColumn, DropColumn * AlertColumn, DropColumn

View File

@ -63,7 +63,7 @@ func TestRelated(t *testing.T) {
CreditCard: CreditCard{Number: "1234567890"}, CreditCard: CreditCard{Number: "1234567890"},
} }
db.Save(&user) db.Debug().Save(&user)
if user.CreditCard.Id == 0 { if user.CreditCard.Id == 0 {
t.Errorf("After user save, credit card should have id") t.Errorf("After user save, credit card should have id")

25
main.go
View File

@ -21,6 +21,7 @@ type DB struct {
tagIdentifier string tagIdentifier string
singularTable bool singularTable bool
source string source string
values map[string]interface{}
} }
func Open(dialect string, drivesources ...string) (DB, error) { func Open(dialect string, drivesources ...string) (DB, error) {
@ -39,7 +40,7 @@ func Open(dialect string, drivesources ...string) (DB, error) {
source = drivesources[1] source = drivesources[1]
} }
db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback, source: source} db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback, source: source, values: map[string]interface{}{}}
db.db, err = sql.Open(driver, source) db.db, err = sql.Open(driver, source)
db.parent = &db db.parent = &db
} }
@ -184,7 +185,7 @@ func (s *DB) Rows() (*sql.Rows, error) {
} }
func (s *DB) Scan(dest interface{}) *DB { func (s *DB) Scan(dest interface{}) *DB {
scope := s.clone().NewScope(s.Value).Set("gorm:query_destination", dest) scope := s.clone().Set("gorm:query_destination", dest).NewScope(s.Value)
Query(scope) Query(scope)
return scope.db return scope.db
} }
@ -212,7 +213,7 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
} }
c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
} else if len(c.search.AssignAttrs) > 0 { } else if len(c.search.AssignAttrs) > 0 {
c.NewScope(out).Set("gorm:update_interface", s.search.AssignAttrs).callCallbacks(s.parent.callback.updates) c.Set("gorm:update_interface", s.search.AssignAttrs).NewScope(out).callCallbacks(s.parent.callback.updates)
} }
return c return c
} }
@ -222,9 +223,10 @@ func (s *DB) Update(attrs ...interface{}) *DB {
} }
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
return s.clone().NewScope(s.Value). return s.clone().
Set("gorm:update_interface", values). Set("gorm:update_interface", values).
Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
NewScope(s.Value).
callCallbacks(s.parent.callback.updates).db callCallbacks(s.parent.callback.updates).db
} }
@ -233,9 +235,10 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
} }
func (s *DB) UpdateColumns(values interface{}) *DB { func (s *DB) UpdateColumns(values interface{}) *DB {
return s.clone().NewScope(s.Value). return s.clone().
Set("gorm:update_interface", values). Set("gorm:update_interface", values).
Set("gorm:update_column", true). Set("gorm:update_column", true).
NewScope(s.Value).
callCallbacks(s.parent.callback.updates).db callCallbacks(s.parent.callback.updates).db
} }
@ -398,3 +401,15 @@ func (s *DB) Association(column string) *Association {
return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field} return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field}
} }
// Set set value by name
func (s *DB) Set(name string, value interface{}) *DB {
s.values[name] = value
return s
}
// Get get value by name
func (s *DB) Get(name string) (value interface{}, ok bool) {
value, ok = s.values[name]
return
}

View File

@ -6,7 +6,7 @@ import (
) )
func (s *DB) clone() *DB { func (s *DB) clone() *DB {
db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error} db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error, values: s.values}
if s.search == nil { if s.search == nil {
db.search = &search{} db.search = &search{}

View File

@ -18,7 +18,6 @@ type Scope struct {
Sql string Sql string
SqlVars []interface{} SqlVars []interface{}
db *DB db *DB
_values map[string]interface{}
skipLeft bool skipLeft bool
primaryKey string primaryKey string
} }
@ -34,7 +33,7 @@ func (scope *Scope) IndirectValue() reflect.Value {
// NewScope create scope for callbacks, including DB's search information // NewScope create scope for callbacks, including DB's search information
func (db *DB) NewScope(value interface{}) *Scope { func (db *DB) NewScope(value interface{}) *Scope {
db.Value = value db.Value = value
return &Scope{db: db, Search: db.search, Value: value, _values: map[string]interface{}{}} return &Scope{db: db, Search: db.search, Value: value}
} }
// New create a new Scope without search information // New create a new Scope without search information
@ -363,15 +362,13 @@ func (scope *Scope) Exec() *Scope {
} }
// Set set value by name // Set set value by name
func (scope *Scope) Set(name string, value interface{}) *Scope { func (scope *Scope) Set(name string, value interface{}) {
scope._values[name] = value scope.db.Set(name, value)
return scope
} }
// Get get value by name // Get get value by name
func (scope *Scope) Get(name string) (value interface{}, ok bool) { func (scope *Scope) Get(name string) (interface{}, bool) {
value, ok = scope._values[name] return scope.db.Get(name)
return
} }
// Trace print sql log // Trace print sql log