Add InstanceSet, InstanceGet fomr Scope
This commit is contained in:
		
							parent
							
								
									9012c9ea09
								
							
						
					
					
						commit
						39ac95adbb
					
				@ -63,7 +63,7 @@ func TestRelated(t *testing.T) {
 | 
				
			|||||||
		CreditCard:      CreditCard{Number: "1234567890"},
 | 
							CreditCard:      CreditCard{Number: "1234567890"},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db.Debug().Save(&user)
 | 
						db.Save(&user)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if user.CreditCard.Id == 0 {
 | 
						if user.CreditCard.Id == 0 {
 | 
				
			||||||
		t.Errorf("After user save, credit card should have id")
 | 
							t.Errorf("After user save, credit card should have id")
 | 
				
			||||||
 | 
				
			|||||||
@ -17,7 +17,7 @@ func Query(scope *Scope) {
 | 
				
			|||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var dest = scope.IndirectValue()
 | 
						var dest = scope.IndirectValue()
 | 
				
			||||||
	if value, ok := scope.Get("gorm:query_destination"); ok {
 | 
						if value, ok := scope.InstanceGet("gorm:query_destination"); ok {
 | 
				
			||||||
		dest = reflect.Indirect(reflect.ValueOf(value))
 | 
							dest = reflect.Indirect(reflect.ValueOf(value))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -7,16 +7,16 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func AssignUpdateAttributes(scope *Scope) {
 | 
					func AssignUpdateAttributes(scope *Scope) {
 | 
				
			||||||
	if attrs, ok := scope.Get("gorm:update_interface"); ok {
 | 
						if attrs, ok := scope.InstanceGet("gorm:update_interface"); ok {
 | 
				
			||||||
		if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
 | 
							if maps := convertInterfaceToMap(attrs); len(maps) > 0 {
 | 
				
			||||||
			protected, ok := scope.Get("gorm:ignore_protected_attrs")
 | 
								protected, ok := scope.Get("gorm:ignore_protected_attrs")
 | 
				
			||||||
			_, updateColumn := scope.Get("gorm:update_column")
 | 
								_, updateColumn := scope.Get("gorm:update_column")
 | 
				
			||||||
			updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
 | 
								updateAttrs, hasUpdate := scope.updatedAttrsWithValues(maps, ok && protected.(bool))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if updateColumn {
 | 
								if updateColumn {
 | 
				
			||||||
				scope.Set("gorm:update_attrs", maps)
 | 
									scope.InstanceSet("gorm:update_attrs", maps)
 | 
				
			||||||
			} else if len(updateAttrs) > 0 {
 | 
								} else if len(updateAttrs) > 0 {
 | 
				
			||||||
				scope.Set("gorm:update_attrs", updateAttrs)
 | 
									scope.InstanceSet("gorm:update_attrs", updateAttrs)
 | 
				
			||||||
			} else if !hasUpdate {
 | 
								} else if !hasUpdate {
 | 
				
			||||||
				scope.SkipLeft()
 | 
									scope.SkipLeft()
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
@ -44,7 +44,7 @@ func Update(scope *Scope) {
 | 
				
			|||||||
	if !scope.HasError() {
 | 
						if !scope.HasError() {
 | 
				
			||||||
		var sqls []string
 | 
							var sqls []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		updateAttrs, ok := scope.Get("gorm:update_attrs")
 | 
							updateAttrs, ok := scope.InstanceGet("gorm:update_attrs")
 | 
				
			||||||
		if ok {
 | 
							if ok {
 | 
				
			||||||
			for key, value := range updateAttrs.(map[string]interface{}) {
 | 
								for key, value := range updateAttrs.(map[string]interface{}) {
 | 
				
			||||||
				sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
 | 
									sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										18
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										18
									
								
								main.go
									
									
									
									
									
								
							@ -185,7 +185,7 @@ func (s *DB) Rows() (*sql.Rows, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) Scan(dest interface{}) *DB {
 | 
					func (s *DB) Scan(dest interface{}) *DB {
 | 
				
			||||||
	scope := s.clone().Set("gorm:query_destination", dest).NewScope(s.Value)
 | 
						scope := s.clone().NewScope(s.Value).InstanceSet("gorm:query_destination", dest)
 | 
				
			||||||
	Query(scope)
 | 
						Query(scope)
 | 
				
			||||||
	return scope.db
 | 
						return scope.db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -213,7 +213,7 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
 | 
							c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(s.parent.callback.creates)
 | 
				
			||||||
	} else if len(c.search.AssignAttrs) > 0 {
 | 
						} else if len(c.search.AssignAttrs) > 0 {
 | 
				
			||||||
		c.Set("gorm:update_interface", s.search.AssignAttrs).NewScope(out).callCallbacks(s.parent.callback.updates)
 | 
							c.NewScope(out).InstanceSet("gorm:update_interface", s.search.AssignAttrs).callCallbacks(s.parent.callback.updates)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return c
 | 
						return c
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -223,10 +223,9 @@ func (s *DB) Update(attrs ...interface{}) *DB {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
 | 
					func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB {
 | 
				
			||||||
	return s.clone().
 | 
						return s.clone().NewScope(s.Value).
 | 
				
			||||||
		Set("gorm:update_interface", values).
 | 
					 | 
				
			||||||
		Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
 | 
							Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0).
 | 
				
			||||||
		NewScope(s.Value).
 | 
							InstanceSet("gorm:update_interface", values).
 | 
				
			||||||
		callCallbacks(s.parent.callback.updates).db
 | 
							callCallbacks(s.parent.callback.updates).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -235,10 +234,9 @@ func (s *DB) UpdateColumn(attrs ...interface{}) *DB {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) UpdateColumns(values interface{}) *DB {
 | 
					func (s *DB) UpdateColumns(values interface{}) *DB {
 | 
				
			||||||
	return s.clone().
 | 
						return s.clone().NewScope(s.Value).
 | 
				
			||||||
		Set("gorm:update_interface", values).
 | 
					 | 
				
			||||||
		Set("gorm:update_column", true).
 | 
							Set("gorm:update_column", true).
 | 
				
			||||||
		NewScope(s.Value).
 | 
							InstanceSet("gorm:update_interface", values).
 | 
				
			||||||
		callCallbacks(s.parent.callback.updates).db
 | 
							callCallbacks(s.parent.callback.updates).db
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -404,6 +402,10 @@ func (s *DB) Association(column string) *Association {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// Set set value by name
 | 
					// Set set value by name
 | 
				
			||||||
func (s *DB) Set(name string, value interface{}) *DB {
 | 
					func (s *DB) Set(name string, value interface{}) *DB {
 | 
				
			||||||
 | 
						return s.clone().set(name, value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *DB) set(name string, value interface{}) *DB {
 | 
				
			||||||
	s.values[name] = value
 | 
						s.values[name] = value
 | 
				
			||||||
	return s
 | 
						return s
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -6,7 +6,11 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *DB) clone() *DB {
 | 
					func (s *DB) clone() *DB {
 | 
				
			||||||
	db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error, values: s.values}
 | 
						db := DB{db: s.db, parent: s.parent, logMode: s.logMode, Value: s.Value, Error: s.Error, values: map[string]interface{}{}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for key, value := range s.values {
 | 
				
			||||||
 | 
							db.values[key] = value
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if s.search == nil {
 | 
						if s.search == nil {
 | 
				
			||||||
		db.search = &search{}
 | 
							db.search = &search{}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										14
									
								
								main_test.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								main_test.go
									
									
									
									
									
								
							@ -494,6 +494,20 @@ func TestHstore(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSetAndGet(t *testing.T) {
 | 
				
			||||||
 | 
						if value, ok := db.Set("hello", "world").Get("hello"); !ok {
 | 
				
			||||||
 | 
							t.Errorf("Should be able to get setting after set")
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							if value.(string) != "world" {
 | 
				
			||||||
 | 
								t.Errorf("Setted value should not be changed")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, ok := db.Get("non_existing"); ok {
 | 
				
			||||||
 | 
							t.Errorf("Get non existing key should return error")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestCompatibilityMode(t *testing.T) {
 | 
					func TestCompatibilityMode(t *testing.T) {
 | 
				
			||||||
	db, _ := gorm.Open("testdb", "")
 | 
						db, _ := gorm.Open("testdb", "")
 | 
				
			||||||
	testdb.SetQueryFunc(func(query string) (driver.Rows, error) {
 | 
						testdb.SetQueryFunc(func(query string) (driver.Rows, error) {
 | 
				
			||||||
 | 
				
			|||||||
@ -332,7 +332,6 @@ func TestCount(t *testing.T) {
 | 
				
			|||||||
func TestNot(t *testing.T) {
 | 
					func TestNot(t *testing.T) {
 | 
				
			||||||
	var users1, users2, users3, users4, users5, users6, users7, users8 []User
 | 
						var users1, users2, users3, users4, users5, users6, users7, users8 []User
 | 
				
			||||||
	db.Find(&users1)
 | 
						db.Find(&users1)
 | 
				
			||||||
 | 
					 | 
				
			||||||
	db.Not(users1[0].Id).Find(&users2)
 | 
						db.Not(users1[0].Id).Find(&users2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(users1)-len(users2) != 1 {
 | 
						if len(users1)-len(users2) != 1 {
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										26
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								scope.go
									
									
									
									
									
								
							@ -20,6 +20,7 @@ type Scope struct {
 | 
				
			|||||||
	db            *DB
 | 
						db            *DB
 | 
				
			||||||
	skipLeft      bool
 | 
						skipLeft      bool
 | 
				
			||||||
	primaryKey    string
 | 
						primaryKey    string
 | 
				
			||||||
 | 
						instanceId    string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) IndirectValue() reflect.Value {
 | 
					func (scope *Scope) IndirectValue() reflect.Value {
 | 
				
			||||||
@ -362,8 +363,9 @@ func (scope *Scope) Exec() *Scope {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Set set value by name
 | 
					// Set set value by name
 | 
				
			||||||
func (scope *Scope) Set(name string, value interface{}) {
 | 
					func (scope *Scope) Set(name string, value interface{}) *Scope {
 | 
				
			||||||
	scope.db.Set(name, value)
 | 
						scope.db.set(name, value)
 | 
				
			||||||
 | 
						return scope
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Get get value by name
 | 
					// Get get value by name
 | 
				
			||||||
@ -371,6 +373,22 @@ func (scope *Scope) Get(name string) (interface{}, bool) {
 | 
				
			|||||||
	return scope.db.Get(name)
 | 
						return scope.db.Get(name)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// InstanceId get InstanceId for scope
 | 
				
			||||||
 | 
					func (scope *Scope) InstanceId() string {
 | 
				
			||||||
 | 
						if scope.instanceId == "" {
 | 
				
			||||||
 | 
							scope.instanceId = fmt.Sprintf("%v", &scope)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return scope.instanceId
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) InstanceSet(name string, value interface{}) *Scope {
 | 
				
			||||||
 | 
						return scope.Set(name+scope.InstanceId(), value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (scope *Scope) InstanceGet(name string) (interface{}, bool) {
 | 
				
			||||||
 | 
						return scope.Get(name + scope.InstanceId())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Trace print sql log
 | 
					// Trace print sql log
 | 
				
			||||||
func (scope *Scope) Trace(t time.Time) {
 | 
					func (scope *Scope) Trace(t time.Time) {
 | 
				
			||||||
	if len(scope.Sql) > 0 {
 | 
						if len(scope.Sql) > 0 {
 | 
				
			||||||
@ -383,7 +401,7 @@ func (scope *Scope) Begin() *Scope {
 | 
				
			|||||||
	if db, ok := scope.DB().(sqlDb); ok {
 | 
						if db, ok := scope.DB().(sqlDb); ok {
 | 
				
			||||||
		if tx, err := db.Begin(); err == nil {
 | 
							if tx, err := db.Begin(); err == nil {
 | 
				
			||||||
			scope.db.db = interface{}(tx).(sqlCommon)
 | 
								scope.db.db = interface{}(tx).(sqlCommon)
 | 
				
			||||||
			scope.Set("gorm:started_transaction", true)
 | 
								scope.InstanceSet("gorm:started_transaction", true)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return scope
 | 
						return scope
 | 
				
			||||||
@ -391,7 +409,7 @@ func (scope *Scope) Begin() *Scope {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// CommitOrRollback commit current transaction if there is no error, otherwise rollback it
 | 
					// CommitOrRollback commit current transaction if there is no error, otherwise rollback it
 | 
				
			||||||
func (scope *Scope) CommitOrRollback() *Scope {
 | 
					func (scope *Scope) CommitOrRollback() *Scope {
 | 
				
			||||||
	if _, ok := scope.Get("gorm:started_transaction"); ok {
 | 
						if _, ok := scope.InstanceGet("gorm:started_transaction"); ok {
 | 
				
			||||||
		if db, ok := scope.db.db.(sqlTx); ok {
 | 
							if db, ok := scope.db.db.(sqlTx); ok {
 | 
				
			||||||
			if scope.HasError() {
 | 
								if scope.HasError() {
 | 
				
			||||||
				db.Rollback()
 | 
									db.Rollback()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user