Skip to content

added task runner for django #273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,4 @@ codegen/
.vscode/
tests/unit/automator/_trial_temp/_trial_marker
tests/unit/automator/_trial_temp/_trial_marker
.history
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ requests >= 2.31.0
typing-extensions >= 4.2.0
astor >= 0.8.1
shortuuid >= 1.0.11
dacite >= 1.8.1
dacite >= 1.8.1
cachetools==4.2.1
210 changes: 210 additions & 0 deletions src/conductor/client/automator/django.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""Django-specific customization."""
import os
import sys
import warnings
from datetime import datetime, timezone
from importlib import import_module
from typing import IO, TYPE_CHECKING, Any, List, Optional, cast

from kombu.utils.imports import symbol_by_name
from kombu.utils.objects import cached_property

if TYPE_CHECKING:
from types import ModuleType
from typing import Protocol

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.utils import ConnectionHandler

class DjangoDBModule(Protocol):
connections: ConnectionHandler


__all__ = ('DjangoFixup', 'fixup')

ERR_NOT_INSTALLED = """\
Environment variable DJANGO_SETTINGS_MODULE is defined
but Django isn't installed. Won't apply Django fix-ups!
"""


def _maybe_close_fd(fh: IO) -> None:
try:
os.close(fh.fileno())
except (AttributeError, OSError, TypeError):
# TypeError added for celery#962
pass



def fixup(app: "TaskRunner", env: str = 'DJANGO_SETTINGS_MODULE') -> Optional["DjangoFixup"]:
"""Install Django fixup if settings module environment is set."""
SETTINGS_MODULE = os.environ.get(env)
if SETTINGS_MODULE:
try:
import django
except ImportError:
warnings.warn(FixupWarning(ERR_NOT_INSTALLED))
else:
# _verify_django_version(django)
return DjangoFixup(app).install()
return None


class DjangoFixup:
"""Fixup installed when using Django."""

def __init__(self, app: "Celery"):
self.app = app
# if _state.default_app is None:
# self.app.set_default()
self._worker_fixup: Optional["DjangoWorkerFixup"] = None

def install(self) -> "DjangoFixup":
# Need to add project directory to path.
# The project directory has precedence over system modules,
# so we prepend it to the path.
sys.path.insert(0, os.getcwd())

self._settings = symbol_by_name('django.conf:settings')
# self.app.loader.now = self.now

# if not self.app._custom_task_cls_used:
# self.app.task_cls = 'celery.contrib.django.task:DjangoTask'

# signals.import_modules.connect(self.on_import_modules)
# signals.worker_init.connect(self.on_worker_init)
self.on_worker_init()
return self

@property
def worker_fixup(self) -> "DjangoWorkerFixup":
if self._worker_fixup is None:
self._worker_fixup = DjangoWorkerFixup(self.app)
return self._worker_fixup

@worker_fixup.setter
def worker_fixup(self, value: "DjangoWorkerFixup") -> None:
self._worker_fixup = value

def on_import_modules(self, **kwargs: Any) -> None:
# call django.setup() before task modules are imported
self.worker_fixup.validate_models()

def on_worker_init(self, **kwargs: Any) -> None:
self.worker_fixup.install()

def now(self, utc: bool = False) -> datetime:
return datetime.now(timezone.utc) if utc else self._now()

def autodiscover_tasks(self) -> List[str]:
from django.apps import apps
return [config.name for config in apps.get_app_configs()]

@cached_property
def _now(self) -> datetime:
return symbol_by_name('django.utils.timezone:now')


class DjangoWorkerFixup:
_db_recycles = 0

def __init__(self, app: "Celery") -> None:
self.app = app
# self.db_reuse_max = self.app.conf.get('CELERY_DB_REUSE_MAX', None)
self.db_reuse_max = 0
self._db = cast("DjangoDBModule", import_module('django.db'))
self._cache = import_module('django.core.cache')
# self._settings = symbol_by_name('django.conf:settings')

self.interface_errors = (
symbol_by_name('django.db.utils.InterfaceError'),
)
self.DatabaseError = symbol_by_name('django.db:DatabaseError')

def django_setup(self) -> None:
import django
django.setup()

def validate_models(self) -> None:
from django.core.checks import run_checks
self.django_setup()
if not os.environ.get('CELERY_SKIP_CHECKS'):
run_checks()

def install(self) -> "DjangoWorkerFixup":
# signals.beat_embedded_init.connect(self.close_database)
# signals.task_prerun.connect(self.on_task_prerun)
# signals.task_postrun.connect(self.on_task_postrun)
# signals.worker_process_init.connect(self.on_worker_process_init)
self.close_database()
# self.close_cache()
return self

def on_worker_process_init(self, **kwargs: Any) -> None:
# Child process must validate models again if on Windows,
# or if they were started using execv.
if os.environ.get('FORKED_BY_MULTIPROCESSING'):
self.validate_models()

# close connections:
# the parent process may have established these,
# so need to close them.

# calling db.close() on some DB connections will cause
# the inherited DB conn to also get broken in the parent
# process so we need to remove it without triggering any
# network IO that close() might cause.
for c in self._db.connections.all():
if c and c.connection:
self._maybe_close_db_fd(c)

# use the _ version to avoid DB_REUSE preventing the conn.close() call
self._close_database(force=True)
self.close_cache()

def _maybe_close_db_fd(self, c: "BaseDatabaseWrapper") -> None:
try:
with c.wrap_database_errors:
_maybe_close_fd(c.connection)
except self.interface_errors:
pass

def on_task_prerun(self, sender: "Task", **kwargs: Any) -> None:
"""Called before every task."""
if not getattr(sender.request, 'is_eager', False):
self.close_database()

def on_task_postrun(self, **kwargs: Any) -> None:
# See https://groups.google.com/group/django-users/browse_thread/thread/78200863d0c07c6d/
# if not getattr(sender.request, 'is_eager', False):
self.close_database()
self.close_cache()

def close_database(self, **kwargs: Any) -> None:
if not self.db_reuse_max:
return self._close_database()
if self._db_recycles >= self.db_reuse_max * 2:
self._db_recycles = 0
self._close_database()
self._db_recycles += 1

def _close_database(self, force: bool = False) -> None:
for conn in self._db.connections.all():
try:
if force:
conn.close()
else:
conn.close_if_unusable_or_obsolete()
except self.interface_errors:
pass
except self.DatabaseError as exc:
str_exc = str(exc)
if 'closed' not in str_exc and 'not connected' not in str_exc:
raise

def close_cache(self) -> None:
try:
self._cache.close_caches()
except (TypeError, AttributeError):
pass
4 changes: 2 additions & 2 deletions src/conductor/client/automator/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sys import platform
from typing import List

from conductor.client.automator.task_runner import TaskRunner
from conductor.client.automator.task_runner import DjangoTaskRunner
from conductor.client.configuration.configuration import Configuration
from conductor.client.configuration.settings.metrics_settings import MetricsSettings
from conductor.client.telemetry.metrics_collector import MetricsCollector
Expand Down Expand Up @@ -142,7 +142,7 @@ def __create_task_runner_process(
configuration: Configuration,
metrics_settings: MetricsSettings
) -> None:
task_runner = TaskRunner(worker, configuration, metrics_settings)
task_runner = DjangoTaskRunner(worker, configuration, metrics_settings)
process = Process(target=task_runner.run)
self.task_runner_processes.append(process)

Expand Down
27 changes: 27 additions & 0 deletions src/conductor/client/automator/task_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from datetime import datetime
import logging
import os
import sys
import time
import traceback
from cachetools import TTLCache, cached

from conductor.client.configuration.configuration import Configuration
from conductor.client.configuration.settings.metrics_settings import MetricsSettings
Expand Down Expand Up @@ -255,3 +257,28 @@ def __get_property_value_from_env(self, prop, task_type):
key_upper = prefix.upper() + "_" + task_type + "_" + prop.upper()
value = os.getenv(key_small, os.getenv(key_upper, value_all))
return value


class DjangoTaskRunner(TaskRunner):
"""
Task runner takes care of closing/refreshing db connections.
"""

def run(self):
self.django_fixup()
super().run()

def django_fixup(self):
from .django import fixup
self.django = fixup(self)

def run_once(self):
super().run_once()
if self.django:
self.close_connections()

@cached(TTLCache(maxsize=1, ttl=600))
def close_connections(self):
self.django.worker_fixup.on_task_postrun()


8 changes: 4 additions & 4 deletions src/conductor/client/http/models/schema_def.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pprint
from enum import Enum

from typing import Dict
import six


Expand Down Expand Up @@ -30,7 +30,7 @@ class SchemaDef(object):
'external_ref': 'externalRef'
}

def __init__(self, name : str =None, version : int =1, type : SchemaType =None, data : dict[str, object] =None,
def __init__(self, name : str =None, version : int =1, type : SchemaType =None, data : Dict[str, object] =None,
external_ref : str =None): # noqa: E501
self._name = None
self._version = None
Expand Down Expand Up @@ -104,7 +104,7 @@ def type(self, type:SchemaType):
self._type = type

@property
def data(self) -> dict[str, object]:
def data(self) -> Dict[str, object]:
"""Gets the data of this SchemaDef. # noqa: E501

