Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/5744.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Action Processors to handle various types of actions
25 changes: 25 additions & 0 deletions src/ai/backend/manager/actions/action/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .base import (
BaseAction,
BaseActionResult,
BaseActionResultMeta,
BaseActionTriggerMeta,
ProcessResult,
TAction,
TActionResult,
)
from .batch import (
BaseBatchAction,
BaseBatchActionResult,
)

__all__ = (
"BaseAction",
"BaseActionResult",
"BaseActionResultMeta",
"BaseBatchAction",
"BaseBatchActionResult",
"BaseActionTriggerMeta",
"ProcessResult",
"TAction",
"TActionResult",
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@


class BaseAction(ABC):
@abstractmethod
def entity_id(self) -> Optional[str]:
raise NotImplementedError
"""
Return the ID of the entity this action operates on.
This returns `None` by default because subclasses may not always need to specify an entity ID.
"""
return None

@classmethod
@abstractmethod
Expand All @@ -37,34 +40,12 @@ class BaseActionTriggerMeta:
started_at: datetime


class BaseBatchAction(ABC):
@abstractmethod
def entity_ids(self) -> list[str]:
raise NotImplementedError

@classmethod
@abstractmethod
def entity_type(cls) -> str:
raise NotImplementedError

@classmethod
@abstractmethod
def operation_type(cls) -> str:
raise NotImplementedError


class BaseActionResult(ABC):
@abstractmethod
def entity_id(self) -> Optional[str]:
raise NotImplementedError


class BaseBatchActionResult(ABC):
@abstractmethod
def entity_ids(self) -> list[str]:
raise NotImplementedError


@dataclass
class BaseActionResultMeta:
action_id: uuid.UUID
Expand Down
28 changes: 28 additions & 0 deletions src/ai/backend/manager/actions/action/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import abstractmethod
from typing import Optional, TypeVar, override

from .base import BaseAction, BaseActionResult


class BaseBatchAction(BaseAction):
@override
def entity_id(self) -> Optional[str]:
return None
Comment on lines +7 to +10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the entity_id for backward compatibility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it overrides BaseAction method


@abstractmethod
def entity_ids(self) -> list[str]:
raise NotImplementedError


class BaseBatchActionResult(BaseActionResult):
@override
def entity_id(self) -> Optional[str]:
return None

@abstractmethod
def entity_ids(self) -> list[str]:
raise NotImplementedError


TBatchAction = TypeVar("TBatchAction", bound=BaseBatchAction)
TBatchActionResult = TypeVar("TBatchActionResult", bound=BaseBatchActionResult)
36 changes: 36 additions & 0 deletions src/ai/backend/manager/actions/action/scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import abstractmethod
from typing import Optional, TypeVar, override

from .base import BaseAction, BaseActionResult


class BaseScopeAction(BaseAction):
@override
def entity_id(self) -> Optional[str]:
return None

@abstractmethod
def scope_type(self) -> str:
raise NotImplementedError

@abstractmethod
def scope_id(self) -> str:
raise NotImplementedError


class BaseScopeActionResult(BaseActionResult):
@override
def entity_id(self) -> Optional[str]:
return None

@abstractmethod
def scope_type(self) -> str:
raise NotImplementedError

@abstractmethod
def scope_id(self) -> str:
raise NotImplementedError


TScopeAction = TypeVar("TScopeAction", bound=BaseScopeAction)
TScopeActionResult = TypeVar("TScopeActionResult", bound=BaseScopeActionResult)
28 changes: 28 additions & 0 deletions src/ai/backend/manager/actions/action/single_entity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import abstractmethod
from typing import Optional, TypeVar, override

from .base import BaseAction, BaseActionResult


class BaseSingleEntityAction(BaseAction):
@override
def entity_id(self) -> Optional[str]:
return None

@abstractmethod
def target_entity_id(self) -> str:
raise NotImplementedError


class BaseSingleEntityActionResult(BaseActionResult):
@override
def entity_id(self) -> Optional[str]:
return None

@abstractmethod
def target_entity_id(self) -> str:
raise NotImplementedError


TSingleEntityAction = TypeVar("TSingleEntityAction", bound=BaseSingleEntityAction)
TSingleEntityActionResult = TypeVar("TSingleEntityActionResult", bound=BaseSingleEntityActionResult)
3 changes: 3 additions & 0 deletions src/ai/backend/manager/actions/processor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import ActionProcessor

__all__ = ("ActionProcessor",)
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,66 @@
from ai.backend.common.exception import BackendAIError, ErrorCode
from ai.backend.logging.utils import BraceStyleAdapter
from ai.backend.manager.actions.types import OperationStatus
from ai.backend.manager.actions.validators.validator import ActionValidator
from ai.backend.manager.actions.validator.base import ActionValidator

from .action import (
from ..action import (
BaseActionResultMeta,
BaseActionTriggerMeta,
ProcessResult,
TAction,
TActionResult,
)
from .monitors.monitor import ActionMonitor
from ..monitors.monitor import ActionMonitor

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class ActionProcessor(Generic[TAction, TActionResult]):
_monitors: list[ActionMonitor]
_validators: list[ActionValidator]
class ActionRunner(Generic[TAction, TActionResult]):
_func: Callable[[TAction], Awaitable[TActionResult]]
_monitors: list[ActionMonitor]

def __init__(
self,
func: Callable[[TAction], Awaitable[TActionResult]],
monitors: Optional[list[ActionMonitor]] = None,
validators: Optional[list[ActionValidator]] = None,
monitors: Optional[list[ActionMonitor]],
) -> None:
self._func = func
self._monitors = monitors or []
self._validators = validators or []

async def _run(self, action: TAction) -> TActionResult:
started_at = datetime.now()
status = OperationStatus.UNKNOWN
description: str = "unknown"
result: Optional[TActionResult] = None
error_code: Optional[ErrorCode] = None

action_id = uuid.uuid4()
action_trigger_meta = BaseActionTriggerMeta(action_id=action_id, started_at=started_at)
async def _start_monitors(
self, action: TAction, action_trigger_meta: BaseActionTriggerMeta
) -> None:
for monitor in self._monitors:
try:
await monitor.prepare(action, action_trigger_meta)
except Exception as e:
log.warning("Error in monitor prepare method: {}", e)

async def _finalize_monitors(
self,
action: TAction,
meta: BaseActionResultMeta,
) -> None:
process_result = ProcessResult(meta=meta)
for monitor in reversed(self._monitors):
try:
await monitor.done(action, process_result)
except Exception as e:
log.warning("Error in monitor done method: {}", e)

async def run(
self, action: TAction, action_trigger_meta: BaseActionTriggerMeta
) -> TActionResult:
started_at = action_trigger_meta.started_at
action_id = action_trigger_meta.action_id
status = OperationStatus.UNKNOWN
description: str = "unknown"
result: Optional[TActionResult] = None
error_code: Optional[ErrorCode] = None

await self._start_monitors(action, action_trigger_meta)
try:
for validator in self._validators:
await validator.validate(action, action_trigger_meta)
result = await self._func(action)
status = OperationStatus.SUCCESS
description = "Success"
return result
except BackendAIError as e:
log.exception("Action processing error: {}", e)
status = OperationStatus.ERROR
Expand All @@ -68,12 +78,16 @@ async def _run(self, action: TAction) -> TActionResult:
description = str(e)
error_code = ErrorCode.default()
raise
else:
status = OperationStatus.SUCCESS
description = "Success"
return result
finally:
ended_at = datetime.now()
duration = ended_at - started_at
entity_id = action.entity_id()
if entity_id is None and result is not None:
entity_id = result.entity_id()
ended_at = datetime.now()
duration = ended_at - started_at
meta = BaseActionResultMeta(
action_id=action_id,
entity_id=entity_id,
Expand All @@ -84,12 +98,35 @@ async def _run(self, action: TAction) -> TActionResult:
duration=duration,
error_code=error_code,
)
process_result = ProcessResult(meta=meta)
for monitor in reversed(self._monitors):
try:
await monitor.done(action, process_result)
except Exception as e:
log.warning("Error in monitor done method: {}", e)
await self._finalize_monitors(
action,
meta,
)


class ActionProcessor(Generic[TAction, TActionResult]):
_validators: list[ActionValidator]

_runner: ActionRunner[TAction, TActionResult]

def __init__(
self,
func: Callable[[TAction], Awaitable[TActionResult]],
monitors: Optional[list[ActionMonitor]] = None,
validators: Optional[list[ActionValidator]] = None,
) -> None:
self._runner = ActionRunner(func, monitors)

self._validators = validators or []

async def _run(self, action: TAction) -> TActionResult:
started_at = datetime.now()
action_id = uuid.uuid4()
action_trigger_meta = BaseActionTriggerMeta(action_id=action_id, started_at=started_at)
for validator in self._validators:
await validator.validate(action, action_trigger_meta)

return await self._runner.run(action, action_trigger_meta)

async def wait_for_complete(self, action: TAction) -> TActionResult:
return await self._run(action)
44 changes: 44 additions & 0 deletions src/ai/backend/manager/actions/processor/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import logging
import uuid
from datetime import datetime
from typing import Awaitable, Callable, Generic, Optional

from ai.backend.logging.utils import BraceStyleAdapter
from ai.backend.manager.actions.validator.batch import BatchActionValidator

from ..action import (
BaseActionTriggerMeta,
)
from ..action.batch import TBatchAction, TBatchActionResult
from ..monitors.monitor import ActionMonitor
from .base import ActionRunner

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class BatchActionProcessor(Generic[TBatchAction, TBatchActionResult]):
_validators: list[BatchActionValidator]

_runner: ActionRunner[TBatchAction, TBatchActionResult]

def __init__(
self,
func: Callable[[TBatchAction], Awaitable[TBatchActionResult]],
monitors: Optional[list[ActionMonitor]] = None,
validators: Optional[list[BatchActionValidator]] = None,
) -> None:
self._runner = ActionRunner(func, monitors)

self._validators = validators or []

async def _run(self, action: TBatchAction) -> TBatchActionResult:
started_at = datetime.now()
action_id = uuid.uuid4()
action_trigger_meta = BaseActionTriggerMeta(action_id=action_id, started_at=started_at)
for validator in self._validators:
await validator.validate(action, action_trigger_meta)

return await self._runner.run(action, action_trigger_meta)

async def wait_for_complete(self, action: TBatchAction) -> TBatchActionResult:
return await self._run(action)
Loading
Loading