Merge branch 'master' into master
This commit is contained in:
commit
a3c066f9f5
@ -3,7 +3,7 @@ package gorm
|
||||
import "fmt"
|
||||
|
||||
// DefaultCallback default callbacks defined by gorm
|
||||
var DefaultCallback = &Callback{}
|
||||
var DefaultCallback = &Callback{logger: nopLogger{}}
|
||||
|
||||
// Callback is a struct that contains all CRUD callbacks
|
||||
// Field `creates` contains callbacks will be call when creating object
|
||||
@ -101,12 +101,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
|
||||
}
|
||||
}
|
||||
|
||||
if cp.logger != nil {
|
||||
// note cp.logger will be nil during the default gorm callback registrations
|
||||
// as they occur within init() blocks. However, any user-registered callbacks
|
||||
// will happen after cp.logger exists (as the default logger or user-specified).
|
||||
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||
}
|
||||
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
|
||||
cp.name = callbackName
|
||||
cp.processor = &callback
|
||||
cp.parent.processors = append(cp.parent.processors, cp)
|
||||
|
@ -51,7 +51,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
|
||||
// createCallback the callback used to insert data into database
|
||||
func createCallback(scope *Scope) {
|
||||
if !scope.HasError() {
|
||||
defer scope.trace(scope.db.nowFunc())
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
columns, placeholders []string
|
||||
|
@ -24,7 +24,7 @@ func queryCallback(scope *Scope) {
|
||||
return
|
||||
}
|
||||
|
||||
defer scope.trace(scope.db.nowFunc())
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
var (
|
||||
isSlice, isPtr bool
|
||||
@ -60,6 +60,11 @@ func queryCallback(scope *Scope) {
|
||||
|
||||
if !scope.HasError() {
|
||||
scope.db.RowsAffected = 0
|
||||
|
||||
if str, ok := scope.Get("gorm:query_hint"); ok {
|
||||
scope.SQL = fmt.Sprint(str) + scope.SQL
|
||||
}
|
||||
|
||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||
}
|
||||
|
@ -23,6 +23,11 @@ type RowsQueryResult struct {
|
||||
func rowQueryCallback(scope *Scope) {
|
||||
if result, ok := scope.InstanceGet("row_query_result"); ok {
|
||||
scope.prepareQuerySQL()
|
||||
|
||||
if str, ok := scope.Get("gorm:query_hint"); ok {
|
||||
scope.SQL = fmt.Sprint(str) + scope.SQL
|
||||
}
|
||||
|
||||
if str, ok := scope.Get("gorm:query_option"); ok {
|
||||
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
|
||||
}
|
||||
|
@ -217,3 +217,33 @@ func TestGetCallback(t *testing.T) {
|
||||
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseDefaultCallback(t *testing.T) {
|
||||
createCallbackName := "gorm:test_use_default_callback_for_create"
|
||||
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
|
||||
// nop
|
||||
})
|
||||
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
|
||||
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
|
||||
}
|
||||
gorm.DefaultCallback.Create().Remove(createCallbackName)
|
||||
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
|
||||
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
|
||||
}
|
||||
|
||||
updateCallbackName := "gorm:test_use_default_callback_for_update"
|
||||
scopeValueName := "gorm:test_use_default_callback_for_update_value"
|
||||
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
|
||||
scope.Set(scopeValueName, 1)
|
||||
})
|
||||
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
|
||||
scope.Set(scopeValueName, 2)
|
||||
})
|
||||
|
||||
scope := DB.NewScope(nil)
|
||||
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
|
||||
callback(scope)
|
||||
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
|
||||
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
|
||||
}
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ type Dialect interface {
|
||||
ModifyColumn(tableName string, columnName string, typ string) error
|
||||
|
||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
|
||||
LimitAndOffsetSQL(limit, offset interface{}) string
|
||||
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
|
||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||
SelectFromDummyTable() string
|
||||
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
||||
|
@ -139,14 +139,19 @@ func (s commonDialect) CurrentDatabase() (name string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||
// LimitAndOffsetSQL return generated SQL with Limit and Offset
|
||||
func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
|
||||
if limit != nil {
|
||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||
if parsedLimit, err := s.parseInt(limit); err != nil {
|
||||
return "", err
|
||||
} else if parsedLimit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||
}
|
||||
}
|
||||
if offset != nil {
|
||||
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
|
||||
if parsedOffset, err := s.parseInt(offset); err != nil {
|
||||
return "", err
|
||||
} else if parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
|
||||
}
|
||||
}
|
||||
@ -181,6 +186,10 @@ func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (stri
|
||||
return indexName, columnName
|
||||
}
|
||||
|
||||
func (commonDialect) parseInt(value interface{}) (int64, error) {
|
||||
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
|
||||
}
|
||||
|
||||
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
|
||||
func IsByteArrayOrSlice(value reflect.Value) bool {
|
||||
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@ -140,13 +139,21 @@ func (s mysql) ModifyColumn(tableName string, columnName string, typ string) err
|
||||
return err
|
||||
}
|
||||
|
||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
|
||||
if limit != nil {
|
||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||
parsedLimit, err := s.parseInt(limit)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedLimit >= 0 {
|
||||
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
|
||||
|
||||
if offset != nil {
|
||||
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
|
||||
parsedOffset, err := s.parseInt(offset)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
|
||||
}
|
||||
}
|
||||
|
@ -168,14 +168,22 @@ func (s mssql) CurrentDatabase() (name string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
|
||||
func parseInt(value interface{}) (int64, error) {
|
||||
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
|
||||
}
|
||||
|
||||
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
|
||||
if offset != nil {
|
||||
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
|
||||
if parsedOffset, err := parseInt(offset); err != nil {
|
||||
return "", err
|
||||
} else if parsedOffset >= 0 {
|
||||
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
|
||||
}
|
||||
}
|
||||
if limit != nil {
|
||||
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
|
||||
if parsedLimit, err := parseInt(limit); err != nil {
|
||||
return "", err
|
||||
} else if parsedLimit >= 0 {
|
||||
if sql == "" {
|
||||
// add default zero offset
|
||||
sql += " OFFSET 0 ROWS"
|
||||
|
2
go.mod
2
go.mod
@ -9,5 +9,5 @@ require (
|
||||
github.com/jinzhu/inflection v1.0.0
|
||||
github.com/jinzhu/now v1.0.1
|
||||
github.com/lib/pq v1.1.1
|
||||
github.com/mattn/go-sqlite3 v1.11.0
|
||||
github.com/mattn/go-sqlite3 v2.0.1+incompatible
|
||||
)
|
||||
|
4
go.sum
4
go.sum
@ -54,6 +54,10 @@ github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
|
||||
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q=
|
||||
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mattn/go-sqlite3 v1.12.0 h1:u/x3mp++qUxvYfulZ4HKOvVO0JWhk7HtE8lWhbGz/Do=
|
||||
github.com/mattn/go-sqlite3 v1.12.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
|
||||
github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
|
13
logger.go
13
logger.go
@ -39,6 +39,15 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
|
||||
|
||||
messages = []interface{}{source, currentTime}
|
||||
|
||||
if len(values) == 2 {
|
||||
//remove the line break
|
||||
currentTime = currentTime[1:]
|
||||
//remove the brackets
|
||||
source = fmt.Sprintf("\033[35m%v\033[0m", values[1])
|
||||
|
||||
messages = []interface{}{currentTime, source}
|
||||
}
|
||||
|
||||
if level == "sql" {
|
||||
// duration
|
||||
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
|
||||
@ -126,3 +135,7 @@ type Logger struct {
|
||||
func (logger Logger) Print(values ...interface{}) {
|
||||
logger.Println(LogFormatter(values...)...)
|
||||
}
|
||||
|
||||
type nopLogger struct{}
|
||||
|
||||
func (nopLogger) Print(values ...interface{}) {}
|
||||
|
16
main.go
16
main.go
@ -124,7 +124,10 @@ func (s *DB) Close() error {
|
||||
// DB get `*sql.DB` from current connection
|
||||
// If the underlying database connection is not a *sql.DB, returns nil
|
||||
func (s *DB) DB() *sql.DB {
|
||||
db, _ := s.db.(*sql.DB)
|
||||
db, ok := s.db.(*sql.DB)
|
||||
if !ok {
|
||||
panic("can't support full GORM on currently status, maybe this is a TX instance.")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
@ -528,12 +531,12 @@ func (s *DB) Debug() *DB {
|
||||
// Transaction start a transaction as a block,
|
||||
// return error will rollback, otherwise to commit.
|
||||
func (s *DB) Transaction(fc func(tx *DB) error) (err error) {
|
||||
panicked := true
|
||||
tx := s.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("%s", r)
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
tx.Rollback()
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
@ -543,10 +546,7 @@ func (s *DB) Transaction(fc func(tx *DB) error) (err error) {
|
||||
err = tx.Commit().Error
|
||||
}
|
||||
|
||||
// Makesure rollback when Block error or Commit error
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
panicked = false
|
||||
return
|
||||
}
|
||||
|
||||
|
53
main_test.go
53
main_test.go
@ -470,6 +470,15 @@ func TestTransaction(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func assertPanic(t *testing.T, f func()) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("The code did not panic")
|
||||
}
|
||||
}()
|
||||
f()
|
||||
}
|
||||
|
||||
func TestTransactionWithBlock(t *testing.T) {
|
||||
// rollback
|
||||
err := DB.Transaction(func(tx *gorm.DB) error {
|
||||
@ -511,17 +520,19 @@ func TestTransactionWithBlock(t *testing.T) {
|
||||
}
|
||||
|
||||
// panic will rollback
|
||||
DB.Transaction(func(tx *gorm.DB) error {
|
||||
u3 := User{Name: "transcation-3"}
|
||||
if err := tx.Save(&u3).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
}
|
||||
assertPanic(t, func() {
|
||||
DB.Transaction(func(tx *gorm.DB) error {
|
||||
u3 := User{Name: "transcation-3"}
|
||||
if err := tx.Save(&u3).Error; err != nil {
|
||||
t.Errorf("No error should raise")
|
||||
}
|
||||
|
||||
if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
|
||||
t.Errorf("Should find saved record")
|
||||
}
|
||||
if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
|
||||
t.Errorf("Should find saved record")
|
||||
}
|
||||
|
||||
panic("force panic")
|
||||
panic("force panic")
|
||||
})
|
||||
})
|
||||
|
||||
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
|
||||
@ -1322,6 +1333,30 @@ func TestCountWithQueryOption(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryHint1(t *testing.T) {
|
||||
db := DB.New()
|
||||
|
||||
_, err := db.Model(User{}).Raw("select 1").Rows()
|
||||
|
||||
if err != nil {
|
||||
t.Error("Unexpected error on query count with query_option")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryHint2(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
ID string `gorm:"primary_key"`
|
||||
Name string
|
||||
}
|
||||
DB.DropTable(&TestStruct{})
|
||||
DB.AutoMigrate(&TestStruct{})
|
||||
|
||||
data := TestStruct{ID: "uuid", Name: "hello"}
|
||||
if err := DB.Set("gorm:query_hint", "/*master*/").Save(&data).Error; err != nil {
|
||||
t.Error("Unexpected error on query count with query_option")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloatColumnPrecision(t *testing.T) {
|
||||
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" {
|
||||
t.Skip()
|
||||
|
@ -505,6 +505,74 @@ func TestOffset(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitAndOffsetSQL(t *testing.T) {
|
||||
user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10}
|
||||
user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20}
|
||||
user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30}
|
||||
user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40}
|
||||
user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50}
|
||||
if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
limit, offset interface{}
|
||||
users []*User
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "OK",
|
||||
limit: float64(2),
|
||||
offset: float64(2),
|
||||
users: []*User{
|
||||
&User{Name: "TestLimitAndOffsetSQL3", Age: 30},
|
||||
&User{Name: "TestLimitAndOffsetSQL2", Age: 20},
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "Limit parse error",
|
||||
limit: float64(1000000), // 1e+06
|
||||
offset: float64(2),
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
name: "Offset parse error",
|
||||
limit: float64(2),
|
||||
offset: float64(1000000), // 1e+06
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var users []*User
|
||||
err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error
|
||||
if tt.ok {
|
||||
if err != nil {
|
||||
t.Errorf("error expected nil, but got %v", err)
|
||||
}
|
||||
if len(users) != len(tt.users) {
|
||||
t.Errorf("users length expected %d, but got %d", len(tt.users), len(users))
|
||||
}
|
||||
for i := range tt.users {
|
||||
if users[i].Name != tt.users[i].Name {
|
||||
t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name)
|
||||
}
|
||||
if users[i].Age != tt.users[i].Age {
|
||||
t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Error("error expected not nil, but got nil")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOr(t *testing.T) {
|
||||
user1 := User{Name: "OrUser1", Age: 1}
|
||||
user2 := User{Name: "OrUser2", Age: 10}
|
||||
|
10
scope.go
10
scope.go
@ -362,7 +362,7 @@ func (scope *Scope) Raw(sql string) *Scope {
|
||||
|
||||
// Exec perform generated SQL
|
||||
func (scope *Scope) Exec() *Scope {
|
||||
defer scope.trace(scope.db.nowFunc())
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
if !scope.HasError() {
|
||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||
@ -809,7 +809,9 @@ func (scope *Scope) orderSQL() string {
|
||||
}
|
||||
|
||||
func (scope *Scope) limitAndOffsetSQL() string {
|
||||
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||
sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
|
||||
scope.Err(err)
|
||||
return sql
|
||||
}
|
||||
|
||||
func (scope *Scope) groupSQL() string {
|
||||
@ -952,7 +954,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
|
||||
}
|
||||
|
||||
func (scope *Scope) row() *sql.Row {
|
||||
defer scope.trace(scope.db.nowFunc())
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
result := &RowQueryResult{}
|
||||
scope.InstanceSet("row_query_result", result)
|
||||
@ -962,7 +964,7 @@ func (scope *Scope) row() *sql.Row {
|
||||
}
|
||||
|
||||
func (scope *Scope) rows() (*sql.Rows, error) {
|
||||
defer scope.trace(scope.db.nowFunc())
|
||||
defer scope.trace(NowFunc())
|
||||
|
||||
result := &RowsQueryResult{}
|
||||
scope.InstanceSet("row_query_result", result)
|
||||
|
Loading…
x
Reference in New Issue
Block a user