From 34d1d7de0b62e69be7d1ce9445bfd4f323648c2c Mon Sep 17 00:00:00 2001 From: Agah Date: Tue, 3 Sep 2024 00:12:39 -0400 Subject: [PATCH] unbundle rebundle --- api/neurolibre_celery_tasks.py | 16 +++++++++------- api/screening_client.py | 19 ++++++++++++++++++- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/api/neurolibre_celery_tasks.py b/api/neurolibre_celery_tasks.py index d483e0e..c95fdac 100644 --- a/api/neurolibre_celery_tasks.py +++ b/api/neurolibre_celery_tasks.py @@ -72,20 +72,22 @@ def get_time(): """ class BaseNeuroLibreTask: - def __init__(self, celery_task, task_title, screening=None, payload=None): + def __init__(self, celery_task, screening=None, payload=None): self.celery_task = celery_task - self.task_title = task_title self.payload = payload self.task_id = celery_task.request.id if screening: - if not isinstance(screening, ScreeningClient): - raise TypeError("The 'screening' parameter must be an instance of ScreeningClient") - self.screening = screening + # If passed here, must be JSON serialization of ScreeningClient object. + # We need to unpack these to pass to ScreeningClient to initialize it as an object. + standard_attrs = ['task_name', 'issue_id', 'target_repo_url', 'task_id', 'comment_id', 'commit_hash'] + standard_dict = {key: screening.pop(key) for key in standard_attrs if key in screening} + extra_payload = screening + self.screening = ScreeningClient(**standard_dict, **extra_payload) self.owner_name, self.repo_name, self.provider_name = get_owner_repo_provider(screening.target_repo_url, provider_full_name=True) elif payload: - # This will be probably deprecated soon. + # This will be probably deprecated soon. For now, reserve for backward compatibility. self.screening = ScreeningClient( - self.task_title, + payload['task_name'], payload['issue_id'], payload['repo_url'], self.task_id, diff --git a/api/screening_client.py b/api/screening_client.py index b9d3e82..e8616de 100644 --- a/api/screening_client.py +++ b/api/screening_client.py @@ -25,6 +25,7 @@ def __init__(self, task_name, issue_id, target_repo_url = None, task_id="0000000 self.target_repo_url = target_repo_url self.commit_hash = commit_hash self.comment_id = comment_id + self.__extra_payload = extra_payload for key, value in extra_payload.items(): setattr(self, key, value) @@ -37,12 +38,28 @@ def __init__(self, task_name, issue_id, target_repo_url = None, task_id="0000000 else: self.repo_object = None + # If no comment ID is provided, create a new comment with a pending status if self.comment_id is None: self.comment_id = self.respond().PENDING("Awaiting task assignment...") + def to_dict(self): + # Convert the object to a dictionary to pass to Celery + result = { + 'task_name': self.task_name, + 'issue_id': self.issue_id, + 'target_repo_url': self.target_repo_url, + 'task_id': self.task_id, + 'comment_id': self.comment_id, + 'commit_hash': self.commit_hash, + } + result.update(self.__extra_payload) + return result + def start_celery_task(self, celery_task_func): - task_result = celery_task_func.apply_async(args=[self]) + # This trick is needed to pass the ScreeningClient object to the Celery task. + # This is because the ScreeningClient object cannot be serialized into JSON, which is required by Redis. + task_result = celery_task_func.apply_async(args=[self.to_dict()]) if task_result.task_id is not None: self.task_id = task_result.task_id