diff --git a/README.md b/README.md index 91751944..cb90d4ff 100644 --- a/README.md +++ b/README.md @@ -1022,8 +1022,6 @@ db.Where("email = ?", "x@example.org").Attrs(User{RegisteredIp: "111.111.111.111 db.Model(&[]User{}).Limit(10).Do("Search", "search func's argument") db.Mode(&User{}).Do("EditForm").Get("edit_form_html") DefaultValue, DefaultTimeZone, R/W Splitting, Validation -* Getter/Setter - share or not? transaction? * Github Pages * Includes * AlertColumn, DropColumn diff --git a/association_test.go b/association_test.go index 457b2183..5f765b04 100644 --- a/association_test.go +++ b/association_test.go @@ -63,7 +63,7 @@ func TestRelated(t *testing.T) { CreditCard: CreditCard{Number: "1234567890"}, } - db.Save(&user) + db.Debug().Save(&user) if user.CreditCard.Id == 0 { t.Errorf("After user save, credit card should have id") diff --git a/main.go b/main.go index cee5f59c..587ebf31 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,7 @@ type DB struct { tagIdentifier string singularTable bool source string + values map[string]interface{} } func Open(dialect string, drivesources ...string) (DB, error) { @@ -39,7 +40,7 @@ func Open(dialect string, drivesources ...string) (DB, error) { source = drivesources[1] } - db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback, source: source} + db = DB{dialect: NewDialect(dialect), tagIdentifier: "sql", logger: defaultLogger, callback: DefaultCallback, source: source, values: map[string]interface{}{}} db.db, err = sql.Open(driver, source) db.parent = &db } @@ -184,7 +185,7 @@ func (s *DB) Rows() (*sql.Rows, error) { } func (s *DB) Scan(dest interface{}) *DB { - scope := s.clone().NewScope(s.Value).Set("gorm:query_destination", dest) + scope := s.clone().Set("gorm:query_destination", dest).NewScope(s.Value) Query(scope) return scope.db } @@ -212,7 +213,7 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { } c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates) } else if len(c.search.AssignAttrs) > 0 { - c.NewScope(out).Set("gorm:update_interface", s.search.AssignAttrs).callCallbacks(s.parent.callback.updates) + c.Set("gorm:update_interface", s.search.AssignAttrs).NewScope(out).callCallbacks(s.parent.callback.updates) } return c } @@ -222,9 +223,10 @@ func (s *DB) Update(attrs ...interface{}) *DB { } func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). + return s.clone(). Set("gorm:update_interface", values). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). + NewScope(s.Value). callCallbacks(s.parent.callback.updates).db } @@ -233,9 +235,10 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB { } func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). + return s.clone(). Set("gorm:update_interface", values). Set("gorm:update_column", true). + NewScope(s.Value). callCallbacks(s.parent.callback.updates).db } @@ -398,3 +401,15 @@ func (s *DB) Association(column string) *Association { return &Association{Scope: scope, Column: column, Error: s.Error, PrimaryKey: primaryKey, Field: field} } + +// Set set value by name +func (s *DB) Set(name string, value interface{}) *DB { + s.values[name] = value + return s +} + +// Get get value by name +func (s *DB) Get(name string) (value interface{}, ok bool) { + value, ok = s.values[name] + return +} diff --git a/main_private.go b/main_private.go index dbb59dd0..dab50f2a 100644 --- a/main_private.go +++ b/main_private.go @@ -6,7 +6,7 @@ import ( ) func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error} + db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error, values: s.values} if s.search == nil { db.search = &search{} diff --git a/scope.go b/scope.go index 18d36501..d65f5649 100644 --- a/scope.go +++ b/scope.go @@ -18,7 +18,6 @@ type Scope struct { Sql string SqlVars []interface{} db *DB - _values map[string]interface{} skipLeft bool primaryKey string } @@ -34,7 +33,7 @@ func (scope *Scope) IndirectValue() reflect.Value { // NewScope create scope for callbacks, including DB's search information func (db *DB) NewScope(value interface{}) *Scope { db.Value = value - return &Scope{db: db, Search: db.search, Value: value, _values: map[string]interface{}{}} + return &Scope{db: db, Search: db.search, Value: value} } // New create a new Scope without search information @@ -363,15 +362,13 @@ func (scope *Scope) Exec() *Scope { } // Set set value by name -func (scope *Scope) Set(name string, value interface{}) *Scope { - scope._values[name] = value - return scope +func (scope *Scope) Set(name string, value interface{}) { + scope.db.Set(name, value) } // Get get value by name -func (scope *Scope) Get(name string) (value interface{}, ok bool) { - value, ok = scope._values[name] - return +func (scope *Scope) Get(name string) (interface{}, bool) { + return scope.db.Get(name) } // Trace print sql log