Merge branch 'master' into master

This commit is contained in:
Jinzhu 2019-12-06 11:34:54 +08:00 committed by GitHub
commit 0718ef0782
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 288 additions and 25 deletions

View File

@ -3,7 +3,7 @@ package gorm
import "fmt" import "fmt"
// DefaultCallback default callbacks defined by gorm // DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{} var DefaultCallback = &Callback{logger: nopLogger{}}
// Callback is a struct that contains all CRUD callbacks // Callback is a struct that contains all CRUD callbacks
// Field `creates` contains callbacks will be call when creating object // Field `creates` contains callbacks will be call when creating object
@ -101,12 +101,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
} }
} }
if cp.logger != nil { cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
// note cp.logger will be nil during the default gorm callback registrations
// as they occur within init() blocks. However, any user-registered callbacks
// will happen after cp.logger exists (as the default logger or user-specified).
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
}
cp.name = callbackName cp.name = callbackName
cp.processor = &callback cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp) cp.parent.processors = append(cp.parent.processors, cp)

View File

@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
// createCallback the callback used to insert data into database // createCallback the callback used to insert data into database
func createCallback(scope *Scope) { func createCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
defer scope.trace(scope.db.nowFunc()) defer scope.trace(NowFunc())
var ( var (
columns, placeholders []string columns, placeholders []string

View File

@ -24,7 +24,7 @@ func queryCallback(scope *Scope) {
return return
} }
defer scope.trace(scope.db.nowFunc()) defer scope.trace(NowFunc())
var ( var (
isSlice, isPtr bool isSlice, isPtr bool
@ -60,6 +60,11 @@ func queryCallback(scope *Scope) {
if !scope.HasError() { if !scope.HasError() {
scope.db.RowsAffected = 0 scope.db.RowsAffected = 0
if str, ok := scope.Get("gorm:query_hint"); ok {
scope.SQL = fmt.Sprint(str) + scope.SQL
}
if str, ok := scope.Get("gorm:query_option"); ok { if str, ok := scope.Get("gorm:query_option"); ok {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
} }

View File

@ -23,6 +23,11 @@ type RowsQueryResult struct {
func rowQueryCallback(scope *Scope) { func rowQueryCallback(scope *Scope) {
if result, ok := scope.InstanceGet("row_query_result"); ok { if result, ok := scope.InstanceGet("row_query_result"); ok {
scope.prepareQuerySQL() scope.prepareQuerySQL()
if str, ok := scope.Get("gorm:query_hint"); ok {
scope.SQL = fmt.Sprint(str) + scope.SQL
}
if str, ok := scope.Get("gorm:query_option"); ok { if str, ok := scope.Get("gorm:query_option"); ok {
scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str))
} }

View File

@ -217,3 +217,33 @@ func TestGetCallback(t *testing.T) {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok) t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
} }
} }
func TestUseDefaultCallback(t *testing.T) {
createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
// nop
})
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
}
gorm.DefaultCallback.Create().Remove(createCallbackName)
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
}
updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 1)
})
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 2)
})
scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
}
}

View File

