Add postgres create returning all serial type columns support
This commit is contained in:
parent
7ea143b548
commit
656b6e22e4
@ -162,7 +162,18 @@ func createCallback(scope *Scope) {
|
||||
|
||||
// execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
|
||||
if primaryField.Field.CanAddr() {
|
||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
|
||||
dest := []interface{}{primaryField.Field.Addr().Interface()}
|
||||
if scope.Dialect().GetName() == "postgres" {
|
||||
for _, field := range scope.Fields() {
|
||||
typ, ok := field.TagSettingsGet("TYPE")
|
||||
if ok && typ == "serial" && field.Field.CanAddr() {
|
||||
dest = append(dest, field.Field.Addr().Interface())
|
||||
scope.SQL += fmt.Sprintf(", %s.%s", quotedTableName, scope.Quote(field.DBName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(dest...); scope.Err(err) == nil {
|
||||
primaryField.IsBlank = false
|
||||
scope.db.RowsAffected = 1
|
||||
}
|
||||
|
@ -286,3 +286,37 @@ func TestCreateIgnore(t *testing.T) {
|
||||
t.Error("Should ignore duplicate user insert by insert modifier:IGNORE ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresReturningMultipleColumns(t *testing.T) {
|
||||
type ReturningMultipleColumns struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
SerialA int `gorm:"type:serial;not null"`
|
||||
SerialB int `gorm:"type:serial;not null"`
|
||||
SerialC int `gorm:"type:serial;not null"`
|
||||
}
|
||||
|
||||
DB.LogMode(true)
|
||||
|
||||
err := DB.AutoMigrate(&ReturningMultipleColumns{}).Error
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
record := ReturningMultipleColumns{}
|
||||
err = DB.Omit("serial_a", "serial_b", "serial_c").Create(&record).Error
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if record.SerialA == 0 {
|
||||
t.Error("SerialA should not be 0")
|
||||
}
|
||||
|
||||
if record.SerialB == 0 {
|
||||
t.Error("SerialB should not be 0")
|
||||
}
|
||||
|
||||
if record.SerialC == 0 {
|
||||
t.Error("SerialC should not be 0")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user