Merge branch 'master' into master

This commit is contained in:
Jinzhu 2019-12-06 08:47:53 +08:00 committed by GitHub
commit aeda586bd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 368 additions and 22 deletions

View File

@ -3,7 +3,7 @@ package gorm
import "fmt"
// DefaultCallback default callbacks defined by gorm
var DefaultCallback = &Callback{}
var DefaultCallback = &Callback{logger: nopLogger{}}
// Callback is a struct that contains all CRUD callbacks
// Field `creates` contains callbacks will be call when creating object
@ -101,6 +101,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *
}
}
cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum()))
cp.name = callbackName
cp.processor = &callback
cp.parent.processors = append(cp.parent.processors, cp)

View File

@ -50,7 +50,7 @@ func updateTimeStampForCreateCallback(scope *Scope) {
// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
if !scope.HasError() {
defer scope.trace(scope.db.nowFunc())
defer scope.trace(NowFunc())
var (
columns, placeholders []string

View File

@ -24,7 +24,7 @@ func queryCallback(scope *Scope) {
return
}
defer scope.trace(scope.db.nowFunc())
defer scope.trace(NowFunc())
var (
isSlice, isPtr bool

View File

@ -217,3 +217,33 @@ func TestGetCallback(t *testing.T) {
t.Errorf("`gorm:test_callback_value` should be `3, true` but `%v, %v`", v, ok)
}
}
func TestUseDefaultCallback(t *testing.T) {
createCallbackName := "gorm:test_use_default_callback_for_create"
gorm.DefaultCallback.Create().Register(createCallbackName, func(*gorm.Scope) {
// nop
})
if gorm.DefaultCallback.Create().Get(createCallbackName) == nil {
t.Errorf("`%s` expected non-nil, but got nil", createCallbackName)
}
gorm.DefaultCallback.Create().Remove(createCallbackName)
if gorm.DefaultCallback.Create().Get(createCallbackName) != nil {
t.Errorf("`%s` expected nil, but got non-nil", createCallbackName)
}
updateCallbackName := "gorm:test_use_default_callback_for_update"
scopeValueName := "gorm:test_use_default_callback_for_update_value"
gorm.DefaultCallback.Update().Register(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 1)
})
gorm.DefaultCallback.Update().Replace(updateCallbackName, func(scope *gorm.Scope) {
scope.Set(scopeValueName, 2)
})
scope := DB.NewScope(nil)
callback := gorm.DefaultCallback.Update().Get(updateCallbackName)
callback(scope)
if v, ok := scope.Get(scopeValueName); !ok || v != 2 {
t.Errorf("`%s` should be `2, true` but `%v, %v`", scopeValueName, v, ok)
}
}

View File

@ -37,7 +37,7 @@ type Dialect interface {
ModifyColumn(tableName string, columnName string, typ string) error
// LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case
LimitAndOffsetSQL(limit, offset interface{}) string
LimitAndOffsetSQL(limit, offset interface{}) (string, error)
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`

View File

@ -139,14 +139,19 @@ func (s commonDialect) CurrentDatabase() (name string) {
return
}
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
// LimitAndOffsetSQL return generated SQL with Limit and Offset
func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
if parsedLimit, err := s.parseInt(limit); err != nil {
return "", err
} else if parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
}
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
if parsedOffset, err := s.parseInt(offset); err != nil {
return "", err
} else if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
@ -181,6 +186,10 @@ func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (stri
return indexName, columnName
}
func (commonDialect) parseInt(value interface{}) (int64, error) {
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))

View File

@ -6,7 +6,6 @@ import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
@ -140,13 +139,21 @@ func (s mysql) ModifyColumn(tableName string, columnName string, typ string) err
return err
}
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
parsedLimit, err := s.parseInt(limit)
if err != nil {
return "", err
}
if parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
parsedOffset, err := s.parseInt(offset)
if err != nil {
return "", err
}
if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
@ -165,7 +172,8 @@ func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool {
func (s mysql) HasTable(tableName string) bool {
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
var name string
if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM %s WHERE Tables_in_%s = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
// allow mysql database name with '-' character
if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil {
if err == sql.ErrNoRows {
return false
}

View File

@ -168,14 +168,22 @@ func (s mssql) CurrentDatabase() (name string) {
return
}
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
func parseInt(value interface{}) (int64, error) {
return strconv.ParseInt(fmt.Sprint(value), 0, 0)
}
func (mssql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) {
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
if parsedOffset, err := parseInt(offset); err != nil {
return "", err
} else if parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d ROWS", parsedOffset)
}
}
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
if parsedLimit, err := parseInt(limit); err != nil {
return "", err
} else if parsedLimit >= 0 {
if sql == "" {
// add default zero offset
sql += " OFFSET 0 ROWS"

2
go.mod
View File

@ -9,5 +9,5 @@ require (
github.com/jinzhu/inflection v1.0.0
github.com/jinzhu/now v1.0.1
github.com/lib/pq v1.1.1
github.com/mattn/go-sqlite3 v1.11.0
github.com/mattn/go-sqlite3 v2.0.1+incompatible
)

4
go.sum
View File

@ -54,6 +54,10 @@ github.com/lib/pq v1.1.1 h1:sJZmqHoEaY7f+NPP8pgLB/WxulyR3fewgCM2qaSlBb4=
github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q=
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.12.0 h1:u/x3mp++qUxvYfulZ4HKOvVO0JWhk7HtE8lWhbGz/Do=
github.com/mattn/go-sqlite3 v1.12.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v2.0.1+incompatible h1:xQ15muvnzGBHpIpdrNi1DA5x0+TcBZzsIDwmw9uTHzw=
github.com/mattn/go-sqlite3 v2.0.1+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=

View File

@ -39,6 +39,15 @@ var LogFormatter = func(values ...interface{}) (messages []interface{}) {
messages = []interface{}{source, currentTime}
if len(values) == 2 {
//remove the line break
currentTime = currentTime[1:]
//remove the brackets
source = fmt.Sprintf("\033[35m%v\033[0m", values[1])
messages = []interface{}{currentTime, source}
}
if level == "sql" {
// duration
messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0))
@ -126,3 +135,7 @@ type Logger struct {
func (logger Logger) Print(values ...interface{}) {
logger.Println(LogFormatter(values...)...)
}
type nopLogger struct{}
func (nopLogger) Print(values ...interface{}) {}

24
main.go
View File

@ -434,6 +434,7 @@ func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB {
}
// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update
// WARNING when update with struct, GORM will not update fields that with zero value
func (s *DB) Update(attrs ...interface{}) *DB {
return s.Updates(toSearchableMap(attrs...), true)
}
@ -480,6 +481,7 @@ func (s *DB) Create(value interface{}) *DB {
}
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time
func (s *DB) Delete(value interface{}, where ...interface{}) *DB {
return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db
}
@ -523,6 +525,28 @@ func (s *DB) Debug() *DB {
return s.clone().LogMode(true)
}
// Transaction start a transaction as a block,
// return error will rollback, otherwise to commit.
func (s *DB) Transaction(fc func(tx *DB) error) (err error) {
panicked := true
tx := s.Begin()
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()
err = fc(tx)
if err == nil {
err = tx.Commit().Error
}
panicked = false
return
}
// Begin begins a transaction
func (s *DB) Begin() *DB {
return s.BeginTx(context.Background(), &sql.TxOptions{})

View File

@ -8,6 +8,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"os"
"path/filepath"
@ -469,6 +470,76 @@ func TestTransaction(t *testing.T) {
}
}
func assertPanic(t *testing.T, f func()) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()
f()
}
func TestTransactionWithBlock(t *testing.T) {
// rollback
err := DB.Transaction(func(tx *gorm.DB) error {
u := User{Name: "transcation"}
if err := tx.Save(&u).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil {
t.Errorf("Should find saved record")
}
return errors.New("the error message")
})
if err.Error() != "the error message" {
t.Errorf("Transaction return error will equal the block returns error")
}
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after rollback")
}
// commit
DB.Transaction(func(tx *gorm.DB) error {
u2 := User{Name: "transcation-2"}
if err := tx.Save(&u2).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should find saved record")
}
return nil
})
if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil {
t.Errorf("Should be able to find committed record")
}
// panic will rollback
assertPanic(t, func() {
DB.Transaction(func(tx *gorm.DB) error {
u3 := User{Name: "transcation-3"}
if err := tx.Save(&u3).Error; err != nil {
t.Errorf("No error should raise")
}
if err := tx.First(&User{}, "name = ?", "transcation-3").Error; err != nil {
t.Errorf("Should find saved record")
}
panic("force panic")
})
})
if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil {
t.Errorf("Should not find record after panic rollback")
}
}
func TestTransaction_NoErrorOnRollbackAfterCommit(t *testing.T) {
tx := DB.Begin()
u := User{Name: "transcation"}

View File

@ -17,6 +17,10 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string {
return defaultTableName
}
// lock for mutating global cached model metadata
var structsLock sync.Mutex
// global cache of model metadata
var modelStructsMap sync.Map
// ModelStruct model definition
@ -419,8 +423,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil {
// source foreign keys
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true
structsLock.Unlock()
// association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName)
@ -523,8 +531,12 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, toFields); foreignField != nil {
if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil {
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true
// source foreign keys
structsLock.Unlock()
// association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name)
relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName)
@ -582,7 +594,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
for idx, foreignKey := range foreignKeys {
if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil {
if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil {
// mark field as foreignkey, use global lock to avoid race
structsLock.Lock()
foreignField.IsForeignKey = true
structsLock.Unlock()
// association foreign keys
relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name)

93
model_struct_test.go Normal file
View File

@ -0,0 +1,93 @@
package gorm_test
import (
"sync"
"testing"
"github.com/jinzhu/gorm"
)
type ModelA struct {
gorm.Model
Name string
ModelCs []ModelC `gorm:"foreignkey:OtherAID"`
}
type ModelB struct {
gorm.Model
Name string
ModelCs []ModelC `gorm:"foreignkey:OtherBID"`
}
type ModelC struct {
gorm.Model
Name string
OtherAID uint64
OtherA *ModelA `gorm:"foreignkey:OtherAID"`
OtherBID uint64
OtherB *ModelB `gorm:"foreignkey:OtherBID"`
}
// This test will try to cause a race condition on the model's foreignkey metadata
func TestModelStructRaceSameModel(t *testing.T) {
// use a WaitGroup to execute as much in-sync as possible
// it's more likely to hit a race condition than without
n := 32
start := sync.WaitGroup{}
start.Add(n)
// use another WaitGroup to know when the test is done
done := sync.WaitGroup{}
done.Add(n)
for i := 0; i < n; i++ {
go func() {
start.Wait()
// call GetStructFields, this had a race condition before we fixed it
DB.NewScope(&ModelA{}).GetStructFields()
done.Done()
}()
start.Done()
}
done.Wait()
}
// This test will try to cause a race condition on the model's foreignkey metadata
func TestModelStructRaceDifferentModel(t *testing.T) {
// use a WaitGroup to execute as much in-sync as possible
// it's more likely to hit a race condition than without
n := 32
start := sync.WaitGroup{}
start.Add(n)
// use another WaitGroup to know when the test is done
done := sync.WaitGroup{}
done.Add(n)
for i := 0; i < n; i++ {
i := i
go func() {
start.Wait()
// call GetStructFields, this had a race condition before we fixed it
if i%2 == 0 {
DB.NewScope(&ModelA{}).GetStructFields()
} else {
DB.NewScope(&ModelB{}).GetStructFields()
}
done.Done()
}()
start.Done()
}
done.Wait()
}

View File

@ -457,6 +457,74 @@ func TestOffset(t *testing.T) {
}
}
func TestLimitAndOffsetSQL(t *testing.T) {
user1 := User{Name: "TestLimitAndOffsetSQL1", Age: 10}
user2 := User{Name: "TestLimitAndOffsetSQL2", Age: 20}
user3 := User{Name: "TestLimitAndOffsetSQL3", Age: 30}
user4 := User{Name: "TestLimitAndOffsetSQL4", Age: 40}
user5 := User{Name: "TestLimitAndOffsetSQL5", Age: 50}
if err := DB.Save(&user1).Save(&user2).Save(&user3).Save(&user4).Save(&user5).Error; err != nil {
t.Fatal(err)
}
tests := []struct {
name string
limit, offset interface{}
users []*User
ok bool
}{
{
name: "OK",
limit: float64(2),
offset: float64(2),
users: []*User{
&User{Name: "TestLimitAndOffsetSQL3", Age: 30},
&User{Name: "TestLimitAndOffsetSQL2", Age: 20},
},
ok: true,
},
{
name: "Limit parse error",
limit: float64(1000000), // 1e+06
offset: float64(2),
ok: false,
},
{
name: "Offset parse error",
limit: float64(2),
offset: float64(1000000), // 1e+06
ok: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var users []*User
err := DB.Where("name LIKE ?", "TestLimitAndOffsetSQL%").Order("age desc").Limit(tt.limit).Offset(tt.offset).Find(&users).Error
if tt.ok {
if err != nil {
t.Errorf("error expected nil, but got %v", err)
}
if len(users) != len(tt.users) {
t.Errorf("users length expected %d, but got %d", len(tt.users), len(users))
}
for i := range tt.users {
if users[i].Name != tt.users[i].Name {
t.Errorf("users[%d] name expected %s, but got %s", i, tt.users[i].Name, users[i].Name)
}
if users[i].Age != tt.users[i].Age {
t.Errorf("users[%d] age expected %d, but got %d", i, tt.users[i].Age, users[i].Age)
}
}
} else {
if err == nil {
t.Error("error expected not nil, but got nil")
}
}
})
}
}
func TestOr(t *testing.T) {
user1 := User{Name: "OrUser1", Age: 1}
user2 := User{Name: "OrUser2", Age: 10}

View File

@ -358,7 +358,7 @@ func (scope *Scope) Raw(sql string) *Scope {
// Exec perform generated SQL
func (scope *Scope) Exec() *Scope {
defer scope.trace(scope.db.nowFunc())
defer scope.trace(NowFunc())
if !scope.HasError() {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
@ -797,7 +797,9 @@ func (scope *Scope) orderSQL() string {
}
func (scope *Scope) limitAndOffsetSQL() string {
return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset)
scope.Err(err)
return sql
}
func (scope *Scope) groupSQL() string {
@ -932,7 +934,7 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin
}
func (scope *Scope) row() *sql.Row {
defer scope.trace(scope.db.nowFunc())
defer scope.trace(NowFunc())
result := &RowQueryResult{}
scope.InstanceSet("row_query_result", result)
@ -942,7 +944,7 @@ func (scope *Scope) row() *sql.Row {
}
func (scope *Scope) rows() (*sql.Rows, error) {
defer scope.trace(scope.db.nowFunc())
defer scope.trace(NowFunc())
result := &RowsQueryResult{}
scope.InstanceSet("row_query_result", result)