@ -37,7 +37,7 @@ type Dialect interface {
ModifyColumn(tableName string, columnName string, typ string) error ModifyColumn(tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) string LimitAndOffsetSQL(limit, offset interface{}) (string, error)
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`

View File

@ -139,14 +139,19 @@ func (s commonDialect) CurrentDatabase() (name string) {
return return
} }
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { // LimitAndOffsetSQL return generated SQL with Limit and Offset
func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if limit != nil { if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { if parsedLimit, err := s.parseInt(limit); err != nil {
return "", err
} else if parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit) sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
} }
} }
if offset != nil { if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { if parsedOffset, err := s.parseInt(offset); err != nil {
return "", err
} else if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset) sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
} }
} }
@ -181,6 +186,10 @@ func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (stri
return indexName, columnName return indexName, columnName
} }
func (commonDialect) parseInt(value interface{}) (int64, error) {
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice // IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool { func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"strconv"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -140,13 +139,21 @@ func (s mysql) ModifyColumn(tableName string, columnName string, typ string) err
return err return err
} }
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if limit != nil { if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { parsedLimit, err := s.parseInt(limit)
if err != nil {
return "", err
}
if parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit) sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
if offset != nil { if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { parsedOffset, err := s.parseInt(offset)
if err != nil {
return "", err
}
if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset) sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
} }
} }

View File

@ -168,14 +168,22 @@ func (s mssql) CurrentDatabase() (name string) {
return return
} }
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) { func parseInt(value interface{}) (int64, error) {
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
}
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if offset != nil { if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 { if parsedOffset, err := parseInt(offset); err != nil {
return "", err
} else if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset) sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
} }
} }
if limit != nil { if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 { if parsedLimit, err := parseInt(limit); err != nil {
return "", err
} else if parsedLimit >= 0 {
if sql == "" { if sql == "" {
// add default zero offset // add default zero offset
sql += " OFFSET 0 ROWS" sql += " OFFSET 0 ROWS"

2
go.mod
View File

@ -9,5 +9,5 @@ require (
github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.0.1 github.com/jinzhu/now v1.0.1
github.com/lib/pq v1.1.1 github.com/lib/pq v1.1.1
github.com/mattn/go-sqlite3 v1.11.0 github.com/mattn/go-sqlite3 v2.0.1+incompatible
) )

4
go.sum
View File

@ -54,6 +54,10 @@ github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q=
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.12.0 h1:u/x3mp++qUxvYfulZ4HKOvVO0JWhk7HtE8lWhbGz/Do=
github.com/mattn/go-sqlite3 v1.12.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=

View File

@ -39,6 +39,15 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
messages = []interface{}{source, currentTime} messages = []interface{}{source, currentTime}
if len(values) == 2 {
//remove the line break
currentTime = currentTime[1:]
//remove the brackets
source = fmt.Sprintf("\033[35m%v\033[0m", values[1])
messages = []interface{}{currentTime, source}
}
if level == "sql" { if level == "sql" {
// duration // duration
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
@ -126,3 +135,7 @@ type Logger struct {
func (logger Logger) Print(values ...interface{}) { func (logger Logger) Print(values ...interface{}) {
logger.Println(LogFormatter(values...)...) logger.Println(LogFormatter(values...)...)
} }
type nopLogger struct{}
func (nopLogger) Print(values ...interface{}) {}

22
main.go
View File

@ -528,6 +528,28 @@ func (s *DB) Debug() *DB {
return s.clone().LogMode(true) return s.clone().LogMode(true)
} }
// Transaction start a transaction as a block,
// return error will rollback, otherwise to commit.
func (s *DB) Transaction(fc func(tx *DB) error) (err error) {
panicked := true
tx := s.Begin()
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()
err = fc(tx)
if err == nil {
err = tx.Commit().Error
}
panicked = false
return
}
// Begin begins a transaction // Begin begins a transaction
func (s *DB) Begin() *DB { func (s *DB) Begin() *DB {
return s.BeginTx(context.Background(), &sql.TxOptions{}) return s.BeginTx(context.Background(), &sql.TxOptions{})

View File

@ -8,6 +8,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -469,6 +470,76 @@ func TestTransaction(t *testing.T) {
} }
} }
func assertPanic(t *testing.T, f func()) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()
f()
}
func TestTransactionWithBlock(t *testing.T) {
// rollback
err := DB.Transaction(func(tx *gorm.DB) error {
u := User{Name: "transcation"}
if err := tx.Save(&u).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record")
}
return errors.New("the error message")
})
if err.Error() != "the error message" {
t.Errorf("Transaction return error will equal the block returns error")
}
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after rollback")
}
// commit
DB.Transaction(func(tx *gorm.DB) error {
u2 := User{Name: "transcation-2"}
if err := tx.Save(&u2).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should find saved record")
}
return nil
})
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should be able to find committed record")
}
// panic will rollback
assertPanic(t, func() {
DB.Transaction(func(tx *gorm.DB) error {
u3 := User{Name: "transcation-3"}
if err := tx.Save(&u3).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
t.Errorf("Should find saved record")
}
panic("force panic")
})
})
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after panic rollback")
}
}
func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) { func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
tx := DB.Begin() tx := DB.Begin()
u := User{Name: "transcation"} u := User{Name: "transcation"}
@ -1262,6 +1333,30 @@ func TestCountWithQueryOption(t *testing.T) {
} }
} }
func TestQueryHint1(t *testing.T) {
db := DB.New()
_, err := db.Model(User{}).Raw("select 1").Rows()
if err != nil {
t.Error("Unexpected error on query count with query_option")
}
}
func TestQueryHint2(t *testing.T) {
type TestStruct struct {
ID string `gorm:"primary_key"`
Name string
}
DB.DropTable(&TestStruct{})
DB.AutoMigrate(&TestStruct{})
data := TestStruct{ID: "uuid", Name: "hello"}
if err := DB.Set("gorm:query_hint", "/*master*/").Save(&data).Error; err != nil {
t.Error("Unexpected error on query count with query_option")
}
}
func TestFloatColumnPrecision(t *testing.T) { func TestFloatColumnPrecision(t *testing.T) {
if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" { if dialect := os.Getenv("GORM_DIALECT"); dialect != "mysql" && dialect != "sqlite" {
t.Skip() t.Skip()

View File

@ -457,6 +457,74 @@ func TestOffset(t *testing.T) {
} }
} }
func TestLimitAndOffsetSQL(t *testing.T) {
user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10}
user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20}
user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30}
user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40}
user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50}
if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil {
t.Fatal(err)
}
tests := []struct {
name string
limit, offset interface{}
users []*User
ok bool
}{
{
name: "OK",
limit: float64(2),
offset: float64(2),
users: []*User{
&User{Name: "TestLimitAndOffsetSQL3", Age: 30},
&User{Name: "TestLimitAndOffsetSQL2", Age: 20},
},
ok: true,
},
{
name: "Limit parse error",
limit: float64(1000000), // 1e+06
offset: float64(2),
ok: false,
},
{
name: "Offset parse error",
limit: float64(2),
offset: float64(1000000), // 1e+06
ok: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var users []*User
err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error
if tt.ok {
if err != nil {
t.Errorf("error expected nil, but got %v", err)
}
if len(users) != len(tt.users) {
t.Errorf("users length expected %d, but got %d", len(tt.users), len(users))
}
for i := range tt.users {
if users[i].Name != tt.users[i].Name {
t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name)
}
if users[i].Age != tt.users[i].Age {
t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age)
}
}
} else {
if err == nil {
t.Error("error expected not nil, but got nil")
}
}
})
}
}
func TestOr(t *testing.T) { func TestOr(t *testing.T) {
user1 := User{Name: "OrUser1", Age: 1} user1 := User{Name: "OrUser1", Age: 1}
user2 := User{Name: "OrUser2", Age: 10} user2 := User{Name: "OrUser2", Age: 10}

View File

@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope {
// Exec perform generated SQL // Exec perform generated SQL
func (scope *Scope) Exec() *Scope { func (scope *Scope) Exec() *Scope {
defer scope.trace(scope.db.nowFunc()) defer scope.trace(NowFunc())
if !scope.HasError() { if !scope.HasError() {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
@ -797,7 +797,9 @@ func (scope *Scope) orderSQL() string {
} }
func (scope *Scope) limitAndOffsetSQL() string { func (scope *Scope) limitAndOffsetSQL() string {
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
scope.Err(err)
return sql
} }
func (scope *Scope) groupSQL() string { func (scope *Scope) groupSQL() string {
@ -932,7 +934,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
} }
func (scope *Scope) row() *sql.Row { func (scope *Scope) row() *sql.Row {
defer scope.trace(scope.db.nowFunc()) defer scope.trace(NowFunc())
result := &RowQueryResult{} result := &RowQueryResult{}
scope.InstanceSet("row_query_result", result) scope.InstanceSet("row_query_result", result)
@ -942,7 +944,7 @@ func (scope *Scope) row() *sql.Row {
} }
func (scope *Scope) rows() (*sql.Rows, error) { func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.trace(scope.db.nowFunc()) defer scope.trace(NowFunc())
result := &RowsQueryResult{} result := &RowsQueryResult{}
scope.InstanceSet("row_query_result", result) scope.InstanceSet("row_query_result", result)