Skip to content

Commit 100be00

Browse files
authored
increase task_class input validation and log messages and add unit test (#191)
1 parent b6b0d49 commit 100be00

File tree

2 files changed

+98
-8
lines changed

2 files changed

+98
-8
lines changed

src/workflow_app/workflow/states.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
from .settings import POSTPROCESS_ERROR, CATALOG_DATA_READY
1010
from .settings import REDUCTION_DATA_READY, REDUCTION_CATALOG_DATA_READY
1111
from .database import transactions
12+
13+
import importlib
14+
import inspect
1215
import json
1316
import logging
17+
import re
1418

1519

1620
class StateAction:
@@ -47,6 +51,30 @@ def _call_default_task(self, headers, message):
4751
action_cls = globals()[destination]
4852
action_cls(connection=self._send_connection)(headers, message)
4953

54+
def _get_class_from_path(self, class_path: str):
55+
"""
56+
Returns the class given by the class path
57+
:param class_path: the class, e.g. "module_name.ClassName"
58+
:return: class or None
59+
"""
60+
# check that the string is in the format "package_name.module_name.class_name"
61+
pattern = r"^[a-zA-Z0-9_\.]+\.[a-zA-Z0-9_]+$"
62+
if not re.match(pattern, class_path):
63+
logging.error(f"task_class {class_path} does not match pattern module_name.ClassName")
64+
return None
65+
module_name, class_name = class_path.rsplit(".", 1)
66+
67+
# try importing the class
68+
try:
69+
module = importlib.import_module(module_name)
70+
cls = getattr(module, class_name)
71+
if not inspect.isclass(cls):
72+
raise ValueError
73+
return cls
74+
except (ModuleNotFoundError, AttributeError, ValueError):
75+
logging.error(f"task_class {class_path} cannot be imported")
76+
return None
77+
5078
def _call_db_task(self, task_data, headers, message):
5179
"""
5280
:param task_data: JSON-encoded task definition
@@ -59,14 +87,12 @@ def _call_db_task(self, task_data, headers, message):
5987
and (task_def["task_class"] is not None)
6088
and len(task_def["task_class"].strip()) > 0
6189
):
62-
try:
63-
toks = task_def["task_class"].strip().split(".")
64-
module = ".".join(toks[: len(toks) - 1])
65-
cls = toks[len(toks) - 1]
66-
exec("from %s import %s as action_cls" % (module, cls))
67-
action_cls(connection=self._send_connection)(headers, message) # noqa: F821
68-
except: # noqa: E722
69-
logging.exception("Task [%s] failed:", headers["destination"])
90+
action_cls = self._get_class_from_path(task_def["task_class"])
91+
if action_cls:
92+
try:
93+
action_cls(connection=self._send_connection)(headers, message) # noqa: F821
94+
except: # noqa: E722
95+
logging.exception("Task [%s] failed:", headers["destination"])
7096
if "task_queues" in task_def:
7197
for item in task_def["task_queues"]:
7298
destination = "/queue/%s" % item

src/workflow_app/workflow/tests/test_states.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,20 @@
66
_ = [workflow]
77

88

9+
class FakeTestClass:
10+
def __init__(self, connection):
11+
pass
12+
13+
def __call__(self, headers, message):
14+
raise ValueError
15+
16+
917
class StateActionTest(TestCase):
18+
19+
@pytest.fixture(autouse=True)
20+
def inject_fixtures(self, caplog):
21+
self.caplog = caplog
22+
1023
def test_call_default_task(self):
1124
from workflow.states import StateAction
1225

@@ -47,6 +60,57 @@ def test_call(self, mock_get_task):
4760
sa(headers, message)
4861
assert mock_connection.send.call_count - original_call_count == 2 # one per task queue
4962

63+
@mock.patch("workflow.states.transactions.get_task")
64+
def test_task_class_path(self, mock_get_task):
65+
from workflow.states import StateAction
66+
67+
mock_connection = mock.Mock()
68+
sa = StateAction(connection=mock_connection, use_db_task=True)
69+
headers = {"destination": "test", "message-id": "test-0"}
70+
message = '{"facility": "SNS", "instrument": "arcs", "ipts": "IPTS-5", "run_number": 3, "data_file": "test"}'
71+
72+
# test with task class "-" (inserted by Django admin interface when left empty)
73+
mock_get_task.return_value = '{"task_class": "-"}'
74+
self.caplog.clear()
75+
sa(headers, message)
76+
assert "does not match pattern" in self.caplog.text
77+
78+
# test with task class that does not follow the pattern "module_name.ClassName"
79+
mock_get_task.return_value = '{"task_class": "FakeClass"}'
80+
self.caplog.clear()
81+
sa(headers, message)
82+
assert "does not match pattern" in self.caplog.text
83+
84+
# test with module that does not exist
85+
mock_get_task.return_value = '{"task_class": "fake_module.FakeClass"}'
86+
self.caplog.clear()
87+
sa(headers, message)
88+
assert "cannot be imported" in self.caplog.text
89+
90+
# test with module exists but class does not
91+
mock_get_task.return_value = '{"task_class": "workflow.states.FakeClass"}'
92+
self.caplog.clear()
93+
sa(headers, message)
94+
assert "cannot be imported" in self.caplog.text
95+
96+
# test with module attribute is not a class
97+
mock_get_task.return_value = '{"task_class": "workflow.state_utilities.decode_message"}'
98+
self.caplog.clear()
99+
sa(headers, message)
100+
assert "cannot be imported" in self.caplog.text
101+
102+
# test with calling class fails
103+
mock_get_task.return_value = '{"task_class": "workflow.tests.test_states.FakeTestClass"}'
104+
self.caplog.clear()
105+
sa(headers, message)
106+
assert "Task [test] failed" in self.caplog.text
107+
108+
# test with valid class
109+
mock_get_task.return_value = '{"task_class": "workflow.states.Reduction_request"}'
110+
self.caplog.clear()
111+
sa(headers, message)
112+
assert mock_connection.send.call_count == 1
113+
50114
@mock.patch("workflow.database.transactions.add_status_entry")
51115
def test_send(self, mockAddStatusEntry):
52116
from workflow.states import StateAction

0 commit comments

Comments
 (0)