diff --git a/callbacks/delete.go b/callbacks/delete.go index 1fb5261c..84f446a3 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) @@ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) { } db.Statement.AddClauseIfNotExists(clause.From{}) - } - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { ok, mode := hasReturning(db, supportReturning) diff --git a/callbacks/helper.go b/callbacks/helper.go index a59e1880..a5eb047e 100644 --- a/callbacks/helper.go +++ b/callbacks/helper.go @@ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { } return false, 0 } + +func checkMissingWhereConditions(db *gorm.DB) { + if !db.AllowGlobalUpdate && db.Error == nil { + where, withCondition := db.Statement.Clauses["WHERE"] + if withCondition { + if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { + whereClause, _ := where.Expression.(clause.Where) + withCondition = len(whereClause.Exprs) > 1 + } + } + if !withCondition { + db.AddError(gorm.ErrMissingWhereClause) + } + return + } +} diff --git a/callbacks/update.go b/callbacks/update.go index 4f07ca30..da03261e 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) { return } + if db.Statement.Schema != nil { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } + } + if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) @@ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) { return } - } - - if db.Statement.Schema != nil { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) - } - } - - if db.Statement.SQL.Len() == 0 { db.Statement.Build(db.Statement.BuildClauses...) } - if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { @@ -232,10 +226,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) - } else if field.GORMDataType == schema.Time { - set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) - } else { + } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) + } else { + set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) } } } @@ -264,10 +258,10 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { value = stmt.DB.NowFunc().UnixNano() / 1e6 - } else if field.GORMDataType == schema.Time { - value = stmt.DB.NowFunc() - } else { + } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() + } else { + value = stmt.DB.NowFunc() } isZero = false } diff --git a/logger/sql.go b/logger/sql.go index 04a2dbd4..c8b194c3 100644 --- a/logger/sql.go +++ b/logger/sql.go @@ -75,10 +75,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case reflect.Bool: vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) case reflect.String: - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper default: if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { - vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper } else { vars[idx] = nullStr } @@ -94,7 +94,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a case float64, float32: vars[idx] = fmt.Sprintf("%.6f", v) case string: - vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { @@ -111,7 +111,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a return } } - vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper + vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper } } } diff --git a/logger/sql_test.go b/logger/sql_test.go index 71aa841a..c5b181a9 100644 --- a/logger/sql_test.go +++ b/logger/sql_test.go @@ -31,7 +31,7 @@ func (s ExampleStruct) Value() (driver.Value, error) { } func format(v []byte, escaper string) string { - return escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper + return escaper + strings.ReplaceAll(string(v), escaper, "\\"+escaper) + escaper } func TestExplainSQL(t *testing.T) { diff --git a/schema/field.go b/schema/field.go index 319f3693..8c793f93 100644 --- a/schema/field.go +++ b/schema/field.go @@ -293,6 +293,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } + if field.GORMDataType == "" { + field.GORMDataType = field.DataType + } + if val, ok := field.TagSettings["TYPE"]; ok { switch DataType(strings.ToLower(val)) { case Bool, Int, Uint, Float, String, Time, Bytes: @@ -302,10 +306,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { } } - if field.GORMDataType == "" { - field.GORMDataType = field.DataType - } - if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: diff --git a/schema/naming.go b/schema/naming.go index a4e3a75b..47a2b363 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -3,7 +3,6 @@ package schema import ( "crypto/sha1" "encoding/hex" - "fmt" "regexp" "strings" "unicode/utf8" @@ -95,7 +94,7 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = fmt.Sprintf("%v%v%v", prefix, table, name)[0:56] + hex.EncodeToString(bs)[:8] + formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] } return formattedName } @@ -174,7 +173,7 @@ func (ns NamingStrategy) toDBName(name string) string { } func (ns NamingStrategy) toSchemaName(name string) string { - result := strings.Replace(strings.Title(strings.Replace(name, "_", " ", -1)), " ", "", -1) + result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "") for _, initialism := range commonInitialisms { result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") } diff --git a/schema/naming_test.go b/schema/naming_test.go index 1fdab9a0..3f598c33 100644 --- a/schema/naming_test.go +++ b/schema/naming_test.go @@ -193,7 +193,7 @@ func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { ns := NamingStrategy{} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") - if formattedName != "prefixtablethisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLo180f2c67" { + if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { t.Errorf("invalid formatted name generated, got %v", formattedName) } } diff --git a/schema/relationship_test.go b/schema/relationship_test.go index e2cf11a9..40ffc324 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -576,3 +576,39 @@ func TestHasManySameForeignKey(t *testing.T) { References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, }) } + +type Author struct { + gorm.Model +} + +type Book struct { + gorm.Model + Author Author + AuthorID uint +} + +func (Book) TableName() string { + return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name" +} + +func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { + s, err := schema.Parse( + &Book{}, + &sync.Map{}, + schema.NamingStrategy{}, + ) + if err != nil { + t.Fatalf("Failed to parse schema") + } + + expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec" + constraint := s.Relationships.Relations["Author"].ParseConstraint() + + if constraint.Name != expectedConstraintName { + t.Fatalf( + "expected constraint name %s, got %s", + expectedConstraintName, + constraint.Name, + ) + } +} diff --git a/schema/serializer.go b/schema/serializer.go index 68597538..09da6d9e 100644 --- a/schema/serializer.go +++ b/schema/serializer.go @@ -1,11 +1,12 @@ package schema import ( + "bytes" "context" "database/sql" "database/sql/driver" + "encoding/gob" "encoding/json" - "errors" "fmt" "reflect" "strings" @@ -32,6 +33,7 @@ func GetSerializer(name string) (serializer SerializerInterface, ok bool) { func init() { RegisterSerializer("json", JSONSerializer{}) RegisterSerializer("unixtime", UnixSecondSerializer{}) + RegisterSerializer("gob", GobSerializer{}) } // Serializer field value serializer @@ -83,7 +85,7 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, case string: bytes = []byte(v) default: - return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", dbValue)) + return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) } err = json.Unmarshal(bytes, fieldValue.Interface()) @@ -123,3 +125,33 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect } return } + +// GobSerializer gob serializer +type GobSerializer struct { +} + +// Scan implements serializer interface +func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { + fieldValue := reflect.New(field.FieldType) + + if dbValue != nil { + var bytesValue []byte + switch v := dbValue.(type) { + case []byte: + bytesValue = v + default: + return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) + } + decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) + err = decoder.Decode(fieldValue.Interface()) + } + field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) + return +} + +// Value implements serializer interface +func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { + buf := new(bytes.Buffer) + err := gob.NewEncoder(buf).Encode(fieldValue) + return buf.Bytes(), err +} diff --git a/soft_delete.go b/soft_delete.go index ba6d2118..6d646288 100644 --- a/soft_delete.go +++ b/soft_delete.go @@ -104,9 +104,7 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { - if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } + SoftDeleteQueryClause(sd).ModifyStatement(stmt) } } @@ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { } } - if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { - stmt.DB.AddError(ErrMissingWhereClause) - } else { - SoftDeleteQueryClause(sd).ModifyStatement(stmt) - } - + SoftDeleteQueryClause(sd).ModifyStatement(stmt) stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build(stmt.DB.Callback().Update().Clauses...) } diff --git a/tests/go.mod b/tests/go.mod index 1c1fb238..cefe6f96 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,11 +9,11 @@ require ( github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect - gorm.io/driver/mysql v1.3.1 + gorm.io/driver/mysql v1.3.2 gorm.io/driver/postgres v1.3.1 gorm.io/driver/sqlite v1.3.1 gorm.io/driver/sqlserver v1.3.1 - gorm.io/gorm v1.23.0 + gorm.io/gorm v1.23.1 ) replace gorm.io/gorm => ../ diff --git a/tests/postgres_test.go b/tests/postgres_test.go index 85671864..418b713e 100644 --- a/tests/postgres_test.go +++ b/tests/postgres_test.go @@ -2,6 +2,7 @@ package tests_test import ( "testing" + "time" "github.com/google/uuid" "github.com/lib/pq" @@ -15,9 +16,11 @@ func TestPostgres(t *testing.T) { type Harumph struct { gorm.Model - Name string `gorm:"check:name_checker,name <> ''"` - Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` - Things pq.StringArray `gorm:"type:text[]"` + Name string `gorm:"check:name_checker,name <> ''"` + Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` + CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` + Things pq.StringArray `gorm:"type:text[]"` } if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { @@ -48,6 +51,15 @@ func TestPostgres(t *testing.T) { if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } + + harumph.Name = "jinzhu1" + if err := DB.Save(&harumph).Error; err != nil { + t.Errorf("Failed to update date, got error %v", err) + } + + if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { + t.Errorf("No error should happen, but got %v", err) + } } type Post struct { diff --git a/tests/serializer_test.go b/tests/serializer_test.go index 3ed733d9..a8a4e28f 100644 --- a/tests/serializer_test.go +++ b/tests/serializer_test.go @@ -19,11 +19,20 @@ type SerializerStruct struct { Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` Contracts map[string]interface{} `gorm:"serializer:json"` + JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:time"` // store time in db, use int as field type EncryptedString EncryptedString } type Roles []string + +type Job struct { + Title string + Number int + Location string + IsIntern bool +} + type EncryptedString string func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -56,6 +65,12 @@ func TestSerializer(t *testing.T) { Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, EncryptedString: EncryptedString("pass"), CreatedTime: createdAt.Unix(), + JobInfo: Job{ + Title: "programmer", + Number: 9920, + Location: "Kenmawr", + IsIntern: false, + }, } if err := DB.Create(&data).Error; err != nil { diff --git a/tests/sql_builder_test.go b/tests/sql_builder_test.go index 897f687f..bc917c32 100644 --- a/tests/sql_builder_test.go +++ b/tests/sql_builder_test.go @@ -460,16 +460,16 @@ func assertEqualSQL(t *testing.T, expected string, actually string) { func replaceQuoteInSQL(sql string) string { // convert single quote into double quote - sql = strings.Replace(sql, `'`, `"`, -1) + sql = strings.ReplaceAll(sql, `'`, `"`) // convert dialect speical quote into double quote switch DB.Dialector.Name() { case "postgres": - sql = strings.Replace(sql, `"`, `"`, -1) + sql = strings.ReplaceAll(sql, `"`, `"`) case "mysql", "sqlite": - sql = strings.Replace(sql, "`", `"`, -1) + sql = strings.ReplaceAll(sql, "`", `"`) case "sqlserver": - sql = strings.Replace(sql, `'`, `"`, -1) + sql = strings.ReplaceAll(sql, `'`, `"`) } return sql diff --git a/tests/update_test.go b/tests/update_test.go index b471ba9b..41ea5d27 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -645,7 +645,7 @@ func TestSave(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement - if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { + if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) }