Update updated_at when upserting with Create OnConflict
This commit is contained in:
		
							parent
							
								
									12bbde89e6
								
							
						
					
					
						commit
						da16a8aac6
					
				| @ -227,6 +227,8 @@ func AfterCreate(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| // ConvertToCreateValues convert to create values
 | // ConvertToCreateValues convert to create values
 | ||||||
| func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | ||||||
|  | 	curTime := stmt.DB.NowFunc() | ||||||
|  | 
 | ||||||
| 	switch value := stmt.Dest.(type) { | 	switch value := stmt.Dest.(type) { | ||||||
| 	case map[string]interface{}: | 	case map[string]interface{}: | ||||||
| 		values = ConvertMapToValuesForCreate(stmt, value) | 		values = ConvertMapToValuesForCreate(stmt, value) | ||||||
| @ -240,7 +242,6 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | |||||||
| 		var ( | 		var ( | ||||||
| 			selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) | 			selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) | ||||||
| 			_, updateTrackTime        = stmt.Get("gorm:update_track_time") | 			_, updateTrackTime        = stmt.Get("gorm:update_track_time") | ||||||
| 			curTime                   = stmt.DB.NowFunc() |  | ||||||
| 			isZero                    bool | 			isZero                    bool | ||||||
| 		) | 		) | ||||||
| 		stmt.Settings.Delete("gorm:update_track_time") | 		stmt.Settings.Delete("gorm:update_track_time") | ||||||
| @ -352,13 +353,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | |||||||
| 					if field := stmt.Schema.LookUpField(column.Name); field != nil { | 					if field := stmt.Schema.LookUpField(column.Name); field != nil { | ||||||
| 						if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | 						if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||||
| 							if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { | 							if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { | ||||||
| 								columns = append(columns, column.Name) | 								if field.AutoUpdateTime > 0 { | ||||||
|  | 									assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} | ||||||
|  | 									switch field.AutoUpdateTime { | ||||||
|  | 									case schema.UnixNanosecond: | ||||||
|  | 										assignment.Value = curTime.UnixNano() | ||||||
|  | 									case schema.UnixMillisecond: | ||||||
|  | 										assignment.Value = curTime.UnixNano() / 1e6 | ||||||
|  | 									case schema.UnixSecond: | ||||||
|  | 										assignment.Value = curTime.Unix() | ||||||
|  | 									} | ||||||
|  | 
 | ||||||
|  | 									onConflict.DoUpdates = append(onConflict.DoUpdates, assignment) | ||||||
|  | 								} else { | ||||||
|  | 									columns = append(columns, column.Name) | ||||||
|  | 								} | ||||||
| 							} | 							} | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				onConflict.DoUpdates = clause.AssignmentColumns(columns) | 				onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) | ||||||
| 
 | 
 | ||||||
| 				// use primary fields as default OnConflict columns
 | 				// use primary fields as default OnConflict columns
 | ||||||
| 				if len(onConflict.Columns) == 0 { | 				if len(onConflict.Columns) == 0 { | ||||||
|  | |||||||
| @ -21,9 +21,10 @@ type TimeType int64 | |||||||
| var TimeReflectType = reflect.TypeOf(time.Time{}) | var TimeReflectType = reflect.TypeOf(time.Time{}) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| 	UnixSecond      TimeType = 1 | 	UnixTime        TimeType = 1 | ||||||
| 	UnixMillisecond TimeType = 2 | 	UnixSecond      TimeType = 2 | ||||||
| 	UnixNanosecond  TimeType = 3 | 	UnixMillisecond TimeType = 3 | ||||||
|  | 	UnixNanosecond  TimeType = 4 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
| @ -251,7 +252,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { | 	if v, ok := field.TagSettings["AUTOCREATETIME"]; ok || (field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { | ||||||
| 		if strings.ToUpper(v) == "NANO" { | 		if field.DataType == Time { | ||||||
|  | 			field.AutoCreateTime = UnixTime | ||||||
|  | 		} else if strings.ToUpper(v) == "NANO" { | ||||||
| 			field.AutoCreateTime = UnixNanosecond | 			field.AutoCreateTime = UnixNanosecond | ||||||
| 		} else if strings.ToUpper(v) == "MILLI" { | 		} else if strings.ToUpper(v) == "MILLI" { | ||||||
| 			field.AutoCreateTime = UnixMillisecond | 			field.AutoCreateTime = UnixMillisecond | ||||||
| @ -261,7 +264,9 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { | 	if v, ok := field.TagSettings["AUTOUPDATETIME"]; ok || (field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { | ||||||
| 		if strings.ToUpper(v) == "NANO" { | 		if field.DataType == Time { | ||||||
|  | 			field.AutoUpdateTime = UnixTime | ||||||
|  | 		} else if strings.ToUpper(v) == "NANO" { | ||||||
| 			field.AutoUpdateTime = UnixNanosecond | 			field.AutoUpdateTime = UnixNanosecond | ||||||
| 		} else if strings.ToUpper(v) == "MILLI" { | 		} else if strings.ToUpper(v) == "MILLI" { | ||||||
| 			field.AutoUpdateTime = UnixMillisecond | 			field.AutoUpdateTime = UnixMillisecond | ||||||
|  | |||||||
| @ -66,6 +66,26 @@ func TestUpsert(t *testing.T) { | |||||||
| 			t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) | 			t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	var user = *GetUser("upsert_on_conflict", Config{}) | ||||||
|  | 	user.Age = 20 | ||||||
|  | 	if err := DB.Create(&user).Error; err != nil { | ||||||
|  | 		t.Errorf("failed to create user, got error %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var user2 User | ||||||
|  | 	DB.First(&user2, user.ID) | ||||||
|  | 	user2.Age = 30 | ||||||
|  | 	time.Sleep(time.Second) | ||||||
|  | 	if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil { | ||||||
|  | 		t.Fatalf("failed to onconflict create user, got error %v", err) | ||||||
|  | 	} else { | ||||||
|  | 		var user3 User | ||||||
|  | 		DB.First(&user3, user.ID) | ||||||
|  | 		if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() { | ||||||
|  | 			t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestUpsertSlice(t *testing.T) { | func TestUpsertSlice(t *testing.T) { | ||||||
| @ -152,29 +172,29 @@ func TestUpsertWithSave(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"}
 | 	lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} | ||||||
| 	// if err := DB.Save(&lang).Error; err != nil {
 | 	if err := DB.Save(&lang).Error; err != nil { | ||||||
| 	// 	t.Errorf("Failed to create, got error %v", err)
 | 		t.Errorf("Failed to create, got error %v", err) | ||||||
| 	// }
 | 	} | ||||||
| 
 | 
 | ||||||
| 	// var result Language
 | 	var result Language | ||||||
| 	// if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil {
 | 	if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { | ||||||
| 	// 	t.Errorf("Failed to query lang, got error %v", err)
 | 		t.Errorf("Failed to query lang, got error %v", err) | ||||||
| 	// } else {
 | 	} else { | ||||||
| 	// 	AssertEqual(t, result, lang)
 | 		AssertEqual(t, result, lang) | ||||||
| 	// }
 | 	} | ||||||
| 
 | 
 | ||||||
| 	// lang.Name += "_new"
 | 	lang.Name += "_new" | ||||||
| 	// if err := DB.Save(&lang).Error; err != nil {
 | 	if err := DB.Save(&lang).Error; err != nil { | ||||||
| 	// 	t.Errorf("Failed to create, got error %v", err)
 | 		t.Errorf("Failed to create, got error %v", err) | ||||||
| 	// }
 | 	} | ||||||
| 
 | 
 | ||||||
| 	// var result2 Language
 | 	var result2 Language | ||||||
| 	// if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil {
 | 	if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { | ||||||
| 	// 	t.Errorf("Failed to query lang, got error %v", err)
 | 		t.Errorf("Failed to query lang, got error %v", err) | ||||||
| 	// } else {
 | 	} else { | ||||||
| 	// 	AssertEqual(t, result2, lang)
 | 		AssertEqual(t, result2, lang) | ||||||
| 	// }
 | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestFindOrInitialize(t *testing.T) { | func TestFindOrInitialize(t *testing.T) { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user