diff --git a/.gitignore b/.gitignore index f60b9c74..b7ae1495 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,4 @@ codegen/ .vscode/ tests/unit/automator/_trial_temp/_trial_marker tests/unit/automator/_trial_temp/_trial_marker +.history \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4f8fe193..154aeb27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ requests >= 2.31.0 typing-extensions >= 4.2.0 astor >= 0.8.1 shortuuid >= 1.0.11 -dacite >= 1.8.1 \ No newline at end of file +dacite >= 1.8.1 +cachetools==4.2.1 \ No newline at end of file diff --git a/src/conductor/client/automator/django.py b/src/conductor/client/automator/django.py new file mode 100644 index 00000000..bf97350b --- /dev/null +++ b/src/conductor/client/automator/django.py @@ -0,0 +1,210 @@ +"""Django-specific customization.""" +import os +import sys +import warnings +from datetime import datetime, timezone +from importlib import import_module +from typing import IO, TYPE_CHECKING, Any, List, Optional, cast + +from kombu.utils.imports import symbol_by_name +from kombu.utils.objects import cached_property + +if TYPE_CHECKING: + from types import ModuleType + from typing import Protocol + + from django.db.backends.base.base import BaseDatabaseWrapper + from django.db.utils import ConnectionHandler + + class DjangoDBModule(Protocol): + connections: ConnectionHandler + + +__all__ = ('DjangoFixup', 'fixup') + +ERR_NOT_INSTALLED = """\ +Environment variable DJANGO_SETTINGS_MODULE is defined +but Django isn't installed. Won't apply Django fix-ups! +""" + + +def _maybe_close_fd(fh: IO) -> None: + try: + os.close(fh.fileno()) + except (AttributeError, OSError, TypeError): + # TypeError added for celery#962 + pass + + + +def fixup(app: "TaskRunner", env: str = 'DJANGO_SETTINGS_MODULE') -> Optional["DjangoFixup"]: + """Install Django fixup if settings module environment is set.""" + SETTINGS_MODULE = os.environ.get(env) + if SETTINGS_MODULE: + try: + import django + except ImportError: + warnings.warn(FixupWarning(ERR_NOT_INSTALLED)) + else: + # _verify_django_version(django) + return DjangoFixup(app).install() + return None + + +class DjangoFixup: + """Fixup installed when using Django.""" + + def __init__(self, app: "Celery"): + self.app = app + # if _state.default_app is None: + # self.app.set_default() + self._worker_fixup: Optional["DjangoWorkerFixup"] = None + + def install(self) -> "DjangoFixup": + # Need to add project directory to path. + # The project directory has precedence over system modules, + # so we prepend it to the path. + sys.path.insert(0, os.getcwd()) + + self._settings = symbol_by_name('django.conf:settings') + # self.app.loader.now = self.now + + # if not self.app._custom_task_cls_used: + # self.app.task_cls = 'celery.contrib.django.task:DjangoTask' + + # signals.import_modules.connect(self.on_import_modules) + # signals.worker_init.connect(self.on_worker_init) + self.on_worker_init() + return self + + @property + def worker_fixup(self) -> "DjangoWorkerFixup": + if self._worker_fixup is None: + self._worker_fixup = DjangoWorkerFixup(self.app) + return self._worker_fixup + + @worker_fixup.setter + def worker_fixup(self, value: "DjangoWorkerFixup") -> None: + self._worker_fixup = value + + def on_import_modules(self, **kwargs: Any) -> None: + # call django.setup() before task modules are imported + self.worker_fixup.validate_models() + + def on_worker_init(self, **kwargs: Any) -> None: + self.worker_fixup.install() + + def now(self, utc: bool = False) -> datetime: + return datetime.now(timezone.utc) if utc else self._now() + + def autodiscover_tasks(self) -> List[str]: + from django.apps import apps + return [config.name for config in apps.get_app_configs()] + + @cached_property + def _now(self) -> datetime: + return symbol_by_name('django.utils.timezone:now') + + +class DjangoWorkerFixup: + _db_recycles = 0 + + def __init__(self, app: "Celery") -> None: + self.app = app + # self.db_reuse_max = self.app.conf.get('CELERY_DB_REUSE_MAX', None) + self.db_reuse_max = 0 + self._db = cast("DjangoDBModule", import_module('django.db')) + self._cache = import_module('django.core.cache') + # self._settings = symbol_by_name('django.conf:settings') + + self.interface_errors = ( + symbol_by_name('django.db.utils.InterfaceError'), + ) + self.DatabaseError = symbol_by_name('django.db:DatabaseError') + + def django_setup(self) -> None: + import django + django.setup() + + def validate_models(self) -> None: + from django.core.checks import run_checks + self.django_setup() + if not os.environ.get('CELERY_SKIP_CHECKS'): + run_checks() + + def install(self) -> "DjangoWorkerFixup": + # signals.beat_embedded_init.connect(self.close_database) + # signals.task_prerun.connect(self.on_task_prerun) + # signals.task_postrun.connect(self.on_task_postrun) + # signals.worker_process_init.connect(self.on_worker_process_init) + self.close_database() + # self.close_cache() + return self + + def on_worker_process_init(self, **kwargs: Any) -> None: + # Child process must validate models again if on Windows, + # or if they were started using execv. + if os.environ.get('FORKED_BY_MULTIPROCESSING'): + self.validate_models() + + # close connections: + # the parent process may have established these, + # so need to close them. + + # calling db.close() on some DB connections will cause + # the inherited DB conn to also get broken in the parent + # process so we need to remove it without triggering any + # network IO that close() might cause. + for c in self._db.connections.all(): + if c and c.connection: + self._maybe_close_db_fd(c) + + # use the _ version to avoid DB_REUSE preventing the conn.close() call + self._close_database(force=True) + self.close_cache() + + def _maybe_close_db_fd(self, c: "BaseDatabaseWrapper") -> None: + try: + with c.wrap_database_errors: + _maybe_close_fd(c.connection) + except self.interface_errors: + pass + + def on_task_prerun(self, sender: "Task", **kwargs: Any) -> None: + """Called before every task.""" + if not getattr(sender.request, 'is_eager', False): + self.close_database() + + def on_task_postrun(self, **kwargs: Any) -> None: + # See https://groups.google.com/group/django-users/browse_thread/thread/78200863d0c07c6d/ + # if not getattr(sender.request, 'is_eager', False): + self.close_database() + self.close_cache() + + def close_database(self, **kwargs: Any) -> None: + if not self.db_reuse_max: + return self._close_database() + if self._db_recycles >= self.db_reuse_max * 2: + self._db_recycles = 0 + self._close_database() + self._db_recycles += 1 + + def _close_database(self, force: bool = False) -> None: + for conn in self._db.connections.all(): + try: + if force: + conn.close() + else: + conn.close_if_unusable_or_obsolete() + except self.interface_errors: + pass + except self.DatabaseError as exc: + str_exc = str(exc) + if 'closed' not in str_exc and 'not connected' not in str_exc: + raise + + def close_cache(self) -> None: + try: + self._cache.close_caches() + except (TypeError, AttributeError): + pass \ No newline at end of file diff --git a/src/conductor/client/automator/task_handler.py b/src/conductor/client/automator/task_handler.py index a187a71e..d5451250 100644 --- a/src/conductor/client/automator/task_handler.py +++ b/src/conductor/client/automator/task_handler.py @@ -5,7 +5,7 @@ from sys import platform from typing import List -from conductor.client.automator.task_runner import TaskRunner +from conductor.client.automator.task_runner import DjangoTaskRunner from conductor.client.configuration.configuration import Configuration from conductor.client.configuration.settings.metrics_settings import MetricsSettings from conductor.client.telemetry.metrics_collector import MetricsCollector @@ -142,7 +142,7 @@ def __create_task_runner_process( configuration: Configuration, metrics_settings: MetricsSettings ) -> None: - task_runner = TaskRunner(worker, configuration, metrics_settings) + task_runner = DjangoTaskRunner(worker, configuration, metrics_settings) process = Process(target=task_runner.run) self.task_runner_processes.append(process) diff --git a/src/conductor/client/automator/task_runner.py b/src/conductor/client/automator/task_runner.py index 18744294..af66c7fc 100644 --- a/src/conductor/client/automator/task_runner.py +++ b/src/conductor/client/automator/task_runner.py @@ -1,8 +1,10 @@ +from datetime import datetime import logging import os import sys import time import traceback +from cachetools import TTLCache, cached from conductor.client.configuration.configuration import Configuration from conductor.client.configuration.settings.metrics_settings import MetricsSettings @@ -255,3 +257,28 @@ def __get_property_value_from_env(self, prop, task_type): key_upper = prefix.upper() + "_" + task_type + "_" + prop.upper() value = os.getenv(key_small, os.getenv(key_upper, value_all)) return value + + +class DjangoTaskRunner(TaskRunner): + """ + Task runner takes care of closing/refreshing db connections. + """ + + def run(self): + self.django_fixup() + super().run() + + def django_fixup(self): + from .django import fixup + self.django = fixup(self) + + def run_once(self): + super().run_once() + if self.django: + self.close_connections() + + @cached(TTLCache(maxsize=1, ttl=600)) + def close_connections(self): + self.django.worker_fixup.on_task_postrun() + + diff --git a/src/conductor/client/http/models/schema_def.py b/src/conductor/client/http/models/schema_def.py index d7aa9d98..9ec6082d 100644 --- a/src/conductor/client/http/models/schema_def.py +++ b/src/conductor/client/http/models/schema_def.py @@ -1,6 +1,6 @@ import pprint from enum import Enum - +from typing import Dict import six @@ -30,7 +30,7 @@ class SchemaDef(object): 'external_ref': 'externalRef' } - def __init__(self, name : str =None, version : int =1, type : SchemaType =None, data : dict[str, object] =None, + def __init__(self, name : str =None, version : int =1, type : SchemaType =None, data : Dict[str, object] =None, external_ref : str =None): # noqa: E501 self._name = None self._version = None @@ -104,7 +104,7 @@ def type(self, type:SchemaType): self._type = type @property - def data(self) -> dict[str, object]: + def data(self) -> Dict[str, object]: """Gets the data of this SchemaDef. # noqa: E501 :return: The data of this SchemaDef. # noqa: E501 @@ -113,7 +113,7 @@ def data(self) -> dict[str, object]: return self._data @data.setter - def data(self, data: dict[str, object]): + def data(self, data: Dict[str, object]): """Sets the data of this SchemaDef. :param data: The data of this SchemaDef. # noqa: E501 diff --git a/src/conductor/client/http/models/state_change_event.py b/src/conductor/client/http/models/state_change_event.py index f64b440f..7ca9c4a5 100644 --- a/src/conductor/client/http/models/state_change_event.py +++ b/src/conductor/client/http/models/state_change_event.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Union, List +from typing import Union, List, Dict from typing_extensions import Self @@ -14,7 +14,7 @@ class StateChangeEventType(Enum): class StateChangeEvent: swagger_types = { 'type': 'str', - 'payload': 'dict[str, object]' + 'payload': 'Dict[str, object]' } attribute_map = { @@ -22,7 +22,7 @@ class StateChangeEvent: 'payload': 'payload' } - def __init__(self, type: str, payload: dict[str, object]) -> None: + def __init__(self, type: str, payload: Dict[str, object]) -> None: self._type = type self._payload = payload @@ -39,7 +39,7 @@ def payload(self): return self._payload @payload.setter - def payload(self, payload: dict[str, object]) -> Self: + def payload(self, payload: Dict[str, object]) -> Self: self._payload = payload diff --git a/src/conductor/client/http/models/workflow_state_update.py b/src/conductor/client/http/models/workflow_state_update.py index 9769a7de..90788102 100644 --- a/src/conductor/client/http/models/workflow_state_update.py +++ b/src/conductor/client/http/models/workflow_state_update.py @@ -1,5 +1,6 @@ import pprint import re # noqa: F401 +from typing import Dict import six @@ -31,7 +32,7 @@ class WorkflowStateUpdate(object): } def __init__(self, task_reference_name: str = None, task_result: TaskResult = None, - variables: dict[str, object] = None): # noqa: E501 + variables: Dict[str, object] = None): # noqa: E501 """WorkflowStateUpdate - a model defined in Swagger""" # noqa: E501 self._task_reference_name = None self._task_result = None @@ -86,7 +87,7 @@ def task_result(self, task_result: TaskResult): self._task_result = task_result @property - def variables(self) -> dict[str, object]: + def variables(self) -> Dict[str, object]: """Gets the variables of this WorkflowStateUpdate. # noqa: E501 @@ -96,7 +97,7 @@ def variables(self) -> dict[str, object]: return self._variables @variables.setter - def variables(self, variables: dict[str, object]): + def variables(self, variables: Dict[str, object]): """Sets the variables of this WorkflowStateUpdate. diff --git a/src/conductor/client/http/models/workflow_task.py b/src/conductor/client/http/models/workflow_task.py index 5c3f2c75..3a6e50e9 100644 --- a/src/conductor/client/http/models/workflow_task.py +++ b/src/conductor/client/http/models/workflow_task.py @@ -1,6 +1,6 @@ import pprint import re # noqa: F401 -from typing import List +from typing import List, Dict import six @@ -124,7 +124,7 @@ def __init__(self, name=None, task_reference_name=None, description=None, input_ sub_workflow_param=None, join_on=None, sink=None, optional=None, task_definition : 'TaskDef' =None, rate_limited=None, default_exclusive_join_task=None, async_complete=None, loop_condition=None, loop_over=None, retry_count=None, evaluator_type=None, expression=None, - workflow_task_type=None, on_state_change: dict[str, StateChangeConfig] = None, + workflow_task_type=None, on_state_change: Dict[str, StateChangeConfig] = None, cache_config: CacheConfig = None): # noqa: E501 """WorkflowTask - a model defined in Swagger""" # noqa: E501 self._name = None @@ -850,7 +850,7 @@ def workflow_task_type(self, workflow_task_type): self._workflow_task_type = workflow_task_type @property - def on_state_change(self) -> dict[str, List[StateChangeEvent]]: + def on_state_change(self) -> Dict[str, List[StateChangeEvent]]: return self._on_state_change @on_state_change.setter diff --git a/src/conductor/client/orkes/orkes_secret_client.py b/src/conductor/client/orkes/orkes_secret_client.py index 20868cd2..02aed980 100644 --- a/src/conductor/client/orkes/orkes_secret_client.py +++ b/src/conductor/client/orkes/orkes_secret_client.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Set from conductor.client.configuration.configuration import Configuration from conductor.client.orkes.models.metadata_tag import MetadataTag @@ -16,7 +16,7 @@ def put_secret(self, key: str, value: str): def get_secret(self, key: str) -> str: return self.secretResourceApi.get_secret(key) - def list_all_secret_names(self) -> set[str]: + def list_all_secret_names(self) -> Set[str]: return set(self.secretResourceApi.list_all_secret_names()) def list_secrets_that_user_can_grant_access_to(self) -> List[str]: diff --git a/src/conductor/client/orkes/orkes_workflow_client.py b/src/conductor/client/orkes/orkes_workflow_client.py index 475cfbde..4c1c75c0 100644 --- a/src/conductor/client/orkes/orkes_workflow_client.py +++ b/src/conductor/client/orkes/orkes_workflow_client.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Dict from conductor.client.configuration.configuration import Configuration from conductor.client.http.models import SkipTaskRequest, WorkflowStatus, \ @@ -24,7 +24,7 @@ def __init__( def start_workflow_by_name( self, name: str, - input: dict[str, object], + input: Dict[str, object], version: Optional[int] = None, correlationId: Optional[str] = None, priority: Optional[int] = None, @@ -130,7 +130,7 @@ def get_by_correlation_ids_in_batch( self, batch_request: CorrelationIdsSearchRequest, include_completed: bool = False, - include_tasks: bool = False) -> dict[str, List[Workflow]]: + include_tasks: bool = False) -> Dict[str, List[Workflow]]: """Given the list of correlation ids and list of workflow names, find and return workflows Returns a map with key as correlationId and value as a list of Workflows @@ -150,7 +150,7 @@ def get_by_correlation_ids( correlation_ids: List[str], include_completed: bool = False, include_tasks: bool = False - ) -> dict[str, List[Workflow]]: + ) -> Dict[str, List[Workflow]]: """Lists workflows for the given correlation id list""" kwargs = {} if include_tasks: @@ -167,7 +167,7 @@ def get_by_correlation_ids( def remove_workflow(self, workflow_id: str): self.workflowResourceApi.delete(workflow_id) - def update_variables(self, workflow_id: str, variables: dict[str, object] = {}) -> None: + def update_variables(self, workflow_id: str, variables: Dict[str, object] = {}) -> None: self.workflowResourceApi.update_workflow_state(variables, workflow_id) def update_state(self, workflow_id: str, update_requesst: WorkflowStateUpdate, diff --git a/src/conductor/client/secret_client.py b/src/conductor/client/secret_client.py index 39c03597..cb9d8a15 100644 --- a/src/conductor/client/secret_client.py +++ b/src/conductor/client/secret_client.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Set from conductor.client.orkes.models.metadata_tag import MetadataTag @@ -13,7 +13,7 @@ def get_secret(self, key: str) -> str: pass @abstractmethod - def list_all_secret_names(self) -> set[str]: + def list_all_secret_names(self) -> Set[str]: pass @abstractmethod diff --git a/src/conductor/client/workflow/executor/workflow_executor.py b/src/conductor/client/workflow/executor/workflow_executor.py index 7b4d2e76..55c696f2 100644 --- a/src/conductor/client/workflow/executor/workflow_executor.py +++ b/src/conductor/client/workflow/executor/workflow_executor.py @@ -132,7 +132,7 @@ def get_by_correlation_ids( correlation_ids: List[str], include_closed: bool = None, include_tasks: bool = None - ) -> dict[str, List[Workflow]]: + ) -> Dict[str, List[Workflow]]: """Lists workflows for the given correlation id list""" return self.workflow_client.get_by_correlation_ids( correlation_ids=correlation_ids, diff --git a/src/conductor/client/workflow/task/simple_task.py b/src/conductor/client/workflow/task/simple_task.py index 9144b976..21414edd 100644 --- a/src/conductor/client/workflow/task/simple_task.py +++ b/src/conductor/client/workflow/task/simple_task.py @@ -1,3 +1,5 @@ +from typing import Dict + from typing_extensions import Self from conductor.client.workflow.task.task import TaskInterface @@ -13,7 +15,7 @@ def __init__(self, task_def_name: str, task_reference_name: str) -> Self: ) -def simple_task(task_def_name: str, task_reference_name: str, inputs: dict[str, object]) -> TaskInterface: +def simple_task(task_def_name: str, task_reference_name: str, inputs: Dict[str, object]) -> TaskInterface: task = SimpleTask(task_def_name=task_def_name, task_reference_name=task_reference_name) task.input_parameters.update(inputs) return task diff --git a/src/conductor/client/workflow_client.py b/src/conductor/client/workflow_client.py index 08884250..6285d0bf 100644 --- a/src/conductor/client/workflow_client.py +++ b/src/conductor/client/workflow_client.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, List +from typing import Optional, List, Dict from conductor.client.http.models import WorkflowRun, SkipTaskRequest, WorkflowStatus, \ ScrollableSearchResultWorkflowSummary @@ -82,7 +82,7 @@ def get_by_correlation_ids_in_batch( self, batch_request: CorrelationIdsSearchRequest, include_completed: bool = False, - include_tasks: bool = False) -> dict[str, List[Workflow]]: + include_tasks: bool = False) -> Dict[str, List[Workflow]]: pass @abstractmethod @@ -92,7 +92,7 @@ def get_by_correlation_ids( correlation_ids: List[str], include_completed: bool = False, include_tasks: bool = False - ) -> dict[str, List[Workflow]]: + ) -> Dict[str, List[Workflow]]: pass @abstractmethod @@ -100,7 +100,7 @@ def remove_workflow(self, workflow_id: str): pass @abstractmethod - def update_variables(self, workflow_id: str, variables: dict[str, object] = {}) -> None: + def update_variables(self, workflow_id: str, variables: Dict[str, object] = {}) -> None: pass @abstractmethod diff --git a/tests/integration/metadata/test_workflow_definition.py b/tests/integration/metadata/test_workflow_definition.py index ba7528bc..bd737afd 100644 --- a/tests/integration/metadata/test_workflow_definition.py +++ b/tests/integration/metadata/test_workflow_definition.py @@ -1,3 +1,5 @@ +from typing import List + from conductor.client.http.models import TaskDef from conductor.client.http.models.start_workflow_request import StartWorkflowRequest from conductor.client.workflow.conductor_workflow import ConductorWorkflow @@ -23,7 +25,7 @@ def run_workflow_definition_tests(workflow_executor: WorkflowExecutor) -> None: test_kitchensink_workflow_registration(workflow_executor) -def generate_tasks_defs() -> list[TaskDef]: +def generate_tasks_defs() -> List[TaskDef]: python_simple_task_from_code = TaskDef( description="desc python_simple_task_from_code", owner_app="python_integration_test_app",