:return: The data of this SchemaDef. # noqa: E501
Expand All @@ -113,7 +113,7 @@ def data(self) -> dict[str, object]:
return self._data

@data.setter
def data(self, data: dict[str, object]):
def data(self, data: Dict[str, object]):
"""Sets the data of this SchemaDef.

:param data: The data of this SchemaDef. # noqa: E501
Expand Down
8 changes: 4 additions & 4 deletions src/conductor/client/http/models/state_change_event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Union, List
from typing import Union, List, Dict
from typing_extensions import Self


Expand All @@ -14,15 +14,15 @@ class StateChangeEventType(Enum):
class StateChangeEvent:
swagger_types = {
'type': 'str',
'payload': 'dict[str, object]'
'payload': 'Dict[str, object]'
}

attribute_map = {
'type': 'type',
'payload': 'payload'
}

def __init__(self, type: str, payload: dict[str, object]) -> None:
def __init__(self, type: str, payload: Dict[str, object]) -> None:
self._type = type
self._payload = payload

Expand All @@ -39,7 +39,7 @@ def payload(self):
return self._payload

@payload.setter
def payload(self, payload: dict[str, object]) -> Self:
def payload(self, payload: Dict[str, object]) -> Self:
self._payload = payload


Expand Down
7 changes: 4 additions & 3 deletions src/conductor/client/http/models/workflow_state_update.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pprint
import re # noqa: F401
from typing import Dict

import six

Expand Down Expand Up @@ -31,7 +32,7 @@ class WorkflowStateUpdate(object):
}

def __init__(self, task_reference_name: str = None, task_result: TaskResult = None,
variables: dict[str, object] = None): # noqa: E501
variables: Dict[str, object] = None): # noqa: E501
"""WorkflowStateUpdate - a model defined in Swagger""" # noqa: E501
self._task_reference_name = None
self._task_result = None
Expand Down Expand Up @@ -86,7 +87,7 @@ def task_result(self, task_result: TaskResult):
self._task_result = task_result

@property
def variables(self) -> dict[str, object]:
def variables(self) -> Dict[str, object]:
"""Gets the variables of this WorkflowStateUpdate. # noqa: E501


Expand All @@ -96,7 +97,7 @@ def variables(self) -> dict[str, object]:
return self._variables

@variables.setter
def variables(self, variables: dict[str, object]):
def variables(self, variables: Dict[str, object]):
"""Sets the variables of this WorkflowStateUpdate.


Expand Down
Loading