From 15c0f891e1291a4e4f37eee0829df68a63752547 Mon Sep 17 00:00:00 2001 From: slockij Date: Fri, 4 Nov 2016 11:54:28 +0100 Subject: [PATCH] In some cases, time fields are created without timezone. Then, forgetting to use UTC in code can result in time related issues, and gorm.Model time fields will not have correct data. Solution - force UTC on save --- forced_utc_test.go | 86 ++++++++++++++++++++++++++++++++++++++++++++++ main.go | 14 +++++++- scope.go | 14 ++++++-- 3 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 forced_utc_test.go diff --git a/forced_utc_test.go b/forced_utc_test.go new file mode 100644 index 00000000..3e17a74a --- /dev/null +++ b/forced_utc_test.go @@ -0,0 +1,86 @@ +package gorm_test + +import ( + "os" + "testing" + "time" +) + +// table that is not created automatically +type ExternalData struct { + Id int + Time time.Time +} + +func TestForcedUTC(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { + t.Skip("Skipping this because this is to test postgres issues with timestamps without timezones") + } + db := DB.New() + db.DropTableIfExists(&ExternalData{}) + db.Exec(` + CREATE TABLE IF NOT EXISTS external_data( + id serial PRIMARY KEY, + time timestamp without time zone NOT NULL + )`) + + tm := time.Date(2000, 1, 1, 1, 0, 0, 0, time.FixedZone("test location", +7200)) + + //Test without forcing utc + elem0 := ExternalData{Time: tm} + db.Save(&elem0) + + elem := ExternalData{} + db.Find(&elem, elem0.Id) + if elem.Time.Equal(tm) { + t.Errorf("Times should not be equal (timezones)") + } + + db.Model(&elem).Update("time", tm) + elem = ExternalData{} + db.Find(&elem, elem0.Id) + if elem.Time.Equal(tm) { + t.Errorf("Times should not be equal (timezones)") + } + + cnt := 0 + + db.Model(&ExternalData{}).Where("time = ?", tm).Count(&cnt) + if cnt == 0 { + t.Errorf("Timezone is cut off, data still should be found (timezones)") + } + + db.Model(&ExternalData{}).Where("time = ?", tm.UTC()).Count(&cnt) + if cnt != 0 { + t.Errorf("UTC normalized time should not be found (timezones)") + } + + //Test with forcing utc + db.ForceUTC(true) + + elem0 = ExternalData{Time: tm} + db.Save(&elem0) + + elem = ExternalData{} + db.Find(&elem, elem0.Id) + if !elem.Time.Equal(tm) { + t.Errorf("Times should be equal (forced UTC)") + } + + db.Model(&elem).Update("time", tm) + elem = ExternalData{} + db.Find(&elem, elem0.Id) + if !elem.Time.Equal(tm) { + t.Errorf("Times should be equal (forced UTC)") + } + + db.Model(&ExternalData{}).Where("time = ?", tm).Count(&cnt) + if cnt != 1 { + t.Errorf("Record should be found (forced UTC)") + } + + db.Model(&ExternalData{}).Where("time = ?", tm.UTC()).Count(&cnt) + if cnt != 1 { + t.Errorf("UTC normalized time should be found (forcedUTC)") + } +} diff --git a/main.go b/main.go index 192dbd7c..5bb1fc94 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ type DB struct { source string values map[string]interface{} joinTableHandlers map[string]JoinTableHandler + forceUTC bool } // Open initialize a new db connection, need to import driver first, e.g: @@ -142,6 +143,17 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// ForceUTC set to `true` to force all time values to be normalized as UTC, `false` for regular behavior (default) +func (s *DB) ForceUTC(enable bool) *DB { + s.forceUTC = enable + return s +} + +// HasForcedUTC return the state of the forceUTC flag +func (s *DB) HasForcedUTC() bool { + return s.forceUTC +} + // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { modelStructsMap = newModelStructsMap() @@ -682,7 +694,7 @@ func (s *DB) GetErrors() (errors []error) { //////////////////////////////////////////////////////////////////////////////// func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error, forceUTC: s.forceUTC} for key, value := range s.values { db.values[key] = value diff --git a/scope.go b/scope.go index 4a962062..cb84dc05 100644 --- a/scope.go +++ b/scope.go @@ -256,12 +256,12 @@ func (scope *Scope) AddToVars(value interface{}) string { if expr, ok := value.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + exp = strings.Replace(exp, "?", scope.AddToVars(scope.toUTC(arg)), 1) } return exp } - scope.SQLVars = append(scope.SQLVars, value) + scope.SQLVars = append(scope.SQLVars, scope.toUTC(value)) return scope.Dialect().BindVar(len(scope.SQLVars)) } @@ -1280,3 +1280,13 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { } return nil } + +func (scope *Scope) toUTC(v interface{}) interface{} { + if !scope.db.HasForcedUTC() { + return v + } + if tm, ok := v.(time.Time); ok { + return tm.UTC() + } + return v +}