Skip to content

Commit

Permalink
fix issues pointed by ruff.
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed May 30, 2024
1 parent 9bd428b commit 5dc74ff
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 14 deletions.
58 changes: 45 additions & 13 deletions src/retsu/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,33 @@

from datetime import datetime
from time import sleep
from typing import Any
from typing import Any, Callable, Optional

import redis


class TaskMetadataManager:
"""Manage task metadata."""

def __init__(self, client: redis.Redis):
"""Initialize TaskMetadataManager."""
self.client = client
self.step = StepMetadataManager(self.client)

def get_all(self, task_id: str) -> dict[str, Any]:
def get_all(self, task_id: str) -> dict[str, bytes]:
"""Get the entire metadata for a given task."""
return self.client.hgetall(f"task:{task_id}:metadata")

def get(self, task_id: str, attribute: str) -> Any:
def get(self, task_id: str, attribute: str) -> bytes:
"""Get a specific metadata attribute for a given task."""
return self.client.hget(f"task:{task_id}:metadata", attribute)

def set(self, task_id: str, metadata: dict[str, Any]) -> None:
def create(self, task_id: str, metadata: dict[str, Any]) -> None:
"""Create an initial metadata for given task."""
self.client.hset(f"task:{task_id}:metadata", mapping=metadata)

def update(self, task_id: str, attribute: str, value: Any) -> None:
"""Update the value of given attribute for a given task."""
self.client.hset(f"task:{task_id}:metadata", attribute, value)
self.client.hset(
f"task:{task_id}:metadata",
Expand All @@ -36,23 +43,30 @@ def update(self, task_id: str, attribute: str, value: Any) -> None:


class StepMetadataManager:
"""Manage metadata for steps of a task."""

def __init__(self, redis_client: redis.Redis):
"""Initialize StepMetadataManager."""
self.client = redis_client

def get_all(self, task_id: str, step_id: str) -> dict[str, Any]:
def get_all(self, task_id: str, step_id: str) -> dict[str, bytes]:
"""Get the whole metadata for a given task and step."""
return self.client.hgetall(f"task:{task_id}:step:{step_id}")

def get(self, task_id: str, step_id: str, attribute: str) -> Any:
def get(self, task_id: str, step_id: str, attribute: str) -> bytes:
"""Get the value of a given attribute for a given task and step."""
return self.client.hget(f"task:{task_id}:step:{step_id}", attribute)

def create(
self, task_id: str, step_id: str, metadata: dict[str, Any]
) -> None:
"""Create an initial metadata for given task and step."""
self.client.hset(f"task:{task_id}:step:{step_id}", mapping=metadata)

def update(
self, task_id: str, step_id: str, attribute: str, value: Any
) -> None:
"""Update the value of given attribute for a given task and step."""
if attribute == "status" and value not in ["started", "completed"]:
raise Exception("Status should be started or completed.")

Expand All @@ -65,13 +79,17 @@ def update(


class ResultTaskManager:
"""Manage the result and metadata from tasks."""

def __init__(self, host="localhost", port=6379, db=0):
"""Initialize ResultTaskManager."""
self.client = redis.Redis(
host=host, port=port, db=db, decode_responses=False
)
self.metadata = TaskMetadataManager(self.client)

def get(self, task_id: str, timeout: Optional[int] = None) -> Any:
"""Get the result for a given task."""
time_step = 0.5
if timeout:
while self.status(task_id) != "completed":
Expand All @@ -87,36 +105,50 @@ def get(self, task_id: str, timeout: Optional[int] = None) -> Any:
return pickle.loads(result) if result else result

def load(self, task_id: str) -> dict[str, Any]:
"""Load the whole metadata for a given task."""
return self.metadata.get_all(task_id)

def create(self, task_id: str, metadata: dict[str, Any]) -> None:
self.metadata.set(task_id, metadata)
"""Create a new metadata for a given task."""
self.metadata.create(task_id, metadata)

def save(self, task_id: str, result: Any) -> None:
"""Save the result for a given task."""
self.metadata.update(task_id, "result", pickle.dumps(result))

def status(self, task_id: str) -> str:
"""Get the status for a given task."""
status = self.metadata.get(task_id, "status")
return status.decode("utf8")


def create_result_task_manager() -> ResultTaskManager:
"""Create a ResultTaskManager with parameters from the environment."""
redis_host: str = os.getenv("RETSU_REDIS_HOST", "localhost")
redis_port: int = int(os.getenv("RETSU_REDIS_PORT", 6379))
redis_db: int = int(os.getenv("RETSU_REDIS_DB", 0))

return ResultTaskManager(host=redis_host, port=redis_port, db=redis_db)


def track_task(redis_manager: TaskMetadataManager):
def decorator(task_func: Callable[(Any,), Any]):
def wrapper(self, *args, **kwargs):
def track_step(task_metadata: TaskMetadataManager) -> Callable[(Any,), Any]:
"""Decorate a function with TaskMetadataManager."""

def decorator(task_func: Callable[(Any,), Any]) -> Callable[(Any,), Any]:
"""Return a decorator for the given task."""

def wrapper(self, *args, **kwargs) -> Any:
"""Wrap a function for registering the task metadata."""
task_id = kwargs["task_id"]
step_id = kwargs.get("step_id", task_func.__name__)
redis_manager.update(task_id, step_id, "status", "started")

step_metadata = task_metadata.step

step_metadata.update(task_id, step_id, "status", "started")
result = task_func(self, *args, **kwargs)
redis_manager.update(task_id, step_id, "status", "completed")
redis_manager.set(task_id, step_id, "result", result)
step_metadata.update(task_id, step_id, "status", "completed")
result_pickled = pickle.dumps(result)
step_metadata.update(task_id, step_id, "result", result_pickled)
return result

return wrapper
Expand Down
10 changes: 9 additions & 1 deletion tests/test_task_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,35 @@


class MyTask(SerialTask):
"""Task for the test."""

def task(self, *args, task_id: str, **kwargs) -> None: # type: ignore
"""Task that sum the given 2 numbers."""
a = kwargs.pop("a", 0)
b = kwargs.pop("b", 0)
return a + b


class SetupTask:
"""Setup the task workflow for the test."""

@classmethod
def setup_class(cls) -> None:
"""Set the class configuration for the test."""
cls.task = MyTask()
cls.task.start()

@classmethod
def teardown_class(cls) -> None:
"""Set the class teardown step for the test."""
cls.task.stop()


class TestSerialTask(SetupTask):
"""TestSerialTask."""

def test_serial(self):
def test_serial_simple(self):
"""Run simple test for a serial task."""
results: dict[str, int] = {}

for i in range(10):
Expand Down

0 comments on commit 5dc74ff

Please sign in to comment.