Fix Scopes with Row, close #4465
This commit is contained in:
		
							parent
							
								
									3226937f68
								
							
						
					
					
						commit
						8e67a08774
					
				| @ -373,7 +373,7 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, | |||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if tx.Statement.FullSaveAssociations { | 	if tx.Statement.FullSaveAssociations { | ||||||
| 		tx = tx.InstanceSet("gorm:update_track_time", true) | 		tx = tx.Set("gorm:update_track_time", true) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if len(selects) > 0 { | 	if len(selects) > 0 { | ||||||
|  | |||||||
| @ -243,9 +243,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | |||||||
| 	default: | 	default: | ||||||
| 		var ( | 		var ( | ||||||
| 			selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) | 			selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) | ||||||
|  | 			_, updateTrackTime        = stmt.Get("gorm:update_track_time") | ||||||
| 			curTime                   = stmt.DB.NowFunc() | 			curTime                   = stmt.DB.NowFunc() | ||||||
| 			isZero                    bool | 			isZero                    bool | ||||||
| 		) | 		) | ||||||
|  | 		stmt.Settings.Delete("gorm:update_track_time") | ||||||
|  | 
 | ||||||
| 		values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} | 		values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} | ||||||
| 
 | 
 | ||||||
| 		for _, db := range stmt.Schema.DBNames { | 		for _, db := range stmt.Schema.DBNames { | ||||||
| @ -284,11 +287,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | |||||||
| 							field.Set(rv, curTime) | 							field.Set(rv, curTime) | ||||||
| 							values.Values[i][idx], _ = field.ValueOf(rv) | 							values.Values[i][idx], _ = field.ValueOf(rv) | ||||||
| 						} | 						} | ||||||
| 					} else if field.AutoUpdateTime > 0 { | 					} else if field.AutoUpdateTime > 0 && updateTrackTime { | ||||||
| 						if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { | 						field.Set(rv, curTime) | ||||||
| 							field.Set(rv, curTime) | 						values.Values[i][idx], _ = field.ValueOf(rv) | ||||||
| 							values.Values[i][idx], _ = field.ValueOf(rv) |  | ||||||
| 						} |  | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| @ -326,11 +327,9 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | |||||||
| 						field.Set(stmt.ReflectValue, curTime) | 						field.Set(stmt.ReflectValue, curTime) | ||||||
| 						values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) | 						values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) | ||||||
| 					} | 					} | ||||||
| 				} else if field.AutoUpdateTime > 0 { | 				} else if field.AutoUpdateTime > 0 && updateTrackTime { | ||||||
| 					if _, ok := stmt.DB.InstanceGet("gorm:update_track_time"); ok { | 					field.Set(stmt.ReflectValue, curTime) | ||||||
| 						field.Set(stmt.ReflectValue, curTime) | 					values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) | ||||||
| 						values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) |  | ||||||
| 					} |  | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -9,7 +9,8 @@ func RowQuery(db *gorm.DB) { | |||||||
| 		BuildQuerySQL(db) | 		BuildQuerySQL(db) | ||||||
| 
 | 
 | ||||||
| 		if !db.DryRun { | 		if !db.DryRun { | ||||||
| 			if isRows, ok := db.InstanceGet("rows"); ok && isRows.(bool) { | 			if isRows, ok := db.Get("rows"); ok && isRows.(bool) { | ||||||
|  | 				db.Statement.Settings.Delete("rows") | ||||||
| 				db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 				db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
| 			} else { | 			} else { | ||||||
| 				db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | 				db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) | ||||||
|  | |||||||
| @ -79,7 +79,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { | |||||||
| 		if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { | 		if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { | ||||||
| 			tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) | 			tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) | ||||||
| 		} | 		} | ||||||
| 		tx = tx.callbacks.Create().Execute(tx.InstanceSet("gorm:update_track_time", true)) | 		tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true)) | ||||||
| 	case reflect.Struct: | 	case reflect.Struct: | ||||||
| 		if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | 		if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { | ||||||
| 			for _, pf := range tx.Statement.Schema.PrimaryFields { | 			for _, pf := range tx.Statement.Schema.PrimaryFields { | ||||||
| @ -426,7 +426,7 @@ func (db *DB) Count(count *int64) (tx *DB) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) Row() *sql.Row { | func (db *DB) Row() *sql.Row { | ||||||
| 	tx := db.getInstance().InstanceSet("rows", false) | 	tx := db.getInstance().Set("rows", false) | ||||||
| 	tx = tx.callbacks.Row().Execute(tx) | 	tx = tx.callbacks.Row().Execute(tx) | ||||||
| 	row, ok := tx.Statement.Dest.(*sql.Row) | 	row, ok := tx.Statement.Dest.(*sql.Row) | ||||||
| 	if !ok && tx.DryRun { | 	if !ok && tx.DryRun { | ||||||
| @ -436,7 +436,7 @@ func (db *DB) Row() *sql.Row { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (db *DB) Rows() (*sql.Rows, error) { | func (db *DB) Rows() (*sql.Rows, error) { | ||||||
| 	tx := db.getInstance().InstanceSet("rows", true) | 	tx := db.getInstance().Set("rows", true) | ||||||
| 	tx = tx.callbacks.Row().Execute(tx) | 	tx = tx.callbacks.Row().Execute(tx) | ||||||
| 	rows, ok := tx.Statement.Dest.(*sql.Rows) | 	rows, ok := tx.Statement.Dest.(*sql.Rows) | ||||||
| 	if !ok && tx.DryRun && tx.Error == nil { | 	if !ok && tx.DryRun && tx.Error == nil { | ||||||
|  | |||||||
| @ -124,7 +124,6 @@ func TestCount(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	var count9 int64 | 	var count9 int64 | ||||||
| 	if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB { | 	if err := DB.Debug().Scopes(func(tx *gorm.DB) *gorm.DB { | ||||||
| 		fmt.Println("kdkdkdkdk") |  | ||||||
| 		return tx.Table("users") | 		return tx.Table("users") | ||||||
| 	}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { | 	}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { | ||||||
| 		t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) | 		t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err)) | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| package tests_test | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| @ -62,4 +63,12 @@ func TestScopes(t *testing.T) { | |||||||
| 	if result.RowsAffected != 2 { | 	if result.RowsAffected != 2 { | ||||||
| 		t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) | 		t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	var maxId int64 | ||||||
|  | 	userTable := func(db *gorm.DB) *gorm.DB { | ||||||
|  | 		return db.WithContext(context.Background()).Table("users") | ||||||
|  | 	} | ||||||
|  | 	if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil { | ||||||
|  | 		t.Errorf("select max(id)") | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu