Merge branch 'go-gorm:master' into master
This commit is contained in:
commit
5d84ac618c
@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(100)
|
||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||
@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
|
@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
|
||||
}
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func checkMissingWhereConditions(db *gorm.DB) {
|
||||
if !db.AllowGlobalUpdate && db.Error == nil {
|
||||
where, withCondition := db.Statement.Clauses["WHERE"]
|
||||
if withCondition {
|
||||
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
|
||||
whereClause, _ := where.Expression.(clause.Where)
|
||||
withCondition = len(whereClause.Exprs) > 1
|
||||
}
|
||||
}
|
||||
if !withCondition {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
return
|
||||
}
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||
@ -232,10 +226,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
||||
} else if field.GORMDataType == schema.Time {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
||||
} else {
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||
} else {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -264,10 +258,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
value = stmt.DB.NowFunc().UnixNano()
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||
} else if field.GORMDataType == schema.Time {
|
||||
value = stmt.DB.NowFunc()
|
||||
} else {
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
value = stmt.DB.NowFunc().Unix()
|
||||
} else {
|
||||
value = stmt.DB.NowFunc()
|
||||
}
|
||||
isZero = false
|
||||
}
|
||||
|
@ -75,10 +75,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
case reflect.Bool:
|
||||
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||
case reflect.String:
|
||||
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
default:
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
@ -94,7 +94,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
case float64, float32:
|
||||
vars[idx] = fmt.Sprintf("%.6f", v)
|
||||
case string:
|
||||
vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
|
||||
default:
|
||||
rv := reflect.ValueOf(v)
|
||||
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
@ -111,7 +111,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
return
|
||||
}
|
||||
}
|
||||
vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) {
|
||||
}
|
||||
|
||||
func format(v []byte, escaper string) string {
|
||||
return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
|
||||
return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper
|
||||
}
|
||||
|
||||
func TestExplainSQL(t *testing.T) {
|
||||
|
@ -293,6 +293,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
}
|
||||
|
||||
if field.GORMDataType == "" {
|
||||
field.GORMDataType = field.DataType
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||
switch DataType(strings.ToLower(val)) {
|
||||
case Bool, Int, Uint, Float, String, Time, Bytes:
|
||||
@ -302,10 +306,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
}
|
||||
}
|
||||
|
||||
if field.GORMDataType == "" {
|
||||
field.GORMDataType = field.DataType
|
||||
}
|
||||
|
||||
if field.Size == 0 {
|
||||
switch reflect.Indirect(fieldValue).Kind() {
|
||||
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
|
||||
|
@ -3,7 +3,6 @@ package schema
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
@ -95,7 +94,7 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||
h.Write([]byte(formattedName))
|
||||
bs := h.Sum(nil)
|
||||
|
||||
formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8]
|
||||
formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8]
|
||||
}
|
||||
return formattedName
|
||||
}
|
||||
@ -174,7 +173,7 @@ func (ns NamingStrategy) toDBName(name string) string {
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) toSchemaName(name string) string {
|
||||
result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1)
|
||||
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
|
||||
for _, initialism := range commonInitialisms {
|
||||
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
|
||||
}
|
||||
|
@ -193,7 +193,7 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) {
|
||||
ns := NamingStrategy{}
|
||||
|
||||
formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString")
|
||||
if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" {
|
||||
if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" {
|
||||
t.Errorf("invalid formatted name generated, got %v", formattedName)
|
||||
}
|
||||
}
|
||||
|
@ -576,3 +576,39 @@ func TestHasManySameForeignKey(t *testing.T) {
|
||||
References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}},
|
||||
})
|
||||
}
|
||||
|
||||
type Author struct {
|
||||
gorm.Model
|
||||
}
|
||||
|
||||
type Book struct {
|
||||
gorm.Model
|
||||
Author Author
|
||||
AuthorID uint
|
||||
}
|
||||
|
||||
func (Book) TableName() string {
|
||||
return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name"
|
||||
}
|
||||
|
||||
func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) {
|
||||
s, err := schema.Parse(
|
||||
&Book{},
|
||||
&sync.Map{},
|
||||
schema.NamingStrategy{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse schema")
|
||||
}
|
||||
|
||||
expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec"
|
||||
constraint := s.Relationships.Relations["Author"].ParseConstraint()
|
||||
|
||||
if constraint.Name != expectedConstraintName {
|
||||
t.Fatalf(
|
||||
"expected constraint name %s, got %s",
|
||||
expectedConstraintName,
|
||||
constraint.Name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,12 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
@ -32,6 +33,7 @@ func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
|
||||
func init() {
|
||||
RegisterSerializer("json", JSONSerializer{})
|
||||
RegisterSerializer("unixtime", UnixSecondSerializer{})
|
||||
RegisterSerializer("gob", GobSerializer{})
|
||||
}
|
||||
|
||||
// Serializer field value serializer
|
||||
@ -83,7 +85,7 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue))
|
||||
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(bytes, fieldValue.Interface())
|
||||
@ -123,3 +125,33 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GobSerializer gob serializer
|
||||
type GobSerializer struct {
|
||||
}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
fieldValue := reflect.New(field.FieldType)
|
||||
|
||||
if dbValue != nil {
|
||||
var bytesValue []byte
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytesValue = v
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
|
||||
}
|
||||
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
||||
err = decoder.Decode(fieldValue.Interface())
|
||||
}
|
||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||
return
|
||||
}
|
||||
|
||||
// Value implements serializer interface
|
||||
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
err := gob.NewEncoder(buf).Encode(fieldValue)
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
|
||||
|
||||
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
||||
if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok {
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
}
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok {
|
||||
stmt.DB.AddError(ErrMissingWhereClause)
|
||||
} else {
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
}
|
||||
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
stmt.AddClauseIfNotExists(clause.Update{})
|
||||
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
||||
}
|
||||
|
@ -9,11 +9,11 @@ require (
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/mattn/go-sqlite3 v1.14.11 // indirect
|
||||
golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
|
||||
gorm.io/driver/mysql v1.3.1
|
||||
gorm.io/driver/mysql v1.3.2
|
||||
gorm.io/driver/postgres v1.3.1
|
||||
gorm.io/driver/sqlite v1.3.1
|
||||
gorm.io/driver/sqlserver v1.3.1
|
||||
gorm.io/gorm v1.23.0
|
||||
gorm.io/gorm v1.23.1
|
||||
)
|
||||
|
||||
replace gorm.io/gorm => ../
|
||||
|
@ -2,6 +2,7 @@ package tests_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
@ -15,9 +16,11 @@ func TestPostgres(t *testing.T) {
|
||||
|
||||
type Harumph struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"check:name_checker,name <> ''"`
|
||||
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
|
||||
Things pq.StringArray `gorm:"type:text[]"`
|
||||
Name string `gorm:"check:name_checker,name <> ''"`
|
||||
Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"`
|
||||
CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||
UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"`
|
||||
Things pq.StringArray `gorm:"type:text[]"`
|
||||
}
|
||||
|
||||
if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil {
|
||||
@ -48,6 +51,15 @@ func TestPostgres(t *testing.T) {
|
||||
if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
|
||||
harumph.Name = "jinzhu1"
|
||||
if err := DB.Save(&harumph).Error; err != nil {
|
||||
t.Errorf("Failed to update date, got error %v", err)
|
||||
}
|
||||
|
||||
if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" {
|
||||
t.Errorf("No error should happen, but got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
|
@ -19,11 +19,20 @@ type SerializerStruct struct {
|
||||
Name []byte `gorm:"json"`
|
||||
Roles Roles `gorm:"serializer:json"`
|
||||
Contracts map[string]interface{} `gorm:"serializer:json"`
|
||||
JobInfo Job `gorm:"type:bytes;serializer:gob"`
|
||||
CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type
|
||||
EncryptedString EncryptedString
|
||||
}
|
||||
|
||||
type Roles []string
|
||||
|
||||
type Job struct {
|
||||
Title string
|
||||
Number int
|
||||
Location string
|
||||
IsIntern bool
|
||||
}
|
||||
|
||||
type EncryptedString string
|
||||
|
||||
func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
@ -56,6 +65,12 @@ func TestSerializer(t *testing.T) {
|
||||
Contracts: map[string]interface{}{"name": "jinzhu", "age": 10},
|
||||
EncryptedString: EncryptedString("pass"),
|
||||
CreatedTime: createdAt.Unix(),
|
||||
JobInfo: Job{
|
||||
Title: "programmer",
|
||||
Number: 9920,
|
||||
Location: "Kenmawr",
|
||||
IsIntern: false,
|
||||
},
|
||||
}
|
||||
|
||||
if err := DB.Create(&data).Error; err != nil {
|
||||
|
@ -460,16 +460,16 @@ func assertEqualSQL(t *testing.T, expected string, actually string) {
|
||||
|
||||
func replaceQuoteInSQL(sql string) string {
|
||||
// convert single quote into double quote
|
||||
sql = strings.Replace(sql, `'`, `"`, -1)
|
||||
sql = strings.ReplaceAll(sql, `'`, `"`)
|
||||
|
||||
// convert dialect speical quote into double quote
|
||||
switch DB.Dialector.Name() {
|
||||
case "postgres":
|
||||
sql = strings.Replace(sql, `"`, `"`, -1)
|
||||
sql = strings.ReplaceAll(sql, `"`, `"`)
|
||||
case "mysql", "sqlite":
|
||||
sql = strings.Replace(sql, "`", `"`, -1)
|
||||
sql = strings.ReplaceAll(sql, "`", `"`)
|
||||
case "sqlserver":
|
||||
sql = strings.Replace(sql, `'`, `"`, -1)
|
||||
sql = strings.ReplaceAll(sql, `'`, `"`)
|
||||
}
|
||||
|
||||
return sql
|
||||
|
@ -645,7 +645,7 @@ func TestSave(t *testing.T) {
|
||||
|
||||
dryDB := DB.Session(&gorm.Session{DryRun: true})
|
||||
stmt := dryDB.Save(&user).Statement
|
||||
if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
|
||||
if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) {
|
||||
t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String())
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user