diff --git a/forced_utc_test.go b/forced_utc_test.go new file mode 100644 index 00000000..3e17a74a --- /dev/null +++ b/forced_utc_test.go @@ -0,0 +1,86 @@ +package gorm_test + +import ( + "os" + "testing" + "time" +) + +// table that is not created automatically +type ExternalData struct { + Id int + Time time.Time +} + +func TestForcedUTC(t *testing.T) { + if dialect := os.Getenv("GORM_DIALECT"); dialect != "postgres" { + t.Skip("Skipping this because this is to test postgres issues with timestamps without timezones") + } + db := DB.New() + db.DropTableIfExists(&ExternalData{}) + db.Exec(` + CREATE TABLE IF NOT EXISTS external_data( + id serial PRIMARY KEY, + time timestamp without time zone NOT NULL + )`) + + tm := time.Date(2000, 1, 1, 1, 0, 0, 0, time.FixedZone("test location", +7200)) + + //Test without forcing utc + elem0 := ExternalData{Time: tm} + db.Save(&elem0) + + elem := ExternalData{} + db.Find(&elem, elem0.Id) + if elem.Time.Equal(tm) { + t.Errorf("Times should not be equal (timezones)") + } + + db.Model(&elem).Update("time", tm) + elem = ExternalData{} + db.Find(&elem, elem0.Id) + if elem.Time.Equal(tm) { + t.Errorf("Times should not be equal (timezones)") + } + + cnt := 0 + + db.Model(&ExternalData{}).Where("time = ?", tm).Count(&cnt) + if cnt == 0 { + t.Errorf("Timezone is cut off, data still should be found (timezones)") + } + + db.Model(&ExternalData{}).Where("time = ?", tm.UTC()).Count(&cnt) + if cnt != 0 { + t.Errorf("UTC normalized time should not be found (timezones)") + } + + //Test with forcing utc + db.ForceUTC(true) + + elem0 = ExternalData{Time: tm} + db.Save(&elem0) + + elem = ExternalData{} + db.Find(&elem, elem0.Id) + if !elem.Time.Equal(tm) { + t.Errorf("Times should be equal (forced UTC)") + } + + db.Model(&elem).Update("time", tm) + elem = ExternalData{} + db.Find(&elem, elem0.Id) + if !elem.Time.Equal(tm) { + t.Errorf("Times should be equal (forced UTC)") + } + + db.Model(&ExternalData{}).Where("time = ?", tm).Count(&cnt) + if cnt != 1 { + t.Errorf("Record should be found (forced UTC)") + } + + db.Model(&ExternalData{}).Where("time = ?", tm.UTC()).Count(&cnt) + if cnt != 1 { + t.Errorf("UTC normalized time should be found (forcedUTC)") + } +} diff --git a/main.go b/main.go index cf88e9c9..8b9d4a88 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ type DB struct { source string values map[string]interface{} joinTableHandlers map[string]JoinTableHandler + forceUTC bool } // Open initialize a new db connection, need to import driver first, e.g: @@ -142,6 +143,17 @@ func (s *DB) LogMode(enable bool) *DB { return s } +// ForceUTC set to `true` to force all time values to be normalized as UTC, `false` for regular behavior (default) +func (s *DB) ForceUTC(enable bool) *DB { + s.forceUTC = enable + return s +} + +// HasForcedUTC return the state of the forceUTC flag +func (s *DB) HasForcedUTC() bool { + return s.forceUTC +} + // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { modelStructsMap = newModelStructsMap() @@ -682,7 +694,7 @@ func (s *DB) GetErrors() []error { //////////////////////////////////////////////////////////////////////////////// func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} + db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error, forceUTC: s.forceUTC} for key, value := range s.values { db.values[key] = value diff --git a/scope.go b/scope.go index 4a962062..cb84dc05 100644 --- a/scope.go +++ b/scope.go @@ -256,12 +256,12 @@ func (scope *Scope) AddToVars(value interface{}) string { if expr, ok := value.(*expr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + exp = strings.Replace(exp, "?", scope.AddToVars(scope.toUTC(arg)), 1) } return exp } - scope.SQLVars = append(scope.SQLVars, value) + scope.SQLVars = append(scope.SQLVars, scope.toUTC(value)) return scope.Dialect().BindVar(len(scope.SQLVars)) } @@ -1280,3 +1280,13 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { } return nil } + +func (scope *Scope) toUTC(v interface{}) interface{} { + if !scope.db.HasForcedUTC() { + return v + } + if tm, ok := v.(time.Time); ok { + return tm.UTC() + } + return v +}