Merge branch 'master' into master

This commit is contained in:
Jinzhu 2019-10-17 23:07:43 +08:00 committed by GitHub
commit ad72b61594
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 55 additions and 54 deletions

View File

@ -119,8 +119,8 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
// Replace a registered callback with new callback // Replace a registered callback with new callback
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
// scope.SetColumn("Created", now) // scope.SetColumn("CreatedAt", now)
// scope.SetColumn("Updated", now) // scope.SetColumn("UpdatedAt", now)
// }) // })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())) cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))

View File

@ -101,10 +101,11 @@ func createCallback(scope *Scope) {
} }
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
if len(columns) == 0 { if len(columns) == 0 {
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v %v%v%v", "INSERT%v INTO %v %v%v%v",
addExtraSpaceIfExist(insertModifier), addExtraSpaceIfExist(insertModifier),
quotedTableName, quotedTableName,
scope.Dialect().DefaultValueStr(), scope.Dialect().DefaultValueStr(),
@ -113,18 +114,19 @@ func createCallback(scope *Scope) {
)) ))
} else { } else {
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v (%v) VALUES (%v)%v%v", "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
addExtraSpaceIfExist(insertModifier), addExtraSpaceIfExist(insertModifier),
scope.QuotedTableName(), scope.QuotedTableName(),
strings.Join(columns, ","), strings.Join(columns, ","),
addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
strings.Join(placeholders, ","), strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix), addExtraSpaceIfExist(lastInsertIDReturningSuffix),
)) ))
} }
// execute create sql // execute create sql: no primaryField
if lastInsertIDReturningSuffix == "" || primaryField == nil { if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count // set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected() scope.db.RowsAffected, _ = result.RowsAffected()
@ -136,16 +138,35 @@ func createCallback(scope *Scope) {
} }
} }
} }
} else { return
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
} }
// execute create sql: lastInsertID implemention for majority of dialects
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
return
}
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
return
} }
} }

View File

@ -40,6 +40,8 @@ type Dialect interface {
LimitAndOffsetSQL(limit, offset interface{}) string LimitAndOffsetSQL(limit, offset interface{}) string
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string LastInsertIDReturningSuffix(tableName, columnName string) string
// DefaultValueStr // DefaultValueStr

View File

@ -157,6 +157,10 @@ func (commonDialect) SelectFromDummyTable() string {
return "" return ""
} }
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return "" return ""
} }

View File

@ -121,7 +121,7 @@ func (s *mysql) DataTypeOf(field *StructField) string {
} }
if sqlType == "" { if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name))
} }
if strings.TrimSpace(additionalType) == "" { if strings.TrimSpace(additionalType) == "" {

View File

@ -120,6 +120,10 @@ func (s postgres) CurrentDatabase() (name string) {
return return
} }
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
return ""
}
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key) return fmt.Sprintf("RETURNING %v.%v", tableName, key)
} }

View File

@ -190,6 +190,14 @@ func (mssql) SelectFromDummyTable() string {
return "" return ""
} }
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
if len(columns) == 0 {
// No OUTPUT to query
return ""
}
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
}
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return "" return ""
} }

View File

@ -133,23 +133,6 @@ func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) {
t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode) t.Errorf("Fetch a record from with a string primary key for a numeric value starting with zero should work, but failed, zip code is %v", address.ZipCode)
} }
} }
func TestStringAgainstIncompleteParentheses(t *testing.T) {
type AddressByZipCode struct {
ZipCode string `gorm:"primary_key"`
Address string
}
DB.AutoMigrate(&AddressByZipCode{})
DB.Create(&AddressByZipCode{ZipCode: "00502", Address: "Holtsville"})
var address AddressByZipCode
var addresses []AddressByZipCode
_ = DB.First(&address, "address_by_zip_codes=00502)) UNION ALL SELECT NULL,version(),current_database(),NULL,NULL,NULL,NULL,NULL--").Find(&addresses).GetErrors()
if len(addresses) > 0 {
t.Errorf("Fetch a record from with a string that has incomplete parentheses should be fail, zip code is %v", address.ZipCode)
}
}
func TestFindAsSliceOfPointers(t *testing.T) { func TestFindAsSliceOfPointers(t *testing.T) {
DB.Save(&User{Name: "user"}) DB.Save(&User{Name: "user"})

View File

@ -277,23 +277,6 @@ func (scope *Scope) AddToVars(value interface{}) string {
return scope.Dialect().BindVar(len(scope.SQLVars)) return scope.Dialect().BindVar(len(scope.SQLVars))
} }
// IsCompleteParentheses check if the string has complete parentheses to prevent SQL injection
func (scope *Scope) IsCompleteParentheses(value string) bool {
count := 0
for i, _ := range value {
if value[i] == 40 { // (
count++
} else if value[i] == 41 { // )
count--
}
if count < 0 {
break
}
i++
}
return count == 0
}
// SelectAttrs return selected attributes // SelectAttrs return selected attributes
func (scope *Scope) SelectAttrs() []string { func (scope *Scope) SelectAttrs() []string {
if scope.selectAttrs == nil { if scope.selectAttrs == nil {
@ -573,10 +556,6 @@ func (scope *Scope) buildCondition(clause map[string]interface{}, include bool)
} }
if value != "" { if value != "" {
if !scope.IsCompleteParentheses(value) {
scope.Err(fmt.Errorf("incomplete parentheses found: %v", value))
return
}
if !include { if !include {
if comparisonRegexp.MatchString(value) { if comparisonRegexp.MatchString(value) {
str = fmt.Sprintf("NOT (%v)", value) str = fmt.Sprintf("NOT (%v)", value)