fixed wrong param substitution order
This commit is contained in:
		
							parent
							
								
									0a51f6cdc5
								
							
						
					
					
						commit
						ede35514a4
					
				
							
								
								
									
										24
									
								
								main_test.go
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								main_test.go
									
									
									
									
									
								
							@ -633,6 +633,30 @@ func TestQueryBuilderSubselectInWhere(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQueryBuilderRawQueryWithSubquery(t *testing.T) {
 | 
			
		||||
	user := User{Name: "subquery_test_user1", Age: 10}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "subquery_test_user2", Age: 11}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
	user = User{Name: "subquery_test_user2", Age: 12}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
	var count int
 | 
			
		||||
	err := DB.Raw("select count(*) from (?) tmp",
 | 
			
		||||
		DB.Table("users").
 | 
			
		||||
			Select("name").
 | 
			
		||||
			Where("age >= ? and name in (?)", 10, []string{"subquery_test_user1", "subquery_test_user2"}).
 | 
			
		||||
			Group("name").
 | 
			
		||||
			QueryExpr(),
 | 
			
		||||
	).Count(&count).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Errorf("Expected to get no errors, but got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if count != 2 {
 | 
			
		||||
		t.Errorf("Row count must be 2, instead got %d", count)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQueryBuilderSubselectInHaving(t *testing.T) {
 | 
			
		||||
	user := User{Name: "query_expr_having_ruser1", Email: "root@user1.com", Age: 64}
 | 
			
		||||
	DB.Save(&user)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										46
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										46
									
								
								scope.go
									
									
									
									
									
								
							@ -1,16 +1,16 @@
 | 
			
		||||
package gorm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	gobytes "bytes"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"reflect"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Scope contain current operation's information when you perform any operation on the database
 | 
			
		||||
@ -555,29 +555,44 @@ func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str stri
 | 
			
		||||
		return strings.Join(sqls, " AND ")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	replacements := []string{}
 | 
			
		||||
	args := clause["args"].([]interface{})
 | 
			
		||||
	for _, arg := range args {
 | 
			
		||||
		switch reflect.ValueOf(arg).Kind() {
 | 
			
		||||
		case reflect.Slice: // For where("id in (?)", []int64{1,2})
 | 
			
		||||
			if bytes, ok := arg.([]byte); ok {
 | 
			
		||||
				str = strings.Replace(str, "?", scope.AddToVars(bytes), 1)
 | 
			
		||||
				replacements = append(replacements, scope.AddToVars(bytes))
 | 
			
		||||
			} else if values := reflect.ValueOf(arg); values.Len() > 0 {
 | 
			
		||||
				var tempMarks []string
 | 
			
		||||
				for i := 0; i < values.Len(); i++ {
 | 
			
		||||
					tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
 | 
			
		||||
				}
 | 
			
		||||
				str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
 | 
			
		||||
				replacements = append(replacements, strings.Join(tempMarks, ","))
 | 
			
		||||
			} else {
 | 
			
		||||
				str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1)
 | 
			
		||||
				replacements = append(replacements, scope.AddToVars(Expr("NULL")))
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			if valuer, ok := interface{}(arg).(driver.Valuer); ok {
 | 
			
		||||
				arg, _ = valuer.Value()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
 | 
			
		||||
			replacements = append(replacements, scope.AddToVars(arg))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	buff := gobytes.NewBuffer([]byte{})
 | 
			
		||||
	i := 0
 | 
			
		||||
	for pos := range str {
 | 
			
		||||
		if str[pos] == '?' {
 | 
			
		||||
			buff.WriteString(replacements[i])
 | 
			
		||||
			i++
 | 
			
		||||
		} else {
 | 
			
		||||
			buff.WriteByte(str[pos])
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	str = buff.String()
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -661,6 +676,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	args := clause["args"].([]interface{})
 | 
			
		||||
	replacements := []string{}
 | 
			
		||||
	for _, arg := range args {
 | 
			
		||||
		switch reflect.ValueOf(arg).Kind() {
 | 
			
		||||
		case reflect.Slice:
 | 
			
		||||
@ -669,14 +685,28 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string)
 | 
			
		||||
			for i := 0; i < values.Len(); i++ {
 | 
			
		||||
				tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface()))
 | 
			
		||||
			}
 | 
			
		||||
			str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1)
 | 
			
		||||
			replacements = append(replacements, strings.Join(tempMarks, ","))
 | 
			
		||||
		default:
 | 
			
		||||
			if valuer, ok := interface{}(arg).(driver.Valuer); ok {
 | 
			
		||||
				arg, _ = valuer.Value()
 | 
			
		||||
			}
 | 
			
		||||
			str = strings.Replace(str, "?", scope.AddToVars(arg), 1)
 | 
			
		||||
			replacements = append(replacements, scope.AddToVars(arg))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	buff := gobytes.NewBuffer([]byte{})
 | 
			
		||||
	i := 0
 | 
			
		||||
	for pos := range str {
 | 
			
		||||
		if str[pos] == '?' {
 | 
			
		||||
			buff.WriteString(replacements[i])
 | 
			
		||||
			i++
 | 
			
		||||
		} else {
 | 
			
		||||
			buff.WriteByte(str[pos])
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	str = buff.String()
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user