Skip to content

Commit f6197bc

Browse files
authored
Merge pull request #14240 from valentijnscholten/remove-dojo-async-task-base-task-bugfix
refactor dojo async task base task (bugfix branch)
2 parents c7221fe + b6e729e commit f6197bc

29 files changed

Lines changed: 818 additions & 311 deletions

dojo/api_v2/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from dojo.api_v2.prefetch.prefetcher import _Prefetcher
4848
from dojo.authorization.roles_permissions import Permissions
49+
from dojo.celery_dispatch import dojo_dispatch_task
4950
from dojo.cred.queries import get_authorized_cred_mappings
5051
from dojo.endpoint.queries import (
5152
get_authorized_endpoint_status,
@@ -679,13 +680,13 @@ def update_jira_epic(self, request, pk=None):
679680
try:
680681

681682
if engagement.has_jira_issue:
682-
jira_helper.update_epic(engagement.id, **request.data)
683+
dojo_dispatch_task(jira_helper.update_epic, engagement.id, **request.data)
683684
response = Response(
684685
{"info": "Jira Epic update query sent"},
685686
status=status.HTTP_200_OK,
686687
)
687688
else:
688-
jira_helper.add_epic(engagement.id, **request.data)
689+
dojo_dispatch_task(jira_helper.add_epic, engagement.id, **request.data)
689690
response = Response(
690691
{"info": "Jira Epic create query sent"},
691692
status=status.HTTP_200_OK,

dojo/celery.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,56 @@
1212
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "dojo.settings.settings")
1313

1414

15-
class PgHistoryTask(Task):
15+
class DojoAsyncTask(Task):
16+
17+
"""
18+
Base task class that provides dojo_async_task functionality without using a decorator.
19+
20+
This class:
21+
- Injects user context into task kwargs
22+
- Tracks task calls for performance testing
23+
- Supports all Celery features (signatures, chords, groups, chains)
24+
"""
25+
26+
def apply_async(self, args=None, kwargs=None, **options):
27+
"""Override apply_async to inject user context and track tasks."""
28+
from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import
29+
from dojo.utils import get_current_user # noqa: PLC0415 circular import
30+
31+
if kwargs is None:
32+
kwargs = {}
33+
34+
# Inject user context if not already present
35+
if "async_user" not in kwargs:
36+
kwargs["async_user"] = get_current_user()
37+
38+
# Control flag used for sync/async decision; never pass into the task itself
39+
kwargs.pop("sync", None)
40+
41+
# Track dispatch
42+
dojo_async_task_counter.incr(
43+
self.name,
44+
args=args,
45+
kwargs=kwargs,
46+
)
47+
48+
# Call parent to execute async
49+
return super().apply_async(args=args, kwargs=kwargs, **options)
50+
51+
52+
class PgHistoryTask(DojoAsyncTask):
1653

1754
"""
1855
Custom Celery base task that automatically applies pghistory context.
1956
20-
When a task is dispatched via dojo_async_task, the current pghistory
21-
context is captured and passed in kwargs as "_pgh_context". This base
22-
class extracts that context and applies it before running the task,
23-
ensuring all database events share the same context as the original
24-
request.
57+
This class inherits from DojoAsyncTask to provide:
58+
- User context injection and task tracking (from DojoAsyncTask)
59+
- Automatic pghistory context application (from this class)
60+
61+
When a task is dispatched via dojo_dispatch_task or dojo_async_task, the current
62+
pghistory context is captured and passed in kwargs as "_pgh_context". This base
63+
class extracts that context and applies it before running the task, ensuring all
64+
database events share the same context as the original request.
2565
"""
2666

2767
def __call__(self, *args, **kwargs):

dojo/celery_dispatch.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Protocol, cast
4+
5+
from celery.canvas import Signature
6+
7+
if TYPE_CHECKING:
8+
from collections.abc import Mapping
9+
10+
11+
class _SupportsSi(Protocol):
12+
def si(self, *args: Any, **kwargs: Any) -> Signature: ...
13+
14+
15+
class _SupportsApplyAsync(Protocol):
16+
def apply_async(self, args: Any | None = None, kwargs: Any | None = None, **options: Any) -> Any: ...
17+
18+
19+
def _inject_async_user(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
20+
result: dict[str, Any] = dict(kwargs or {})
21+
if "async_user" not in result:
22+
from dojo.utils import get_current_user # noqa: PLC0415 circular import
23+
24+
result["async_user"] = get_current_user()
25+
return result
26+
27+
28+
def _inject_pghistory_context(kwargs: Mapping[str, Any] | None) -> dict[str, Any]:
29+
"""Capture and inject pghistory context if available."""
30+
result: dict[str, Any] = dict(kwargs or {})
31+
if "_pgh_context" not in result:
32+
from dojo.pghistory_utils import get_serializable_pghistory_context # noqa: PLC0415 circular import
33+
34+
if pgh_context := get_serializable_pghistory_context():
35+
result["_pgh_context"] = pgh_context
36+
return result
37+
38+
39+
def dojo_create_signature(task_or_sig: _SupportsSi | Signature, *args: Any, **kwargs: Any) -> Signature:
40+
"""
41+
Build a Celery signature with DefectDojo user context and pghistory context injected.
42+
43+
- If passed a task, returns `task_or_sig.si(*args, **kwargs)`.
44+
- If passed an existing signature, returns a cloned signature with merged kwargs.
45+
"""
46+
injected = _inject_async_user(kwargs)
47+
injected = _inject_pghistory_context(injected)
48+
injected.pop("countdown", None)
49+
50+
if isinstance(task_or_sig, Signature):
51+
merged_kwargs = {**(task_or_sig.kwargs or {}), **injected}
52+
return task_or_sig.clone(kwargs=merged_kwargs)
53+
54+
return task_or_sig.si(*args, **injected)
55+
56+
57+
def dojo_dispatch_task(task_or_sig: _SupportsSi | _SupportsApplyAsync | Signature, *args: Any, **kwargs: Any) -> Any:
58+
"""
59+
Dispatch a task/signature using DefectDojo semantics.
60+
61+
- Inject `async_user` if missing.
62+
- Capture and inject pghistory context if available.
63+
- Respect `sync=True` (foreground execution) and user `block_execution`.
64+
- Support `countdown=<seconds>` for async dispatch.
65+
66+
Returns:
67+
- async: AsyncResult-like return from Celery
68+
- sync: underlying return value of the task
69+
70+
"""
71+
from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import
72+
73+
countdown = cast("int", kwargs.pop("countdown", 0))
74+
injected = _inject_async_user(kwargs)
75+
injected = _inject_pghistory_context(injected)
76+
77+
sig = dojo_create_signature(task_or_sig if isinstance(task_or_sig, Signature) else cast("_SupportsSi", task_or_sig), *args, **injected)
78+
sig_kwargs = dict(sig.kwargs or {})
79+
80+
if we_want_async(*sig.args, func=getattr(sig, "type", None), **sig_kwargs):
81+
# DojoAsyncTask.apply_async tracks async dispatch. Avoid double-counting here.
82+
return sig.apply_async(countdown=countdown)
83+
84+
# Track foreground execution as a "created task" as well (matches historical dojo_async_task behavior)
85+
dojo_async_task_counter.incr(str(sig.task), args=sig.args, kwargs=sig_kwargs)
86+
87+
sig_kwargs.pop("sync", None)
88+
sig = sig.clone(kwargs=sig_kwargs)
89+
eager = sig.apply()
90+
try:
91+
return eager.get(propagate=True)
92+
except RuntimeError:
93+
# Since we are intentionally running synchronously, we can propagate exceptions directly, and enable sync subtasks
94+
# If the requests desires this. Celery docs explain that this is a rare use case, but we support it _just in case_
95+
return eager.get(propagate=True, disable_sync_subtasks=False)

dojo/endpoint/views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dojo.authorization.authorization import user_has_permission_or_403
1919
from dojo.authorization.authorization_decorators import user_is_authorized
2020
from dojo.authorization.roles_permissions import Permissions
21+
from dojo.celery_dispatch import dojo_dispatch_task
2122
from dojo.endpoint.queries import get_authorized_endpoints_for_queryset
2223
from dojo.endpoint.utils import clean_hosts_run, endpoint_meta_import
2324
from dojo.filters import EndpointFilter, EndpointFilterWithoutObjectLookups
@@ -345,7 +346,7 @@ def endpoint_bulk_update_all(request, pid=None):
345346
product_calc = list(Product.objects.filter(endpoint__id__in=endpoints_to_update).distinct())
346347
endpoints.delete()
347348
for prod in product_calc:
348-
calculate_grade(prod.id)
349+
dojo_dispatch_task(calculate_grade, prod.id)
349350

350351
if skipped_endpoint_count > 0:
351352
add_error_message_to_response(f"Skipped deletion of {skipped_endpoint_count} endpoints because you are not authorized.")

dojo/engagement/services.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from django.dispatch import receiver
66

77
import dojo.jira_link.helper as jira_helper
8+
from dojo.celery_dispatch import dojo_dispatch_task
89
from dojo.models import Engagement
910

1011
logger = logging.getLogger(__name__)
@@ -16,7 +17,7 @@ def close_engagement(eng):
1617
eng.save()
1718

1819
if jira_helper.get_jira_project(eng):
19-
jira_helper.close_epic(eng.id, push_to_jira=True)
20+
dojo_dispatch_task(jira_helper.close_epic, eng.id, push_to_jira=True)
2021

2122

2223
def reopen_engagement(eng):

dojo/engagement/views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from dojo.authorization.authorization import user_has_permission_or_403
3838
from dojo.authorization.authorization_decorators import user_is_authorized
3939
from dojo.authorization.roles_permissions import Permissions
40+
from dojo.celery_dispatch import dojo_dispatch_task
4041
from dojo.endpoint.utils import save_endpoints_to_add
4142
from dojo.engagement.queries import get_authorized_engagements
4243
from dojo.engagement.services import close_engagement, reopen_engagement
@@ -392,7 +393,7 @@ def copy_engagement(request, eid):
392393
form = DoneForm(request.POST)
393394
if form.is_valid():
394395
engagement_copy = engagement.copy()
395-
calculate_grade(product.id)
396+
dojo_dispatch_task(calculate_grade, product.id)
396397
messages.add_message(
397398
request,
398399
messages.SUCCESS,

dojo/finding/deduplication.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from django.db.models.query_utils import Q
99

1010
from dojo.celery import app
11-
from dojo.decorators import dojo_async_task
1211
from dojo.models import Finding, System_Settings
1312

1413
logger = logging.getLogger(__name__)
@@ -45,13 +44,11 @@ def get_finding_models_for_deduplication(finding_ids):
4544
)
4645

4746

48-
@dojo_async_task
4947
@app.task
5048
def do_dedupe_finding_task(new_finding_id, *args, **kwargs):
5149
return do_dedupe_finding_task_internal(Finding.objects.get(id=new_finding_id), *args, **kwargs)
5250

5351

54-
@dojo_async_task
5552
@app.task
5653
def do_dedupe_batch_task(finding_ids, *args, **kwargs):
5754
"""

dojo/finding/helper.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import dojo.jira_link.helper as jira_helper
1717
import dojo.risk_acceptance.helper as ra_helper
1818
from dojo.celery import app
19-
from dojo.decorators import dojo_async_task
2019
from dojo.endpoint.utils import endpoint_get_or_create, save_endpoints_to_add
2120
from dojo.file_uploads.helper import delete_related_files
2221
from dojo.finding.deduplication import (
@@ -395,7 +394,6 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group
395394
finding_group.findings.add(*findings)
396395

397396

398-
@dojo_async_task
399397
@app.task
400398
def post_process_finding_save(finding_id, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002
401399
issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed
@@ -440,7 +438,9 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option
440438

441439
if product_grading_option:
442440
if system_settings.enable_product_grade:
443-
calculate_grade(finding.test.engagement.product.id)
441+
from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import
442+
443+
dojo_dispatch_task(calculate_grade, finding.test.engagement.product.id)
444444
else:
445445
deduplicationLogger.debug("skipping product grading because it's disabled in system settings")
446446

@@ -457,7 +457,6 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option
457457
jira_helper.push_to_jira(finding.finding_group)
458458

459459

460-
@dojo_async_task
461460
@app.task
462461
def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True,
463462
issue_updater_option=True, push_to_jira=False, user=None, **kwargs):
@@ -500,7 +499,9 @@ def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_op
500499
tool_issue_updater.async_tool_issue_update(finding)
501500

502501
if product_grading_option and system_settings.enable_product_grade:
503-
calculate_grade(findings[0].test.engagement.product.id)
502+
from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import
503+
504+
dojo_dispatch_task(calculate_grade, findings[0].test.engagement.product.id)
504505

505506
if push_to_jira:
506507
for finding in findings:

dojo/finding/views.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
user_is_authorized,
3939
)
4040
from dojo.authorization.roles_permissions import Permissions
41+
from dojo.celery_dispatch import dojo_dispatch_task
4142
from dojo.filters import (
4243
AcceptedFindingFilter,
4344
AcceptedFindingFilterWithoutObjectLookups,
@@ -1099,7 +1100,7 @@ def process_form(self, request: HttpRequest, finding: Finding, context: dict):
10991100
product = finding.test.engagement.product
11001101
finding.delete()
11011102
# Update the grade of the product async
1102-
calculate_grade(product.id)
1103+
dojo_dispatch_task(calculate_grade, product.id)
11031104
# Add a message to the request that the finding was successfully deleted
11041105
messages.add_message(
11051106
request,
@@ -1374,7 +1375,7 @@ def copy_finding(request, fid):
13741375
test = form.cleaned_data.get("test")
13751376
product = finding.test.engagement.product
13761377
finding_copy = finding.copy(test=test)
1377-
calculate_grade(product.id)
1378+
dojo_dispatch_task(calculate_grade, product.id)
13781379
messages.add_message(
13791380
request,
13801381
messages.SUCCESS,

0 commit comments

Comments
 (0)