Merge 387c00d2b5d6099dd4686401775f432c976ff967 into aa45da15421ddb1549a7060e5c189f8350a29166
This commit is contained in:
		
						commit
						3de6d5f115
					
				
							
								
								
									
										67
									
								
								callback.go
									
									
									
									
									
								
							
							
						
						
									
										67
									
								
								callback.go
									
									
									
									
									
								
							@ -2,6 +2,9 @@ package gorm
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type callback struct {
 | 
					type callback struct {
 | 
				
			||||||
@ -197,4 +200,68 @@ func (c *callback) sort() {
 | 
				
			|||||||
	c.rowQueries = sortProcessors(rowQueries)
 | 
						c.rowQueries = sortProcessors(rowQueries)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ForceReload(scope *Scope) {
 | 
				
			||||||
 | 
						if _, ok := scope.InstanceGet("gorm:force_reload"); ok {
 | 
				
			||||||
 | 
							scope.DB().New().First(scope.Value)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func escapeIfNeeded(scope *Scope, value string) string {
 | 
				
			||||||
 | 
						trimmed := strings.TrimSpace(value)
 | 
				
			||||||
 | 
						// default:'string value' OR
 | 
				
			||||||
 | 
						if (strings.HasPrefix(trimmed, "'") && strings.HasSuffix(trimmed, "'")) ||
 | 
				
			||||||
 | 
							strings.HasSuffix(trimmed, ")") { //sql expression, like: default:"(now() at timezone 'utc') or now() or user_defined_function(parameters.. )
 | 
				
			||||||
 | 
							return trimmed
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						lowered := strings.ToLower(trimmed)
 | 
				
			||||||
 | 
						if lowered == "null" || strings.HasPrefix(lowered, "current_") { // null and other sql reserved keyworks (used a default values) can't be placed between apices
 | 
				
			||||||
 | 
							return lowered
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return scope.AddToVars(trimmed) // default:'something'  like:default:'false' should be between quotes (what AddToVars do)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func handleDefaultValue(scope *Scope, field *Field) string {
 | 
				
			||||||
 | 
						if field.IsBlank {
 | 
				
			||||||
 | 
							defaultValue := strings.TrimSpace(parseTagSetting(field.Tag.Get("sql"))["DEFAULT"])
 | 
				
			||||||
 | 
							switch field.Field.Kind() {
 | 
				
			||||||
 | 
							case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 | 
				
			||||||
 | 
								if numericValue, err := strconv.ParseInt(defaultValue, 10, 64); err == nil {
 | 
				
			||||||
 | 
									if numericValue != field.Field.Int() {
 | 
				
			||||||
 | 
										return escapeIfNeeded(scope, fmt.Sprintf("%d", field.Field.Int()))
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										return escapeIfNeeded(scope, defaultValue)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 | 
				
			||||||
 | 
								if numericValue, err := strconv.ParseUint(defaultValue, 10, 64); err == nil {
 | 
				
			||||||
 | 
									if numericValue != field.Field.Uint() {
 | 
				
			||||||
 | 
										return escapeIfNeeded(scope, escapeIfNeeded(scope, fmt.Sprintf("%d", field.Field.Int())))
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										return escapeIfNeeded(scope, defaultValue)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case reflect.Bool:
 | 
				
			||||||
 | 
								if boolValue, err := strconv.ParseBool(defaultValue); err == nil {
 | 
				
			||||||
 | 
									if boolValue != field.Field.Bool() {
 | 
				
			||||||
 | 
										return escapeIfNeeded(scope, fmt.Sprintf("%t", field.Field.Bool()))
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										return escapeIfNeeded(scope, defaultValue)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								/*
 | 
				
			||||||
 | 
									case reflect.String:
 | 
				
			||||||
 | 
										if defaultValue != field.Field.String() {
 | 
				
			||||||
 | 
											return escapeIfNeeded(scope, fmt.Sprintf("%s", field.Field.String()))
 | 
				
			||||||
 | 
										} else {
 | 
				
			||||||
 | 
											return escapeIfNeeded(scope, defaultValue)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
								*/
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								return escapeIfNeeded(scope, defaultValue)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return scope.AddToVars(field.Field.Interface())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var DefaultCallback = &callback{processors: []*callbackProcessor{}}
 | 
					var DefaultCallback = &callback{processors: []*callbackProcessor{}}
 | 
				
			||||||
 | 
				
			|||||||
@ -29,17 +29,15 @@ func Create(scope *Scope) {
 | 
				
			|||||||
			if scope.changeableField(field) {
 | 
								if scope.changeableField(field) {
 | 
				
			||||||
				if field.IsNormal {
 | 
									if field.IsNormal {
 | 
				
			||||||
					if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
 | 
										if !field.IsPrimaryKey || (field.IsPrimaryKey && !field.IsBlank) {
 | 
				
			||||||
						if !field.IsBlank || !field.HasDefaultValue {
 | 
											if !field.HasDefaultValue {
 | 
				
			||||||
							columns = append(columns, scope.Quote(field.DBName))
 | 
												columns = append(columns, scope.Quote(field.DBName))
 | 
				
			||||||
							sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
 | 
												sqls = append(sqls, scope.AddToVars(field.Field.Interface()))
 | 
				
			||||||
						} else if field.HasDefaultValue {
 | 
											} else {
 | 
				
			||||||
							var hasDefaultValueColumns []string
 | 
												columns = append(columns, scope.Quote(field.DBName))
 | 
				
			||||||
							if oldHasDefaultValueColumns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
 | 
												sqls = append(sqls, handleDefaultValue(scope, field))
 | 
				
			||||||
								hasDefaultValueColumns = oldHasDefaultValueColumns.([]string)
 | 
												scope.InstanceSet("gorm:force_reload", true)
 | 
				
			||||||
							}
 | 
					 | 
				
			||||||
							hasDefaultValueColumns = append(hasDefaultValueColumns, field.DBName)
 | 
					 | 
				
			||||||
							scope.InstanceSet("gorm:force_reload_after_create_attrs", hasDefaultValueColumns)
 | 
					 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
									} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
				
			||||||
					for _, dbName := range relationship.ForeignDBNames {
 | 
										for _, dbName := range relationship.ForeignDBNames {
 | 
				
			||||||
@ -102,12 +100,6 @@ func Create(scope *Scope) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func ForceReloadAfterCreate(scope *Scope) {
 | 
					 | 
				
			||||||
	if columns, ok := scope.InstanceGet("gorm:force_reload_after_create_attrs"); ok {
 | 
					 | 
				
			||||||
		scope.DB().New().Select(columns.([]string)).First(scope.Value)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func AfterCreate(scope *Scope) {
 | 
					func AfterCreate(scope *Scope) {
 | 
				
			||||||
	scope.CallMethodWithErrorCheck("AfterCreate")
 | 
						scope.CallMethodWithErrorCheck("AfterCreate")
 | 
				
			||||||
	scope.CallMethodWithErrorCheck("AfterSave")
 | 
						scope.CallMethodWithErrorCheck("AfterSave")
 | 
				
			||||||
@ -119,7 +111,7 @@ func init() {
 | 
				
			|||||||
	DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
 | 
						DefaultCallback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
 | 
						DefaultCallback.Create().Register("gorm:update_time_stamp_when_create", UpdateTimeStampWhenCreate)
 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:create", Create)
 | 
						DefaultCallback.Create().Register("gorm:create", Create)
 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:force_reload_after_create", ForceReloadAfterCreate)
 | 
						DefaultCallback.Create().Register("gorm:force_reload", ForceReload)
 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
 | 
						DefaultCallback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
 | 
						DefaultCallback.Create().Register("gorm:after_create", AfterCreate)
 | 
				
			||||||
	DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
 | 
						DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
 | 
				
			||||||
 | 
				
			|||||||
@ -43,7 +43,7 @@ func Update(scope *Scope) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
 | 
							if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok {
 | 
				
			||||||
			for key, value := range updateAttrs.(map[string]interface{}) {
 | 
								for key, value := range updateAttrs.(map[string]interface{}) {
 | 
				
			||||||
				if scope.changeableDBColumn(key) {
 | 
									if scope.isChangeableDBColumn(key) {
 | 
				
			||||||
					sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
 | 
										sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(key), scope.AddToVars(value)))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@ -51,7 +51,12 @@ func Update(scope *Scope) {
 | 
				
			|||||||
			fields := scope.Fields()
 | 
								fields := scope.Fields()
 | 
				
			||||||
			for _, field := range fields {
 | 
								for _, field := range fields {
 | 
				
			||||||
				if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
 | 
									if scope.changeableField(field) && !field.IsPrimaryKey && field.IsNormal {
 | 
				
			||||||
					sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
 | 
										if field.HasDefaultValue {
 | 
				
			||||||
 | 
											sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), handleDefaultValue(scope, field)))
 | 
				
			||||||
 | 
											scope.InstanceSet("gorm:force_reload", true)
 | 
				
			||||||
 | 
										} else {
 | 
				
			||||||
 | 
											sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface())))
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
				} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
									} else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" {
 | 
				
			||||||
					for _, dbName := range relationship.ForeignDBNames {
 | 
										for _, dbName := range relationship.ForeignDBNames {
 | 
				
			||||||
						if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
 | 
											if relationField := fields[dbName]; !scope.changeableField(relationField) && !relationField.IsBlank {
 | 
				
			||||||
@ -92,4 +97,5 @@ func init() {
 | 
				
			|||||||
	DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
 | 
						DefaultCallback.Update().Register("gorm:save_after_associations", SaveAfterAssociations)
 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
 | 
						DefaultCallback.Update().Register("gorm:after_update", AfterUpdate)
 | 
				
			||||||
	DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
 | 
						DefaultCallback.Update().Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
 | 
				
			||||||
 | 
						DefaultCallback.Update().Register("gorm:force_reload", ForceReload)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -162,3 +162,25 @@ func TestOmitWithCreate(t *testing.T) {
 | 
				
			|||||||
		t.Errorf("Should not create omited relationships")
 | 
							t.Errorf("Should not create omited relationships")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Test from: https://github.com/jinzhu/gorm/issues/689
 | 
				
			||||||
 | 
					func TestCreateWithBoolDefaultValue(t *testing.T) {
 | 
				
			||||||
 | 
						type Data struct {
 | 
				
			||||||
 | 
							ID            int    `gorm:"column:id;primary_key" json:"id"`
 | 
				
			||||||
 | 
							Name          string `sql:"type:varchar(100);not null;unique" json:"name"`
 | 
				
			||||||
 | 
							DeleteAllowed bool   `sql:"not null;DEFAULT:true" json:"delete_allowed"`
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						DB.AutoMigrate(&Data{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						data := Data{
 | 
				
			||||||
 | 
							Name:          "test",
 | 
				
			||||||
 | 
							DeleteAllowed: false,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						DB.Create(&data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if data.DeleteAllowed {
 | 
				
			||||||
 | 
							t.Error("Test failed")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								scope.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								scope.go
									
									
									
									
									
								
							@ -411,7 +411,7 @@ func (scope *Scope) OmitAttrs() []string {
 | 
				
			|||||||
	return scope.Search.omits
 | 
						return scope.Search.omits
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (scope *Scope) changeableDBColumn(column string) bool {
 | 
					func (scope *Scope) isChangeableDBColumn(column string) bool {
 | 
				
			||||||
	selectAttrs := scope.SelectAttrs()
 | 
						selectAttrs := scope.SelectAttrs()
 | 
				
			||||||
	omitAttrs := scope.OmitAttrs()
 | 
						omitAttrs := scope.OmitAttrs()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -136,6 +136,7 @@ type Animal struct {
 | 
				
			|||||||
	unexported string    // unexported value
 | 
						unexported string    // unexported value
 | 
				
			||||||
	CreatedAt  time.Time
 | 
						CreatedAt  time.Time
 | 
				
			||||||
	UpdatedAt  time.Time
 | 
						UpdatedAt  time.Time
 | 
				
			||||||
 | 
						Cool       bool `sql:"default:false"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type JoinTable struct {
 | 
					type JoinTable struct {
 | 
				
			||||||
 | 
				
			|||||||
@ -87,10 +87,11 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
 | 
				
			|||||||
	DB.Save(&animal)
 | 
						DB.Save(&animal)
 | 
				
			||||||
	updatedAt1 := animal.UpdatedAt
 | 
						updatedAt1 := animal.UpdatedAt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Sleep for a second and than update a field
 | 
				
			||||||
 | 
						time.Sleep(1000 * time.Millisecond)
 | 
				
			||||||
	DB.Save(&animal).Update("name", "Francis")
 | 
						DB.Save(&animal).Update("name", "Francis")
 | 
				
			||||||
 | 
					 | 
				
			||||||
	if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) {
 | 
						if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) {
 | 
				
			||||||
		t.Errorf("updatedAt should not be updated if nothing changed")
 | 
							t.Errorf("updatedAt should be updated when changing a field")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var animals []Animal
 | 
						var animals []Animal
 | 
				
			||||||
@ -101,7 +102,6 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	animal = Animal{From: "somewhere"}              // No name fields, should be filled with the default value (galeone)
 | 
						animal = Animal{From: "somewhere"}              // No name fields, should be filled with the default value (galeone)
 | 
				
			||||||
	DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
 | 
						DB.Save(&animal).Update("From", "a nice place") // The name field shoul be untouched
 | 
				
			||||||
	DB.First(&animal, animal.Counter)
 | 
					 | 
				
			||||||
	if animal.Name != "galeone" {
 | 
						if animal.Name != "galeone" {
 | 
				
			||||||
		t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name)
 | 
							t.Errorf("Name fiels shouldn't be changed if untouched, but got %v", animal.Name)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -109,17 +109,28 @@ func TestUpdateWithNoStdPrimaryKeyAndDefaultValues(t *testing.T) {
 | 
				
			|||||||
	// When changing a field with a default value, the change must occur
 | 
						// When changing a field with a default value, the change must occur
 | 
				
			||||||
	animal.Name = "amazing horse"
 | 
						animal.Name = "amazing horse"
 | 
				
			||||||
	DB.Save(&animal)
 | 
						DB.Save(&animal)
 | 
				
			||||||
	DB.First(&animal, animal.Counter)
 | 
					 | 
				
			||||||
	if animal.Name != "amazing horse" {
 | 
						if animal.Name != "amazing horse" {
 | 
				
			||||||
		t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name)
 | 
							t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// When changing a field with a default value with blank value
 | 
						// When changing a field with a default value with blank value, the DBMS should insert the default value. Not the empty one.
 | 
				
			||||||
	animal.Name = ""
 | 
						animal.Name = ""
 | 
				
			||||||
	DB.Save(&animal)
 | 
						DB.Save(&animal)
 | 
				
			||||||
	DB.First(&animal, animal.Counter)
 | 
						if animal.Name == "" {
 | 
				
			||||||
	if animal.Name != "" {
 | 
							t.Errorf("Update a filed with an associated default value should not occur when trying to insert an empty field. The default one should be inserted\n")
 | 
				
			||||||
		t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name)
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Animal.Cool has a default value thats equal to the Zero of its type. (false) I have to update this field to true and false without problems
 | 
				
			||||||
 | 
						animal.Cool = true
 | 
				
			||||||
 | 
						DB.Save(&animal)
 | 
				
			||||||
 | 
						if !animal.Cool {
 | 
				
			||||||
 | 
							t.Errorf("I should update a field with a default value to someother value")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						animal.Cool = false
 | 
				
			||||||
 | 
						DB.Save(&animal)
 | 
				
			||||||
 | 
						if animal.Cool {
 | 
				
			||||||
 | 
							t.Errorf("I should update a field with an associated blank value to its blank value")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user