Respect update permission for OnConflict Create
This commit is contained in:
		
							parent
							
								
									0329b800b0
								
							
						
					
					
						commit
						2ec7043818
					
				| @ -37,7 +37,6 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 
 | 
 | ||||||
| 	return func(db *gorm.DB) { | 	return func(db *gorm.DB) { | ||||||
| 		if db.Error != nil { | 		if db.Error != nil { | ||||||
| 			// maybe record logger TODO
 |  | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| @ -64,11 +63,9 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			db.RowsAffected, _ = result.RowsAffected() | 			db.RowsAffected, _ = result.RowsAffected() | ||||||
| 			if !(db.RowsAffected > 0) { |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 
 | 
 | ||||||
| 			if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | 			if db.RowsAffected != 0 && db.Statement.Schema != nil && | ||||||
|  | 				db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { | ||||||
| 				if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { | 				if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { | ||||||
| 					switch db.Statement.ReflectValue.Kind() { | 					switch db.Statement.ReflectValue.Kind() { | ||||||
| 					case reflect.Slice, reflect.Array: | 					case reflect.Slice, reflect.Array: | ||||||
| @ -107,7 +104,6 @@ func Create(config *Config) func(db *gorm.DB) { | |||||||
| 					db.AddError(err) | 					db.AddError(err) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 
 |  | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| @ -349,11 +345,15 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { | |||||||
| 	if c, ok := stmt.Clauses["ON CONFLICT"]; ok { | 	if c, ok := stmt.Clauses["ON CONFLICT"]; ok { | ||||||
| 		if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { | 		if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { | ||||||
| 			if stmt.Schema != nil && len(values.Columns) > 1 { | 			if stmt.Schema != nil && len(values.Columns) > 1 { | ||||||
|  | 				selectColumns, restricted := stmt.SelectAndOmitColumns(true, true) | ||||||
|  | 
 | ||||||
| 				columns := make([]string, 0, len(values.Columns)-1) | 				columns := make([]string, 0, len(values.Columns)-1) | ||||||
| 				for _, column := range values.Columns { | 				for _, column := range values.Columns { | ||||||
| 					if field := stmt.Schema.LookUpField(column.Name); field != nil { | 					if field := stmt.Schema.LookUpField(column.Name); field != nil { | ||||||
| 						if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { | 						if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { | ||||||
| 							columns = append(columns, column.Name) | 							if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { | ||||||
|  | 								columns = append(columns, column.Name) | ||||||
|  | 							} | ||||||
| 						} | 						} | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
|  | |||||||
| @ -1,9 +1,11 @@ | |||||||
| package tests_test | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"regexp" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"gorm.io/gorm" | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| 	. "gorm.io/gorm/utils/tests" | 	. "gorm.io/gorm/utils/tests" | ||||||
| ) | ) | ||||||
| @ -51,6 +53,19 @@ func TestUpsert(t *testing.T) { | |||||||
| 	if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { | 	if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { | ||||||
| 		t.Fatalf("failed to upsert, got name %v", result.Name) | 		t.Fatalf("failed to upsert, got name %v", result.Name) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	if name := DB.Dialector.Name(); name != "sqlserver" { | ||||||
|  | 		type RestrictedLanguage struct { | ||||||
|  | 			Code string `gorm:"primarykey"` | ||||||
|  | 			Name string | ||||||
|  | 			Lang string `gorm:"<-:create"` | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) | ||||||
|  | 		if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.[^\w]*$`).MatchString(r.Statement.SQL.String()) { | ||||||
|  | 			t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestUpsertSlice(t *testing.T) { | func TestUpsertSlice(t *testing.T) { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu