Added support for default fields on create. Unified logics of default values between update and create
This commit is contained in:
parent
5399fd879f
commit
2d2d926b6e
53
callback.go
53
callback.go
@ -2,6 +2,9 @@ package gorm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type callback struct {
|
type callback struct {
|
||||||
@ -203,4 +206,54 @@ func ForceReload(scope *Scope) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func escapeIfNeeded(scope *Scope, value string) string {
|
||||||
|
trimmed := strings.TrimSpace(value)
|
||||||
|
// default:'string value' OR
|
||||||
|
if (strings.HasPrefix(trimmed, "'") && strings.HasSuffix(trimmed, "'")) ||
|
||||||
|
strings.HasSuffix(trimmed, ")") { //sql expression, like: default:"(now() at timezone 'utc') or now() or user_defined_function(parameters.. )
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
lowered := strings.ToLower(trimmed)
|
||||||
|
if lowered == "null" || strings.HasPrefix(lowered, "current_") { // null and other sql reserved keyworks (used a default values) can't be placed between apices
|
||||||
|
return lowered
|
||||||
|
}
|
||||||
|
return scope.AddToVars(trimmed) // default:'something' like:default:'false' should be between quotes (what AddToVars do)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleDefaultValue(scope *Scope, field *Field) string {
|
||||||
|
if field.IsBlank {
|
||||||
|
defaultValue := strings.TrimSpace(parseTagSetting(field.Tag.Get("sql"))["DEFAULT"])
|
||||||
|
switch field.Field.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
if numericValue, err := strconv.ParseInt(defaultValue, 10, 64); err == nil {
|
||||||
|
if numericValue != field.Field.Int() {
|
||||||
|
return escapeIfNeeded(scope, fmt.Sprintf("%d", field.Field.Int()))
|
||||||
|
} else {
|
||||||
|
return escapeIfNeeded(scope, defaultValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
if numericValue, err := strconv.ParseUint(defaultValue, 10, 64); err == nil {
|
||||||
|
if numericValue != field.Field.Uint() {
|
||||||
|
return escapeIfNeeded(scope, escapeIfNeeded(scope, fmt.Sprintf("%d", field.Field.Int())))
|
||||||
|
} else {
|
||||||
|
return escapeIfNeeded(scope, defaultValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Bool:
|
||||||
|
if boolValue, err := strconv.ParseBool(defaultValue); err == nil {
|
||||||
|
if boolValue != field.Field.Bool() {
|
||||||
|
return escapeIfNeeded(scope, fmt.Sprintf("%t", field.Field.Bool()))
|
||||||
|
} else {
|
||||||
|
return escapeIfNeeded(scope, defaultValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return escapeIfNeeded(scope, defaultValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scope.AddToVars(field.Field.Interface())
|
||||||
|
}
|
||||||
|
|
||||||
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
|
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
|
||||||
|
@ -29,12 +29,15 @@ func Create(scope *Scope) {
|
|||||||
if scope.changeableField(field) {
|
if scope.changeableField(field) {
|
||||||
if field.IsNormal {
|
if field.IsNormal {
|
||||||
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
|
if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
|
||||||
if !field.IsBlank || !field.HasDefaultValue {
|
if !field.HasDefaultValue {
|
||||||
columns = append(columns, scope.Quote(field.DBName))
|
columns = append(columns, scope.Quote(field.DBName))
|
||||||
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
|
||||||
} else if field.HasDefaultValue {
|
} else {
|
||||||
|
columns = append(columns, scope.Quote(field.DBName))
|
||||||
|
sqls = append(sqls, handleDefaultValue(scope, field))
|
||||||
scope.InstanceSet("gorm:force_reload", true)
|
scope.InstanceSet("gorm:force_reload", true)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
|
||||||
for _, dbName := range relationship.ForeignDBNames {
|
for _, dbName := range relationship.ForeignDBNames {
|
||||||
|
@ -37,14 +37,6 @@ func UpdateTimeStampWhenUpdate(scope *Scope) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func escapeIfNeeded(scope *Scope, value string) string {
|
|
||||||
// default:'string value' OR sql expression, like: default:"(now() at timezone 'utc')"
|
|
||||||
if (strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) || (strings.HasPrefix(value, "(") && strings.HasSuffix(value, ")")) {
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
return scope.AddToVars(value) // default:'something' like:default:'false' should be between quotes (what AddToVars do)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Update(scope *Scope) {
|
func Update(scope *Scope) {
|
||||||
if !scope.HasError() {
|
if !scope.HasError() {
|
||||||
var sqls []string
|
var sqls []string
|
||||||
@ -60,12 +52,7 @@ func Update(scope *Scope) {
|
|||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
|
if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
|
||||||
if field.HasDefaultValue {
|
if field.HasDefaultValue {
|
||||||
if field.IsBlank {
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), handleDefaultValue(scope, field)))
|
||||||
defaultValue := parseTagSetting(field.Tag.Get("sql"))["DEFAULT"]
|
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), escapeIfNeeded(scope, defaultValue)))
|
|
||||||
} else {
|
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
|
||||||
}
|
|
||||||
scope.InstanceSet("gorm:force_reload", true)
|
scope.InstanceSet("gorm:force_reload", true)
|
||||||
} else {
|
} else {
|
||||||
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
|
||||||
|
@ -157,3 +157,25 @@ func TestOmitWithCreate(t *testing.T) {
|
|||||||
t.Errorf("Should not create omited relationships")
|
t.Errorf("Should not create omited relationships")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test from: https://github.com/jinzhu/gorm/issues/689
|
||||||
|
func TestCreateWithBoolDefaultValue(t *testing.T) {
|
||||||
|
type Data struct {
|
||||||
|
ID int `gorm:"column:id;primary_key" json:"id"`
|
||||||
|
Name string `sql:"type:varchar(100);not null;unique" json:"name"`
|
||||||
|
DeleteAllowed bool `sql:"not null;DEFAULT:true" json:"delete_allowed"`
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.AutoMigrate(&Data{})
|
||||||
|
|
||||||
|
data := Data{
|
||||||
|
Name: "test",
|
||||||
|
DeleteAllowed: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
DB.Create(&data)
|
||||||
|
|
||||||
|
if data.DeleteAllowed {
|
||||||
|
t.Error("Test failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user