Refactor check missing where condition
This commit is contained in:
		
							parent
							
								
									3741f258d0
								
							
						
					
					
						commit
						6a18a15c93
					
				| @ -118,6 +118,12 @@ func Delete(config *Config) func(db *gorm.DB) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		if db.Statement.Schema != nil { | ||||||
|  | 			for _, c := range db.Statement.Schema.DeleteClauses { | ||||||
|  | 				db.Statement.AddClause(c) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		if db.Statement.SQL.Len() == 0 { | 		if db.Statement.SQL.Len() == 0 { | ||||||
| 			db.Statement.SQL.Grow(100) | 			db.Statement.SQL.Grow(100) | ||||||
| 			db.Statement.AddClauseIfNotExists(clause.Delete{}) | 			db.Statement.AddClauseIfNotExists(clause.Delete{}) | ||||||
| @ -141,22 +147,11 @@ func Delete(config *Config) func(db *gorm.DB) { | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			db.Statement.AddClauseIfNotExists(clause.From{}) | 			db.Statement.AddClauseIfNotExists(clause.From{}) | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
| 		if db.Statement.Schema != nil { |  | ||||||
| 			for _, c := range db.Statement.Schema.DeleteClauses { |  | ||||||
| 				db.Statement.AddClause(c) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if db.Statement.SQL.Len() == 0 { |  | ||||||
| 			db.Statement.Build(db.Statement.BuildClauses...) | 			db.Statement.Build(db.Statement.BuildClauses...) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil { | 		checkMissingWhereConditions(db) | ||||||
| 			db.AddError(gorm.ErrMissingWhereClause) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
| 		if !db.DryRun && db.Error == nil { | 		if !db.DryRun && db.Error == nil { | ||||||
| 			ok, mode := hasReturning(db, supportReturning) | 			ok, mode := hasReturning(db, supportReturning) | ||||||
|  | |||||||
| @ -104,3 +104,19 @@ func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { | |||||||
| 	} | 	} | ||||||
| 	return false, 0 | 	return false, 0 | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func checkMissingWhereConditions(db *gorm.DB) { | ||||||
|  | 	if !db.AllowGlobalUpdate && db.Error == nil { | ||||||
|  | 		where, withCondition := db.Statement.Clauses["WHERE"] | ||||||
|  | 		if withCondition { | ||||||
|  | 			if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { | ||||||
|  | 				whereClause, _ := where.Expression.(clause.Where) | ||||||
|  | 				withCondition = len(whereClause.Exprs) > 1 | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if !withCondition { | ||||||
|  | 			db.AddError(gorm.ErrMissingWhereClause) | ||||||
|  | 		} | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -59,6 +59,12 @@ func Update(config *Config) func(db *gorm.DB) { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		if db.Statement.Schema != nil { | ||||||
|  | 			for _, c := range db.Statement.Schema.UpdateClauses { | ||||||
|  | 				db.Statement.AddClause(c) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		if db.Statement.SQL.Len() == 0 { | 		if db.Statement.SQL.Len() == 0 { | ||||||
| 			db.Statement.SQL.Grow(180) | 			db.Statement.SQL.Grow(180) | ||||||
| 			db.Statement.AddClauseIfNotExists(clause.Update{}) | 			db.Statement.AddClauseIfNotExists(clause.Update{}) | ||||||
| @ -68,22 +74,10 @@ func Update(config *Config) func(db *gorm.DB) { | |||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if db.Statement.Schema != nil { |  | ||||||
| 			for _, c := range db.Statement.Schema.UpdateClauses { |  | ||||||
| 				db.Statement.AddClause(c) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if db.Statement.SQL.Len() == 0 { |  | ||||||
| 			db.Statement.Build(db.Statement.BuildClauses...) | 			db.Statement.Build(db.Statement.BuildClauses...) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { | 		checkMissingWhereConditions(db) | ||||||
| 			db.AddError(gorm.ErrMissingWhereClause) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 
 | 
 | ||||||
| 		if !db.DryRun && db.Error == nil { | 		if !db.DryRun && db.Error == nil { | ||||||
| 			if ok, mode := hasReturning(db, supportReturning); ok { | 			if ok, mode := hasReturning(db, supportReturning); ok { | ||||||
|  | |||||||
| @ -104,10 +104,8 @@ func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { | |||||||
| 
 | 
 | ||||||
| func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { | func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { | ||||||
| 	if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { | 	if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { | ||||||
| 		if _, ok := stmt.Clauses["WHERE"]; stmt.DB.AllowGlobalUpdate || ok { |  | ||||||
| 		SoftDeleteQueryClause(sd).ModifyStatement(stmt) | 		SoftDeleteQueryClause(sd).ModifyStatement(stmt) | ||||||
| 	} | 	} | ||||||
| 	} |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { | func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { | ||||||
| @ -152,12 +150,7 @@ func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if _, ok := stmt.Clauses["WHERE"]; !stmt.DB.AllowGlobalUpdate && !ok { |  | ||||||
| 			stmt.DB.AddError(ErrMissingWhereClause) |  | ||||||
| 		} else { |  | ||||||
| 		SoftDeleteQueryClause(sd).ModifyStatement(stmt) | 		SoftDeleteQueryClause(sd).ModifyStatement(stmt) | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		stmt.AddClauseIfNotExists(clause.Update{}) | 		stmt.AddClauseIfNotExists(clause.Update{}) | ||||||
| 		stmt.Build(stmt.DB.Callback().Update().Clauses...) | 		stmt.Build(stmt.DB.Callback().Update().Clauses...) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -645,7 +645,7 @@ func TestSave(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	dryDB := DB.Session(&gorm.Session{DryRun: true}) | 	dryDB := DB.Session(&gorm.Session{DryRun: true}) | ||||||
| 	stmt := dryDB.Save(&user).Statement | 	stmt := dryDB.Save(&user).Statement | ||||||
| 	if !regexp.MustCompile(`.id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { | 	if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { | ||||||
| 		t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) | 		t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user