Add returning tests
This commit is contained in:
		
							parent
							
								
									835d7bde59
								
							
						
					
					
						commit
						e953880d19
					
				| @ -84,7 +84,10 @@ func Update(config *Config) func(db *gorm.DB) { | |||||||
| 		if !db.DryRun && db.Error == nil { | 		if !db.DryRun && db.Error == nil { | ||||||
| 			if ok, mode := hasReturning(db, supportReturning); ok { | 			if ok, mode := hasReturning(db, supportReturning); ok { | ||||||
| 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | 				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { | ||||||
|  | 					dest := db.Statement.Dest | ||||||
|  | 					db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() | ||||||
| 					gorm.Scan(rows, db, mode) | 					gorm.Scan(rows, db, mode) | ||||||
|  | 					db.Statement.Dest = dest | ||||||
| 					rows.Close() | 					rows.Close() | ||||||
| 				} | 				} | ||||||
| 			} else { | 			} else { | ||||||
| @ -152,20 +155,23 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { | |||||||
| 	if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { | 	if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { | ||||||
| 		switch stmt.ReflectValue.Kind() { | 		switch stmt.ReflectValue.Kind() { | ||||||
| 		case reflect.Slice, reflect.Array: | 		case reflect.Slice, reflect.Array: | ||||||
| 			var primaryKeyExprs []clause.Expression | 			if size := stmt.ReflectValue.Len(); size > 0 { | ||||||
| 			for i := 0; i < stmt.ReflectValue.Len(); i++ { | 				var primaryKeyExprs []clause.Expression | ||||||
| 				var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) | 				for i := 0; i < stmt.ReflectValue.Len(); i++ { | ||||||
| 				var notZero bool | 					var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) | ||||||
| 				for idx, field := range stmt.Schema.PrimaryFields { | 					var notZero bool | ||||||
| 					value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) | 					for idx, field := range stmt.Schema.PrimaryFields { | ||||||
| 					exprs[idx] = clause.Eq{Column: field.DBName, Value: value} | 						value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) | ||||||
| 					notZero = notZero || !isZero | 						exprs[idx] = clause.Eq{Column: field.DBName, Value: value} | ||||||
| 				} | 						notZero = notZero || !isZero | ||||||
| 				if notZero { | 					} | ||||||
| 					primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) | 					if notZero { | ||||||
|  | 						primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) | ||||||
|  | 					} | ||||||
| 				} | 				} | ||||||
|  | 
 | ||||||
|  | 				stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) | ||||||
| 			} | 			} | ||||||
| 			stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) |  | ||||||
| 		case reflect.Struct: | 		case reflect.Struct: | ||||||
| 			for _, field := range stmt.Schema.PrimaryFields { | 			for _, field := range stmt.Schema.PrimaryFields { | ||||||
| 				if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { | 				if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { | ||||||
|  | |||||||
							
								
								
									
										16
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								scan.go
									
									
									
									
									
								
							| @ -120,22 +120,6 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) { | |||||||
| 
 | 
 | ||||||
| 	switch dest := db.Statement.Dest.(type) { | 	switch dest := db.Statement.Dest.(type) { | ||||||
| 	case map[string]interface{}, *map[string]interface{}: | 	case map[string]interface{}, *map[string]interface{}: | ||||||
| 		if update && db.Statement.Schema != nil { |  | ||||||
| 			switch db.Statement.ReflectValue.Kind() { |  | ||||||
| 			case reflect.Struct: |  | ||||||
| 				fields := make([]*schema.Field, len(columns)) |  | ||||||
| 				for idx, column := range columns { |  | ||||||
| 					if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { |  | ||||||
| 						fields[idx] = field |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				if initialized || rows.Next() { |  | ||||||
| 					db.scanIntoStruct(db.Statement.Schema, rows, db.Statement.ReflectValue, values, columns, fields, nil) |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if initialized || rows.Next() { | 		if initialized || rows.Next() { | ||||||
| 			columnTypes, _ := rows.ColumnTypes() | 			columnTypes, _ := rows.ColumnTypes() | ||||||
| 			prepareValues(values, db, columnTypes, columns) | 			prepareValues(values, db, columnTypes, columns) | ||||||
|  | |||||||
| @ -159,6 +159,6 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		stmt.AddClauseIfNotExists(clause.Update{}) | 		stmt.AddClauseIfNotExists(clause.Update{}) | ||||||
| 		stmt.Build("UPDATE", "SET", "WHERE") | 		stmt.Build(stmt.DB.Callback().Update().Clauses...) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
| @ -205,3 +205,54 @@ func TestDeleteSliceWithAssociations(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // only sqlite, postgres support returning
 | ||||||
|  | func TestSoftDeleteReturning(t *testing.T) { | ||||||
|  | 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	users := []*User{ | ||||||
|  | 		GetUser("delete-returning-1", Config{}), | ||||||
|  | 		GetUser("delete-returning-2", Config{}), | ||||||
|  | 		GetUser("delete-returning-3", Config{}), | ||||||
|  | 	} | ||||||
|  | 	DB.Create(&users) | ||||||
|  | 
 | ||||||
|  | 	var results []User | ||||||
|  | 	DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results) | ||||||
|  | 	if len(results) != 2 { | ||||||
|  | 		t.Errorf("failed to return delete data, got %v", results) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var count int64 | ||||||
|  | 	DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count) | ||||||
|  | 	if count != 1 { | ||||||
|  | 		t.Errorf("failed to delete data, current count %v", count) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestDeleteReturning(t *testing.T) { | ||||||
|  | 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	companies := []Company{ | ||||||
|  | 		{Name: "delete-returning-1"}, | ||||||
|  | 		{Name: "delete-returning-2"}, | ||||||
|  | 		{Name: "delete-returning-3"}, | ||||||
|  | 	} | ||||||
|  | 	DB.Create(&companies) | ||||||
|  | 
 | ||||||
|  | 	var results []Company | ||||||
|  | 	DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results) | ||||||
|  | 	if len(results) != 2 { | ||||||
|  | 		t.Errorf("failed to return delete data, got %v", results) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var count int64 | ||||||
|  | 	DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count) | ||||||
|  | 	if count != 1 { | ||||||
|  | 		t.Errorf("failed to delete data, current count %v", count) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -7,8 +7,8 @@ require ( | |||||||
| 	github.com/jinzhu/now v1.1.2 | 	github.com/jinzhu/now v1.1.2 | ||||||
| 	github.com/lib/pq v1.10.3 | 	github.com/lib/pq v1.10.3 | ||||||
| 	gorm.io/driver/mysql v1.1.2 | 	gorm.io/driver/mysql v1.1.2 | ||||||
| 	gorm.io/driver/postgres v1.2.0 | 	gorm.io/driver/postgres v1.2.1 | ||||||
| 	gorm.io/driver/sqlite v1.2.0 | 	gorm.io/driver/sqlite v1.2.2 | ||||||
| 	gorm.io/driver/sqlserver v1.1.2 | 	gorm.io/driver/sqlserver v1.1.2 | ||||||
| 	gorm.io/gorm v1.22.0 | 	gorm.io/gorm v1.22.0 | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -167,16 +167,13 @@ func TestUpdates(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// update with gorm exprs
 | 	// update with gorm exprs
 | ||||||
| 	if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | 	if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||||
| 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||||
| 	} | 	} | ||||||
| 	var user4 User | 	var user4 User | ||||||
| 	DB.First(&user4, user3.ID) | 	DB.First(&user4, user3.ID) | ||||||
| 
 | 
 | ||||||
| 	// sqlite, postgres support returning
 | 	user3.Age += 100 | ||||||
| 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { |  | ||||||
| 		user3.Age += 100 |  | ||||||
| 	} |  | ||||||
| 	AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") | 	AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -728,3 +725,35 @@ func TestSaveWithPrimaryValue(t *testing.T) { | |||||||
| 		t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) | 		t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // only sqlite, postgres support returning
 | ||||||
|  | func TestUpdateReturning(t *testing.T) { | ||||||
|  | 	if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	users := []*User{ | ||||||
|  | 		GetUser("update-returning-1", Config{}), | ||||||
|  | 		GetUser("update-returning-2", Config{}), | ||||||
|  | 		GetUser("update-returning-3", Config{}), | ||||||
|  | 	} | ||||||
|  | 	DB.Create(&users) | ||||||
|  | 
 | ||||||
|  | 	var results []User | ||||||
|  | 	DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88) | ||||||
|  | 	if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 { | ||||||
|  | 		t.Errorf("failed to return updated data, got %v", results) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||||
|  | 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||||
|  | 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if results[1].Age-results[0].Age != 100 { | ||||||
|  | 		t.Errorf("failed to return updated age column") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu