Merge branch 'master' into master

This commit is contained in:
Jinzhu 2019-10-17 23:07:03 +08:00 committed by GitHub
commit ac3b7669cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 16 deletions

View File

@ -119,8 +119,8 @@ func (cp *CallbackProcessor) Remove(callbackName string) {
// Replace a registered callback with new callback // Replace a registered callback with new callback
// db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) {
// scope.SetColumn("Created", now) // scope.SetColumn("CreatedAt", now)
// scope.SetColumn("Updated", now) // scope.SetColumn("UpdatedAt", now)
// }) // })
func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) {
cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum())) cp.logger.Print(fmt.Sprintf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()))

View File

@ -101,10 +101,11 @@ func createCallback(scope *Scope) {
} }
lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
if len(columns) == 0 { if len(columns) == 0 {
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v %v%v%v", "INSERT%v INTO %v %v%v%v",
addExtraSpaceIfExist(insertModifier), addExtraSpaceIfExist(insertModifier),
quotedTableName, quotedTableName,
scope.Dialect().DefaultValueStr(), scope.Dialect().DefaultValueStr(),
@ -113,18 +114,19 @@ func createCallback(scope *Scope) {
)) ))
} else { } else {
scope.Raw(fmt.Sprintf( scope.Raw(fmt.Sprintf(
"INSERT %v INTO %v (%v) VALUES (%v)%v%v", "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
addExtraSpaceIfExist(insertModifier), addExtraSpaceIfExist(insertModifier),
scope.QuotedTableName(), scope.QuotedTableName(),
strings.Join(columns, ","), strings.Join(columns, ","),
addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
strings.Join(placeholders, ","), strings.Join(placeholders, ","),
addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(extraOption),
addExtraSpaceIfExist(lastInsertIDReturningSuffix), addExtraSpaceIfExist(lastInsertIDReturningSuffix),
)) ))
} }
// execute create sql // execute create sql: no primaryField
if lastInsertIDReturningSuffix == "" || primaryField == nil { if primaryField == nil {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count // set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected() scope.db.RowsAffected, _ = result.RowsAffected()
@ -136,16 +138,35 @@ func createCallback(scope *Scope) {
} }
} }
} }
} else { return
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
} }
// execute create sql: lastInsertID implemention for majority of dialects
if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
// set rows affected count
scope.db.RowsAffected, _ = result.RowsAffected()
// set primary value to primary field
if primaryField != nil && primaryField.IsBlank {
if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
scope.Err(primaryField.Set(primaryValue))
}
}
}
return
}
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
if primaryField.Field.CanAddr() {
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
primaryField.IsBlank = false
scope.db.RowsAffected = 1
}
} else {
scope.Err(ErrUnaddressable)
}
return
} }
} }

View File

@ -40,6 +40,8 @@ type Dialect interface {
LimitAndOffsetSQL(limit, offset interface{}) string LimitAndOffsetSQL(limit, offset interface{}) string
// SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL`
SelectFromDummyTable() string SelectFromDummyTable() string
// LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT`
LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string
// LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING`
LastInsertIDReturningSuffix(tableName, columnName string) string LastInsertIDReturningSuffix(tableName, columnName string) string
// DefaultValueStr // DefaultValueStr

View File

@ -157,6 +157,10 @@ func (commonDialect) SelectFromDummyTable() string {
return "" return ""
} }
func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return "" return ""
} }

View File

@ -120,7 +120,7 @@ func (s *mysql) DataTypeOf(field *StructField) string {
} }
if sqlType == "" { if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name))
} }
if strings.TrimSpace(additionalType) == "" { if strings.TrimSpace(additionalType) == "" {

View File

@ -120,6 +120,10 @@ func (s postgres) CurrentDatabase() (name string) {
return return
} }
func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string {
return ""
}
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
return fmt.Sprintf("RETURNING %v.%v", tableName, key) return fmt.Sprintf("RETURNING %v.%v", tableName, key)
} }

View File

@ -190,6 +190,14 @@ func (mssql) SelectFromDummyTable() string {
return "" return ""
} }
func (mssql) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string {
if len(columns) == 0 {
// No OUTPUT to query
return ""
}
return fmt.Sprintf("OUTPUT Inserted.%v", columnName)
}
func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string { func (mssql) LastInsertIDReturningSuffix(tableName, columnName string) string {
return "" return ""
} }