Merge branch 'master' into master
This commit is contained in:
commit
ad72b61594
@ -119,8 +119,8 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
|
|||||||
|
|
||||||
// Replace a registered callback with new callback
|
// Replace a registered callback with new callback
|
||||||
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
|
||||||
// scope.SetColumn("Created", now)
|
// scope.SetColumn("CreatedAt", now)
|
||||||
// scope.SetColumn("Updated", now)
|
// scope.SetColumn("UpdatedAt", now)
|
||||||
// })
|
// })
|
||||||
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
|
||||||
cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
|
cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))
|
||||||
|
@ -101,6 +101,7 @@ func createCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
|
||||||
|
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
|
||||||
|
|
||||||
if len(columns) == 0 {
|
if len(columns) == 0 {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
@ -113,18 +114,19 @@ func createCallback(scope *Scope) {
|
|||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
scope.Raw(fmt.Sprintf(
|
scope.Raw(fmt.Sprintf(
|
||||||
"INSERT %v INTO %v (%v) VALUES (%v)%v%v",
|
"INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
|
||||||
addExtraSpaceIfExist(insertModifier),
|
addExtraSpaceIfExist(insertModifier),
|
||||||
scope.QuotedTableName(),
|
scope.QuotedTableName(),
|
||||||
strings.Join(columns, ","),
|
strings.Join(columns, ","),
|
||||||
|
addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
|
||||||
strings.Join(placeholders, ","),
|
strings.Join(placeholders, ","),
|
||||||
addExtraSpaceIfExist(extraOption),
|
addExtraSpaceIfExist(extraOption),
|
||||||
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
addExtraSpaceIfExist(lastInsertIDReturningSuffix),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
// execute create sql
|
// execute create sql: no primaryField
|
||||||
if lastInsertIDReturningSuffix == "" || primaryField == nil {
|
if primaryField == nil {
|
||||||
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
// set rows affected count
|
// set rows affected count
|
||||||
scope.db.RowsAffected, _ = result.RowsAffected()
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
@ -136,7 +138,26 @@ func createCallback(scope *Scope) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// execute create sql: lastInsertID implemention for majority of dialects
|
||||||
|
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
|
||||||
|
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
|
||||||
|
// set rows affected count
|
||||||
|
scope.db.RowsAffected, _ = result.RowsAffected()
|
||||||
|
|
||||||
|
// set primary value to primary field
|
||||||
|
if primaryField != nil && primaryField.IsBlank {
|
||||||
|
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
|
||||||
|
scope.Err(primaryField.Set(primaryValue))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
|
||||||
if primaryField.Field.CanAddr() {
|
if primaryField.Field.CanAddr() {
|
||||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||||
primaryField.IsBlank = false
|
primaryField.IsBlank = false
|
||||||
@ -145,7 +166,7 @@ func createCallback(scope *Scope) {
|
|||||||
} else {
|
} else {
|
||||||
scope.Err(ErrUnaddressable)
|
scope.Err(ErrUnaddressable)
|
||||||
}
|
}
|
||||||
}
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -40,6 +40,8 @@ type Dialect interface {
|
|||||||
LimitAndOffsetSQL(limit, offset interface{}) string
|
LimitAndOffsetSQL(limit, offset interface{}) string
|
||||||
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
|
||||||
SelectFromDummyTable() string
|
SelectFromDummyTable() string
|
||||||
|
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
|
||||||
|
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
|
||||||
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
|
||||||
LastInsertIDReturningSuffix(tableName, columnName string) string
|
LastInsertIDReturningSuffix(tableName, columnName string) string
|
||||||
// DefaultValueStr
|
// DefaultValueStr
|
||||||
|
@ -157,6 +157,10 @@ func (commonDialect) SelectFromDummyTable() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -121,7 +121,7 @@ func (s *mysql) DataTypeOf(field *StructField) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if sqlType == "" {
|
if sqlType == "" {
|
||||||
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String()))
|
panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name))
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.TrimSpace(additionalType) == "" {
|
if strings.TrimSpace(additionalType) == "" {
|
||||||
|
@ -120,6 +120,10 @@ func (s postgres) CurrentDatabase() (name string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
|
||||||
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
|
||||||
}
|
}
|
||||||
|
@ -190,6 +190,14 @@ func (mssql) SelectFromDummyTable() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
|
||||||
|
if len(columns) == 0 {
|
||||||
|
// No OUTPUT to query
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
|
||||||
|
}
|
||||||
|
|
||||||
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -133,23 +133,6 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
|
|||||||
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode)
|
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func TestStringAgainstIncompleteParentheses(t *testing.T) {
|
|
||||||
type AddressByZipCode struct {
|
|
||||||
ZipCode string `gorm:"primary_key"`
|
|
||||||
Address string
|
|
||||||
}
|
|
||||||
|
|
||||||
DB.AutoMigrate(&AddressByZipCode{})
|
|
||||||
DB.Create(&AddressByZipCode{ZipCode: "00502", Address: "Holtsville"})
|
|
||||||
|
|
||||||
var address AddressByZipCode
|
|
||||||
var addresses []AddressByZipCode
|
|
||||||
_ = DB.First(&address, "address_by_zip_codes=00502)) UNION ALL SELECT NULL,version(),current_database(),NULL,NULL,NULL,NULL,NULL--").Find(&addresses).GetErrors()
|
|
||||||
if len(addresses) > 0 {
|
|
||||||
t.Errorf("Fetch a record from with a string that has incomplete parentheses should be fail, zip code is %v", address.ZipCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFindAsSliceOfPointers(t *testing.T) {
|
func TestFindAsSliceOfPointers(t *testing.T) {
|
||||||
DB.Save(&User{Name: "user"})
|
DB.Save(&User{Name: "user"})
|
||||||
|
21
scope.go
21
scope.go
@ -277,23 +277,6 @@ func (scope *Scope) AddToVars(value interface{}) string {
|
|||||||
return scope.Dialect().BindVar(len(scope.SQLVars))
|
return scope.Dialect().BindVar(len(scope.SQLVars))
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsCompleteParentheses check if the string has complete parentheses to prevent SQL injection
|
|
||||||
func (scope *Scope) IsCompleteParentheses(value string) bool {
|
|
||||||
count := 0
|
|
||||||
for i, _ := range value {
|
|
||||||
if value[i] == 40 { // (
|
|
||||||
count++
|
|
||||||
} else if value[i] == 41 { // )
|
|
||||||
count--
|
|
||||||
}
|
|
||||||
if count < 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
return count == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectAttrs return selected attributes
|
// SelectAttrs return selected attributes
|
||||||
func (scope *Scope) SelectAttrs() []string {
|
func (scope *Scope) SelectAttrs() []string {
|
||||||
if scope.selectAttrs == nil {
|
if scope.selectAttrs == nil {
|
||||||
@ -573,10 +556,6 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if value != "" {
|
if value != "" {
|
||||||
if !scope.IsCompleteParentheses(value) {
|
|
||||||
scope.Err(fmt.Errorf("incomplete parentheses found: %v", value))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !include {
|
if !include {
|
||||||
if comparisonRegexp.MatchString(value) {
|
if comparisonRegexp.MatchString(value) {
|
||||||
str = fmt.Sprintf("NOT (%v)", value)
|
str = fmt.Sprintf("NOT (%v)", value)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user