Merge branch 'go-gorm:master' into master
This commit is contained in:
		
						commit
						5d84ac618c
					
				@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if db.Statement.Schema != nil {
 | 
				
			||||||
 | 
								for _, c := range db.Statement.Schema.DeleteClauses {
 | 
				
			||||||
 | 
									db.Statement.AddClause(c)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if db.Statement.SQL.Len() == 0 {
 | 
							if db.Statement.SQL.Len() == 0 {
 | 
				
			||||||
			db.Statement.SQL.Grow(100)
 | 
								db.Statement.SQL.Grow(100)
 | 
				
			||||||
			db.Statement.AddClauseIfNotExists(clause.Delete{})
 | 
								db.Statement.AddClauseIfNotExists(clause.Delete{})
 | 
				
			||||||
@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			db.Statement.AddClauseIfNotExists(clause.From{})
 | 
								db.Statement.AddClauseIfNotExists(clause.From{})
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if db.Statement.Schema != nil {
 | 
					 | 
				
			||||||
			for _, c := range db.Statement.Schema.DeleteClauses {
 | 
					 | 
				
			||||||
				db.Statement.AddClause(c)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if db.Statement.SQL.Len() == 0 {
 | 
					 | 
				
			||||||
			db.Statement.Build(db.Statement.BuildClauses...)
 | 
								db.Statement.Build(db.Statement.BuildClauses...)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
 | 
							checkMissingWhereConditions(db)
 | 
				
			||||||
			db.AddError(gorm.ErrMissingWhereClause)
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !db.DryRun && db.Error == nil {
 | 
							if !db.DryRun && db.Error == nil {
 | 
				
			||||||
			ok, mode := hasReturning(db, supportReturning)
 | 
								ok, mode := hasReturning(db, supportReturning)
 | 
				
			||||||
 | 
				
			|||||||
@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return false, 0
 | 
						return false, 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func checkMissingWhereConditions(db *gorm.DB) {
 | 
				
			||||||
 | 
						if !db.AllowGlobalUpdate && db.Error == nil {
 | 
				
			||||||
 | 
							where, withCondition := db.Statement.Clauses["WHERE"]
 | 
				
			||||||
 | 
							if withCondition {
 | 
				
			||||||
 | 
								if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
 | 
				
			||||||
 | 
									whereClause, _ := where.Expression.(clause.Where)
 | 
				
			||||||
 | 
									withCondition = len(whereClause.Exprs) > 1
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !withCondition {
 | 
				
			||||||
 | 
								db.AddError(gorm.ErrMissingWhereClause)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if db.Statement.Schema != nil {
 | 
				
			||||||
 | 
								for _, c := range db.Statement.Schema.UpdateClauses {
 | 
				
			||||||
 | 
									db.Statement.AddClause(c)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if db.Statement.SQL.Len() == 0 {
 | 
							if db.Statement.SQL.Len() == 0 {
 | 
				
			||||||
			db.Statement.SQL.Grow(180)
 | 
								db.Statement.SQL.Grow(180)
 | 
				
			||||||
			db.Statement.AddClauseIfNotExists(clause.Update{})
 | 
								db.Statement.AddClauseIfNotExists(clause.Update{})
 | 
				
			||||||
@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) {
 | 
				
			|||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if db.Statement.Schema != nil {
 | 
					 | 
				
			||||||
			for _, c := range db.Statement.Schema.UpdateClauses {
 | 
					 | 
				
			||||||
				db.Statement.AddClause(c)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if db.Statement.SQL.Len() == 0 {
 | 
					 | 
				
			||||||
			db.Statement.Build(db.Statement.BuildClauses...)
 | 
								db.Statement.Build(db.Statement.BuildClauses...)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
 | 
							checkMissingWhereConditions(db)
 | 
				
			||||||
			db.AddError(gorm.ErrMissingWhereClause)
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !db.DryRun && db.Error == nil {
 | 
							if !db.DryRun && db.Error == nil {
 | 
				
			||||||
			if ok, mode := hasReturning(db, supportReturning); ok {
 | 
								if ok, mode := hasReturning(db, supportReturning); ok {
 | 
				
			||||||
@ -232,10 +226,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
 | 
				
			|||||||
							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
 | 
												set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
 | 
				
			||||||
						} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
											} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
				
			||||||
							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
 | 
												set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
 | 
				
			||||||
						} else if field.GORMDataType == schema.Time {
 | 
											} else if field.AutoUpdateTime == schema.UnixSecond {
 | 
				
			||||||
							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
 | 
					 | 
				
			||||||
						} else {
 | 
					 | 
				
			||||||
							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
 | 
												set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
 | 
				
			||||||
 | 
											} else {
 | 
				
			||||||
 | 
												set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
 | 
				
			||||||
						}
 | 
											}
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -264,10 +258,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
 | 
				
			|||||||
									value = stmt.DB.NowFunc().UnixNano()
 | 
														value = stmt.DB.NowFunc().UnixNano()
 | 
				
			||||||
								} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
													} else if field.AutoUpdateTime == schema.UnixMillisecond {
 | 
				
			||||||
									value = stmt.DB.NowFunc().UnixNano() / 1e6
 | 
														value = stmt.DB.NowFunc().UnixNano() / 1e6
 | 
				
			||||||
								} else if field.GORMDataType == schema.Time {
 | 
													} else if field.AutoUpdateTime == schema.UnixSecond {
 | 
				
			||||||
									value = stmt.DB.NowFunc()
 | 
					 | 
				
			||||||
								} else {
 | 
					 | 
				
			||||||
									value = stmt.DB.NowFunc().Unix()
 | 
														value = stmt.DB.NowFunc().Unix()
 | 
				
			||||||
 | 
													} else {
 | 
				
			||||||
 | 
														value = stmt.DB.NowFunc()
 | 
				
			||||||
								}
 | 
													}
 | 
				
			||||||
								isZero = false
 | 
													isZero = false
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
 | 
				
			|||||||
@ -75,10 +75,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
 | 
				
			|||||||
			case reflect.Bool:
 | 
								case reflect.Bool:
 | 
				
			||||||
				vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
 | 
									vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
 | 
				
			||||||
			case reflect.String:
 | 
								case reflect.String:
 | 
				
			||||||
				vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
 | 
									vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
 | 
				
			||||||
			default:
 | 
								default:
 | 
				
			||||||
				if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
 | 
									if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
 | 
				
			||||||
					vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
 | 
										vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					vars[idx] = nullStr
 | 
										vars[idx] = nullStr
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
@ -94,7 +94,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
 | 
				
			|||||||
		case float64, float32:
 | 
							case float64, float32:
 | 
				
			||||||
			vars[idx] = fmt.Sprintf("%.6f", v)
 | 
								vars[idx] = fmt.Sprintf("%.6f", v)
 | 
				
			||||||
		case string:
 | 
							case string:
 | 
				
			||||||
			vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
 | 
								vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			rv := reflect.ValueOf(v)
 | 
								rv := reflect.ValueOf(v)
 | 
				
			||||||
			if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
 | 
								if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
 | 
				
			||||||
@ -111,7 +111,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
 | 
				
			|||||||
						return
 | 
											return
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
 | 
									vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func format(v []byte, escaper string) string {
 | 
					func format(v []byte, escaper string) string {
 | 
				
			||||||
	return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
 | 
						return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestExplainSQL(t *testing.T) {
 | 
					func TestExplainSQL(t *testing.T) {
 | 
				
			||||||
 | 
				
			|||||||
@ -293,6 +293,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if field.GORMDataType == "" {
 | 
				
			||||||
 | 
							field.GORMDataType = field.DataType
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if val, ok := field.TagSettings["TYPE"]; ok {
 | 
						if val, ok := field.TagSettings["TYPE"]; ok {
 | 
				
			||||||
		switch DataType(strings.ToLower(val)) {
 | 
							switch DataType(strings.ToLower(val)) {
 | 
				
			||||||
		case Bool, Int, Uint, Float, String, Time, Bytes:
 | 
							case Bool, Int, Uint, Float, String, Time, Bytes:
 | 
				
			||||||
@ -302,10 +306,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if field.GORMDataType == "" {
 | 
					 | 
				
			||||||
		field.GORMDataType = field.DataType
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if field.Size == 0 {
 | 
						if field.Size == 0 {
 | 
				
			||||||
		switch reflect.Indirect(fieldValue).Kind() {
 | 
							switch reflect.Indirect(fieldValue).Kind() {
 | 
				
			||||||
		case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
 | 
							case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,6 @@ package schema
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"crypto/sha1"
 | 
						"crypto/sha1"
 | 
				
			||||||
	"encoding/hex"
 | 
						"encoding/hex"
 | 
				
			||||||
	"fmt"
 | 
					 | 
				
			||||||
	"regexp"
 | 
						"regexp"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"unicode/utf8"
 | 
						"unicode/utf8"
 | 
				
			||||||
@ -95,7 +94,7 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string {
 | 
				
			|||||||
		h.Write([]byte(formattedName))
 | 
							h.Write([]byte(formattedName))
 | 
				
			||||||
		bs := h.Sum(nil)
 | 
							bs := h.Sum(nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8]
 | 
							formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return formattedName
 | 
						return formattedName
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -174,7 +173,7 @@ func (ns NamingStrategy) toDBName(name string) string {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (ns NamingStrategy) toSchemaName(name string) string {
 | 
					func (ns NamingStrategy) toSchemaName(name string) string {
 | 
				
			||||||
	result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1)
 | 
						result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
 | 
				
			||||||
	for _, initialism := range commonInitialisms {
 | 
						for _, initialism := range commonInitialisms {
 | 
				
			||||||
		result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
 | 
							result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -193,7 +193,7 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
 | 
				
			|||||||
	ns := NamingStrategy{}
 | 
						ns := NamingStrategy{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
 | 
						formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
 | 
				
			||||||
	if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" {
 | 
						if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
 | 
				
			||||||
		t.Errorf("invalid formatted name generated, got %v", formattedName)
 | 
							t.Errorf("invalid formatted name generated, got %v", formattedName)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -576,3 +576,39 @@ func TestHasManySameForeignKey(t *testing.T) {
 | 
				
			|||||||
		References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
 | 
							References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Author struct {
 | 
				
			||||||
 | 
						gorm.Model
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Book struct {
 | 
				
			||||||
 | 
						gorm.Model
 | 
				
			||||||
 | 
						Author   Author
 | 
				
			||||||
 | 
						AuthorID uint
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (Book) TableName() string {
 | 
				
			||||||
 | 
						return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
 | 
				
			||||||
 | 
						s, err := schema.Parse(
 | 
				
			||||||
 | 
							&Book{},
 | 
				
			||||||
 | 
							&sync.Map{},
 | 
				
			||||||
 | 
							schema.NamingStrategy{},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("Failed to parse schema")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec"
 | 
				
			||||||
 | 
						constraint := s.Relationships.Relations["Author"].ParseConstraint()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if constraint.Name != expectedConstraintName {
 | 
				
			||||||
 | 
							t.Fatalf(
 | 
				
			||||||
 | 
								"expected constraint name %s, got %s",
 | 
				
			||||||
 | 
								expectedConstraintName,
 | 
				
			||||||
 | 
								constraint.Name,
 | 
				
			||||||
 | 
							)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -1,11 +1,12 @@
 | 
				
			|||||||
package schema
 | 
					package schema
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"database/sql"
 | 
						"database/sql"
 | 
				
			||||||
	"database/sql/driver"
 | 
						"database/sql/driver"
 | 
				
			||||||
 | 
						"encoding/gob"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
					 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
@ -32,6 +33,7 @@ func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
 | 
				
			|||||||
func init() {
 | 
					func init() {
 | 
				
			||||||
	RegisterSerializer("json", JSONSerializer{})
 | 
						RegisterSerializer("json", JSONSerializer{})
 | 
				
			||||||
	RegisterSerializer("unixtime", UnixSecondSerializer{})
 | 
						RegisterSerializer("unixtime", UnixSecondSerializer{})
 | 
				
			||||||
 | 
						RegisterSerializer("gob", GobSerializer{})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Serializer field value serializer
 | 
					// Serializer field value serializer
 | 
				
			||||||
@ -83,7 +85,7 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
 | 
				
			|||||||
		case string:
 | 
							case string:
 | 
				
			||||||
			bytes = []byte(v)
 | 
								bytes = []byte(v)
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue))
 | 
								return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		err = json.Unmarshal(bytes, fieldValue.Interface())
 | 
							err = json.Unmarshal(bytes, fieldValue.Interface())
 | 
				
			||||||
@ -123,3 +125,33 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GobSerializer gob serializer
 | 
				
			||||||
 | 
					type GobSerializer struct {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Scan implements serializer interface
 | 
				
			||||||
 | 
					func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
 | 
				
			||||||
 | 
						fieldValue := reflect.New(field.FieldType)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if dbValue != nil {
 | 
				
			||||||
 | 
							var bytesValue []byte
 | 
				
			||||||
 | 
							switch v := dbValue.(type) {
 | 
				
			||||||
 | 
							case []byte:
 | 
				
			||||||
 | 
								bytesValue = v
 | 
				
			||||||
 | 
							default:
 | 
				
			||||||
 | 
								return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
 | 
				
			||||||
 | 
							err = decoder.Decode(fieldValue.Interface())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Value implements serializer interface
 | 
				
			||||||
 | 
					func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
 | 
				
			||||||
 | 
						buf := new(bytes.Buffer)
 | 
				
			||||||
 | 
						err := gob.NewEncoder(buf).Encode(fieldValue)
 | 
				
			||||||
 | 
						return buf.Bytes(), err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -104,11 +104,9 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
 | 
					func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
 | 
				
			||||||
	if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
 | 
						if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
 | 
				
			||||||
		if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok {
 | 
					 | 
				
			||||||
		SoftDeleteQueryClause(sd).ModifyStatement(stmt)
 | 
							SoftDeleteQueryClause(sd).ModifyStatement(stmt)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
 | 
					func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
 | 
				
			||||||
	return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
 | 
						return []clause.Interface{SoftDeleteDeleteClause{Field: f}}
 | 
				
			||||||
@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
 | 
					 | 
				
			||||||
			stmt.DB.AddError(ErrMissingWhereClause)
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
		SoftDeleteQueryClause(sd).ModifyStatement(stmt)
 | 
							SoftDeleteQueryClause(sd).ModifyStatement(stmt)
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		stmt.AddClauseIfNotExists(clause.Update{})
 | 
							stmt.AddClauseIfNotExists(clause.Update{})
 | 
				
			||||||
		stmt.Build(stmt.DB.Callback().Update().Clauses...)
 | 
							stmt.Build(stmt.DB.Callback().Update().Clauses...)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
				
			|||||||
@ -9,11 +9,11 @@ require (
 | 
				
			|||||||
	github.com/lib/pq v1.10.4
 | 
						github.com/lib/pq v1.10.4
 | 
				
			||||||
	github.com/mattn/go-sqlite3 v1.14.11 // indirect
 | 
						github.com/mattn/go-sqlite3 v1.14.11 // indirect
 | 
				
			||||||
	golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
						golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
				
			||||||
	gorm.io/driver/mysql v1.3.1
 | 
						gorm.io/driver/mysql v1.3.2
 | 
				
			||||||
	gorm.io/driver/postgres v1.3.1
 | 
						gorm.io/driver/postgres v1.3.1
 | 
				
			||||||
	gorm.io/driver/sqlite v1.3.1
 | 
						gorm.io/driver/sqlite v1.3.1
 | 
				
			||||||
	gorm.io/driver/sqlserver v1.3.1
 | 
						gorm.io/driver/sqlserver v1.3.1
 | 
				
			||||||
	gorm.io/gorm v1.23.0
 | 
						gorm.io/gorm v1.23.1
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
replace gorm.io/gorm => ../
 | 
					replace gorm.io/gorm => ../
 | 
				
			||||||
 | 
				
			|||||||
@ -2,6 +2,7 @@ package tests_test
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/google/uuid"
 | 
						"github.com/google/uuid"
 | 
				
			||||||
	"github.com/lib/pq"
 | 
						"github.com/lib/pq"
 | 
				
			||||||
@ -17,6 +18,8 @@ func TestPostgres(t *testing.T) {
 | 
				
			|||||||
		gorm.Model
 | 
							gorm.Model
 | 
				
			||||||
		Name      string         `gorm:"check:name_checker,name <> ''"`
 | 
							Name      string         `gorm:"check:name_checker,name <> ''"`
 | 
				
			||||||
		Test      uuid.UUID      `gorm:"type:uuid;not null;default:gen_random_uuid()"`
 | 
							Test      uuid.UUID      `gorm:"type:uuid;not null;default:gen_random_uuid()"`
 | 
				
			||||||
 | 
							CreatedAt time.Time      `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
 | 
				
			||||||
 | 
							UpdatedAt time.Time      `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
 | 
				
			||||||
		Things    pq.StringArray `gorm:"type:text[]"`
 | 
							Things    pq.StringArray `gorm:"type:text[]"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -48,6 +51,15 @@ func TestPostgres(t *testing.T) {
 | 
				
			|||||||
	if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
 | 
						if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
 | 
				
			||||||
		t.Errorf("No error should happen, but got %v", err)
 | 
							t.Errorf("No error should happen, but got %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						harumph.Name = "jinzhu1"
 | 
				
			||||||
 | 
						if err := DB.Save(&harumph).Error; err != nil {
 | 
				
			||||||
 | 
							t.Errorf("Failed to update date, got error %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
 | 
				
			||||||
 | 
							t.Errorf("No error should happen, but got %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Post struct {
 | 
					type Post struct {
 | 
				
			||||||
 | 
				
			|||||||
@ -19,11 +19,20 @@ type SerializerStruct struct {
 | 
				
			|||||||
	Name            []byte                 `gorm:"json"`
 | 
						Name            []byte                 `gorm:"json"`
 | 
				
			||||||
	Roles           Roles                  `gorm:"serializer:json"`
 | 
						Roles           Roles                  `gorm:"serializer:json"`
 | 
				
			||||||
	Contracts       map[string]interface{} `gorm:"serializer:json"`
 | 
						Contracts       map[string]interface{} `gorm:"serializer:json"`
 | 
				
			||||||
 | 
						JobInfo         Job                    `gorm:"type:bytes;serializer:gob"`
 | 
				
			||||||
	CreatedTime     int64                  `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
 | 
						CreatedTime     int64                  `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
 | 
				
			||||||
	EncryptedString EncryptedString
 | 
						EncryptedString EncryptedString
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type Roles []string
 | 
					type Roles []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Job struct {
 | 
				
			||||||
 | 
						Title    string
 | 
				
			||||||
 | 
						Number   int
 | 
				
			||||||
 | 
						Location string
 | 
				
			||||||
 | 
						IsIntern bool
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type EncryptedString string
 | 
					type EncryptedString string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
 | 
					func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
 | 
				
			||||||
@ -56,6 +65,12 @@ func TestSerializer(t *testing.T) {
 | 
				
			|||||||
		Contracts:       map[string]interface{}{"name": "jinzhu", "age": 10},
 | 
							Contracts:       map[string]interface{}{"name": "jinzhu", "age": 10},
 | 
				
			||||||
		EncryptedString: EncryptedString("pass"),
 | 
							EncryptedString: EncryptedString("pass"),
 | 
				
			||||||
		CreatedTime:     createdAt.Unix(),
 | 
							CreatedTime:     createdAt.Unix(),
 | 
				
			||||||
 | 
							JobInfo: Job{
 | 
				
			||||||
 | 
								Title:    "programmer",
 | 
				
			||||||
 | 
								Number:   9920,
 | 
				
			||||||
 | 
								Location: "Kenmawr",
 | 
				
			||||||
 | 
								IsIntern: false,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := DB.Create(&data).Error; err != nil {
 | 
						if err := DB.Create(&data).Error; err != nil {
 | 
				
			||||||
 | 
				
			|||||||
@ -460,16 +460,16 @@ func assertEqualSQL(t *testing.T, expected string, actually string) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func replaceQuoteInSQL(sql string) string {
 | 
					func replaceQuoteInSQL(sql string) string {
 | 
				
			||||||
	// convert single quote into double quote
 | 
						// convert single quote into double quote
 | 
				
			||||||
	sql = strings.Replace(sql, `'`, `"`, -1)
 | 
						sql = strings.ReplaceAll(sql, `'`, `"`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// convert dialect speical quote into double quote
 | 
						// convert dialect speical quote into double quote
 | 
				
			||||||
	switch DB.Dialector.Name() {
 | 
						switch DB.Dialector.Name() {
 | 
				
			||||||
	case "postgres":
 | 
						case "postgres":
 | 
				
			||||||
		sql = strings.Replace(sql, `"`, `"`, -1)
 | 
							sql = strings.ReplaceAll(sql, `"`, `"`)
 | 
				
			||||||
	case "mysql", "sqlite":
 | 
						case "mysql", "sqlite":
 | 
				
			||||||
		sql = strings.Replace(sql, "`", `"`, -1)
 | 
							sql = strings.ReplaceAll(sql, "`", `"`)
 | 
				
			||||||
	case "sqlserver":
 | 
						case "sqlserver":
 | 
				
			||||||
		sql = strings.Replace(sql, `'`, `"`, -1)
 | 
							sql = strings.ReplaceAll(sql, `'`, `"`)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return sql
 | 
						return sql
 | 
				
			||||||
 | 
				
			|||||||
@ -645,7 +645,7 @@ func TestSave(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	dryDB := DB.Session(&gorm.Session{DryRun: true})
 | 
						dryDB := DB.Session(&gorm.Session{DryRun: true})
 | 
				
			||||||
	stmt := dryDB.Save(&user).Statement
 | 
						stmt := dryDB.Save(&user).Statement
 | 
				
			||||||
	if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
 | 
						if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
 | 
				
			||||||
		t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
 | 
							t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user