Skip to content

Commit

Permalink
Improved API consistency and more endpoints. Piggyback: don't sleep i…
Browse files Browse the repository at this point in the history
…f module took any task from queue (and maybe skipped it) (#1075)
  • Loading branch information
kazet authored Jun 13, 2024
1 parent 8f54637 commit 2900d19
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 29 deletions.
71 changes: 52 additions & 19 deletions artemis/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
import datetime
from typing import Annotated, Any, Dict, List, Optional

from fastapi import (
APIRouter,
Body,
Depends,
Form,
Header,
HTTPException,
Query,
Request,
)
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Request
from fastapi.responses import RedirectResponse
from karton.core.backend import KartonBackend
from karton.core.config import Config as KartonConfig
Expand All @@ -20,6 +11,7 @@

from artemis.config import Config
from artemis.db import DB, ColumnOrdering, TaskFilter
from artemis.frontend import get_binds_that_can_be_disabled
from artemis.modules.classifier import Classifier
from artemis.producer import create_tasks
from artemis.reporting.base.language import Language
Expand Down Expand Up @@ -60,14 +52,35 @@ def verify_api_token(x_api_token: Annotated[str, Header()]) -> None:
@router.post("/add", dependencies=[Depends(verify_api_token)])
def add(
targets: List[str],
tag: Annotated[Optional[str], Body()] = None,
disabled_modules: List[str] = Config.Miscellaneous.MODULES_DISABLED_BY_DEFAULT,
tag: str | None = Body(default=None),
disabled_modules: Optional[List[str]] = Body(default=None),
enabled_modules: Optional[List[str]] = Body(default=None),
) -> Dict[str, Any]:
"""Add targets to be scanned."""
if disabled_modules and enabled_modules:
raise HTTPException(
status_code=400, detail="It's not possible to set both disabled_modules and enabled_modules."
)

for task in targets:
if not Classifier.is_supported(task):
return {"error": f"Invalid task: {task}"}

identities_that_can_be_disabled = set([bind.identity for bind in get_binds_that_can_be_disabled()])

if enabled_modules:
if len(set(enabled_modules) - identities_that_can_be_disabled) > 0:
raise HTTPException(
status_code=400,
detail=f"The following modules from enabled_modules either don't exist or must always be enabled: {','.join(set(enabled_modules) - identities_that_can_be_disabled)}",
)

if enabled_modules:
# Let's disable all modules that can be disabled and aren't included in enabled_modules
disabled_modules = list(identities_that_can_be_disabled - set(enabled_modules))
elif not disabled_modules:
disabled_modules = Config.Miscellaneous.MODULES_DISABLED_BY_DEFAULT

create_tasks(targets, tag, disabled_modules=disabled_modules)

return {"ok": True}
Expand All @@ -76,7 +89,13 @@ def add(
@router.get("/analyses", dependencies=[Depends(verify_api_token)])
def list_analysis() -> List[Dict[str, Any]]:
"""Returns the list of analysed targets. Any scanned target would be listed here."""
return db.list_analysis()
analyses = db.list_analysis()
karton_state = KartonState(backend=KartonBackend(config=KartonConfig()))
for analysis in analyses:
analysis["num_pending_tasks"] = (
len(karton_state.analyses[analysis["id"]].pending_tasks) if analysis["id"] in karton_state.analyses else 0
)
return analyses


@router.get("/num-queued-tasks", dependencies=[Depends(verify_api_token)])
Expand All @@ -97,7 +116,7 @@ def num_queued_tasks(karton_names: Optional[List[str]] = None) -> int:

@router.get("/task-results", dependencies=[Depends(verify_api_token)])
def get_task_results(
only_interesting: bool = False,
only_interesting: bool = True,
page: int = 1,
page_size: int = 100,
analysis_id: Optional[str] = None,
Expand All @@ -114,6 +133,20 @@ def get_task_results(
).data


@router.post("/stop-and-delete-analysis", dependencies=[Depends(verify_api_token)])
def stop_and_delete_analysis(analysis_id: str) -> Dict[str, bool]:
backend = KartonBackend(config=KartonConfig())

for task in backend.get_all_tasks():
if task.root_uid == analysis_id:
backend.delete_task(task)

if db.get_analysis_by_id(analysis_id):
db.delete_analysis(analysis_id)

return {"ok": True}


@router.get("/exports", dependencies=[Depends(verify_api_token)])
def get_exports() -> List[ReportGenerationTaskModel]:
"""List all exports. An export is a request to create human-readable messages that may be sent to scanned entities."""
Expand Down Expand Up @@ -150,12 +183,12 @@ async def post_export_delete(id: int) -> Dict[str, Any]:
}


@router.post("/export")
@router.post("/export", dependencies=[Depends(verify_api_token)])
async def post_export(
language: str = Form(),
skip_previously_exported: bool = Form(),
tag: Optional[str] = Form(None),
comment: Optional[str] = Form(None),
language: str = Body(),
skip_previously_exported: bool = Body(),
tag: Optional[str] = Body(None),
comment: Optional[str] = Body(None),
) -> Dict[str, Any]:
"""Create a new export. An export is a request to create human-readable messages that may be sent to scanned entities."""
db.create_report_generation_task(
Expand Down
3 changes: 0 additions & 3 deletions artemis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
from artemis.db import DB
from artemis.frontend import error_content_not_found
from artemis.frontend import router as router_front
from artemis.utils import read_template

app = FastAPI(
docs_url="/docs" if Config.Miscellaneous.API_TOKEN else None,
redoc_url=None,
# This will be displayed as the additional text in Swagger docs
description=read_template("components/generating_reports_hint.jinja2"),
)
app.exception_handler(CsrfProtectError)(csrf.csrf_protect_exception_handler)
app.exception_handler(404)(error_content_not_found)
Expand Down
10 changes: 5 additions & 5 deletions artemis/module_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _single_iteration(self) -> int:
else:
resource_lock = None

tasks, locks = self._take_and_lock_tasks(self.task_max_batch_size)
tasks, locks, num_task_removed_from_queue = self._take_and_lock_tasks(self.task_max_batch_size)
self._log_tasks(tasks)

for task in tasks:
Expand All @@ -202,15 +202,15 @@ def _single_iteration(self) -> int:
if resource_lock:
resource_lock.release()

return len(tasks)
return num_task_removed_from_queue

def _take_and_lock_tasks(self, num_tasks: int) -> Tuple[List[Task], List[Optional[ResourceLock]]]:
def _take_and_lock_tasks(self, num_tasks: int) -> Tuple[List[Task], List[Optional[ResourceLock]], int]:
self.log.debug("[taking tasks] Acquiring lock to take tasks from queue")
try:
self.taking_tasks_from_queue_lock.acquire()
except FailedToAcquireLockException:
self.log.info("[taking tasks] Failed to acquire lock to take tasks from queue")
return [], []
return [], [], 0

try:
tasks = []
Expand Down Expand Up @@ -300,7 +300,7 @@ def _take_and_lock_tasks(self, num_tasks: int) -> Tuple[List[Task], List[Optiona
self.log.debug(
"[taking tasks] Tasks from queue taken and filtered, %d left after filtering", len(tasks_not_blocklisted)
)
return tasks_not_blocklisted, locks_for_tasks_not_blocklisted
return tasks_not_blocklisted, locks_for_tasks_not_blocklisted, len(tasks)

def _is_blocklisted(self, task: Task) -> bool:
if self.identity == "classifier":
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/test_automated_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_automated_interaction(self) -> None:

analyses = requests.get(BACKEND_URL + "api/analyses", headers={"X-API-Token": "api-token"}).json()
self.assertEqual(len(analyses), 1)
self.assertEqual(set(analyses[0].keys()), {"stopped", "target", "created_at", "id", "tag"})
self.assertEqual(set(analyses[0].keys()), {"stopped", "target", "created_at", "id", "tag", "num_pending_tasks"})
self.assertEqual(analyses[0]["stopped"], False)
self.assertEqual(analyses[0]["target"], "test-smtp-server.artemis")
self.assertEqual(analyses[0]["tag"], "automated-interaction")
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/test_exporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_exporting_api(self) -> None:
self.assertEqual(
requests.post(
BACKEND_URL + "api/export",
data={"skip_previously_exported": True, "language": "pl_PL"},
json={"skip_previously_exported": True, "language": "pl_PL"},
headers={"X-Api-Token": "api-token"},
).json(),
{"ok": True},
Expand Down

0 comments on commit 2900d19

Please sign in to comment.