Support Raw func support use Limit method

This commit is contained in:
xuejipeng 2021-07-19 17:02:13 +08:00
parent 2202e99cbf
commit 5eb8039603
2 changed files with 51 additions and 1 deletions

View File

@ -28,13 +28,28 @@ func Query(db *gorm.DB) {
} }
func BuildQuerySQL(db *gorm.DB) { func BuildQuerySQL(db *gorm.DB) {
var hasLimitClause bool
sqlStr := db.Statement.SQL.String()
if db.Statement.Schema != nil && !db.Statement.Unscoped { if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.QueryClauses { for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c) db.Statement.AddClause(c)
} }
} }
if db.Statement.SQL.String() == "" { for cla, _ := range db.Statement.Clauses {
if cla == "LIMIT" {
hasLimitClause = true
}
}
if sqlStr != "" && !strings.Contains(strings.ToUpper(sqlStr), "LIMIT") && hasLimitClause {
if !strings.HasSuffix(db.Statement.SQL.String(), " ") {
db.Statement.SQL.WriteByte(' ')
}
db.Statement.Build("LIMIT")
}
if sqlStr == "" {
db.Statement.SQL.Grow(100) db.Statement.SQL.Grow(100)
clauseSelect := clause.Select{Distinct: db.Statement.Distinct} clauseSelect := clause.Select{Distinct: db.Statement.Distinct}

35
chainable_api_test.go Normal file
View File

@ -0,0 +1,35 @@
package gorm_test
import (
"gorm.io/gorm/callbacks"
"gorm.io/gorm/utils/tests"
"strings"
"testing"
"gorm.io/gorm"
)
func TestDB_Raw_Limit(t *testing.T) {
var (
t1 *gorm.DB
t2 *gorm.DB
t3 *gorm.DB
)
tx, _ := gorm.Open(
tests.DummyDialector{},
//&gorm.Config{Logger: logger.Default.LogMode(logger.Info)},
)
t1 = tx.Table("pod_events").Limit(1)
t2 = tx.Raw("SELECT * FROM `pod_events`").Limit(1)
t3 = tx.Raw("SELECT * FROM `pod_events` LIMIT 1").Limit(1)
t1.Statement.BuildClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
callbacks.BuildQuerySQL(t1)
callbacks.BuildQuerySQL(t2)
callbacks.BuildQuerySQL(t3)
s1 := t1.Statement.SQL.String()
s2 := t2.Statement.SQL.String()
s3 := t3.Statement.SQL.String()
if !strings.EqualFold(s1, s2) || !strings.EqualFold(s1, s3) {
t.Errorf("s1 != s2 != s3\ns1 = %v\ns2 = %v\ns3 = %v\n", s1, s2, s3)
}
}