diff --git a/callbacks/query.go b/callbacks/query.go index d0341284..480d58be 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -28,13 +28,28 @@ func Query(db *gorm.DB) { } func BuildQuerySQL(db *gorm.DB) { + var hasLimitClause bool + sqlStr := db.Statement.SQL.String() if db.Statement.Schema != nil && !db.Statement.Unscoped { for _, c := range db.Statement.Schema.QueryClauses { db.Statement.AddClause(c) } } - if db.Statement.SQL.String() == "" { + for cla, _ := range db.Statement.Clauses { + if cla == "LIMIT" { + hasLimitClause = true + } + } + + if sqlStr != "" && !strings.Contains(strings.ToUpper(sqlStr), "LIMIT") && hasLimitClause { + if !strings.HasSuffix(db.Statement.SQL.String(), " ") { + db.Statement.SQL.WriteByte(' ') + } + db.Statement.Build("LIMIT") + } + + if sqlStr == "" { db.Statement.SQL.Grow(100) clauseSelect := clause.Select{Distinct: db.Statement.Distinct} diff --git a/chainable_api_test.go b/chainable_api_test.go new file mode 100644 index 00000000..96f0d942 --- /dev/null +++ b/chainable_api_test.go @@ -0,0 +1,35 @@ +package gorm_test + +import ( + "gorm.io/gorm/callbacks" + "gorm.io/gorm/utils/tests" + "strings" + "testing" + + "gorm.io/gorm" +) + +func TestDB_Raw_Limit(t *testing.T) { + var ( + t1 *gorm.DB + t2 *gorm.DB + t3 *gorm.DB + ) + tx, _ := gorm.Open( + tests.DummyDialector{}, + //&gorm.Config{Logger: logger.Default.LogMode(logger.Info)}, + ) + t1 = tx.Table("pod_events").Limit(1) + t2 = tx.Raw("SELECT * FROM `pod_events`").Limit(1) + t3 = tx.Raw("SELECT * FROM `pod_events` LIMIT 1").Limit(1) + t1.Statement.BuildClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} + callbacks.BuildQuerySQL(t1) + callbacks.BuildQuerySQL(t2) + callbacks.BuildQuerySQL(t3) + s1 := t1.Statement.SQL.String() + s2 := t2.Statement.SQL.String() + s3 := t3.Statement.SQL.String() + if !strings.EqualFold(s1, s2) || !strings.EqualFold(s1, s3) { + t.Errorf("s1 != s2 != s3\ns1 = %v\ns2 = %v\ns3 = %v\n", s1, s2, s3) + } +}