Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/5921.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introduce service, repository layer when handling agent status change event
22 changes: 22 additions & 0 deletions src/ai/backend/manager/data/agent/modifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import override

from ai.backend.manager.models.agent import AgentStatus
from ai.backend.manager.types import OptionalState, PartialModifier


@dataclass
class AgentStatusModifier(PartialModifier):
status: AgentStatus
status_changed: datetime
lost_at: OptionalState[datetime] = field(default_factory=lambda: OptionalState.nop())

@override
def fields_to_update(self) -> dict:
to_update = {
"status": self.status,
"status_changed": self.status_changed,
}
self.lost_at.update_dict(to_update, "lost_at")
return to_update
40 changes: 26 additions & 14 deletions src/ai/backend/manager/event_dispatcher/handlers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from ai.backend.manager.errors.resource import InstanceNotFound
from ai.backend.manager.registry import AgentRegistry
from ai.backend.manager.services.agent.actions.handle_heartbeat import HandleHeartbeatAction
from ai.backend.manager.services.agent.actions.mark_agent_exit import MarkAgentExitAction
from ai.backend.manager.services.agent.actions.mark_agent_running import MarkAgentRunningAction
from ai.backend.manager.services.processors import Processors

from ...models.agent import AgentStatus, agents
Expand Down Expand Up @@ -78,11 +80,12 @@ async def handle_agent_started(
event: AgentStartedEvent,
) -> None:
log.info("instance_lifecycle: ag:{0} joined (via event, {1})", source, event.reason)
await self._registry.update_instance(
source,
{
"status": AgentStatus.ALIVE,
},
processors = await self.get_processors()
await processors.agent.mark_agent_running.wait_for_complete(
MarkAgentRunningAction(
agent_id=source,
agent_status=AgentStatus.ALIVE,
)
)

async def handle_agent_terminated(
Expand All @@ -91,22 +94,31 @@ async def handle_agent_terminated(
source: AgentId,
event: AgentTerminatedEvent,
) -> None:
processors = await self.get_processors()
if event.reason == "agent-lost":
await self._registry.mark_agent_terminated(source, AgentStatus.LOST)
self._registry.agent_cache.discard(source)
await processors.agent.mark_agent_exit.wait_for_complete(
MarkAgentExitAction(
agent_id=source,
agent_status=AgentStatus.LOST,
)
)
elif event.reason == "agent-restart":
log.info("agent@{0} restarting for maintenance.", source)
await self._registry.update_instance(
source,
{
"status": AgentStatus.RESTARTING,
},
await processors.agent.mark_agent_running.wait_for_complete(
MarkAgentRunningAction(
agent_id=source,
agent_status=AgentStatus.RESTARTING,
)
)
else:
# On normal instance termination, kernel_terminated events were already
# triggered by the agent.
await self._registry.mark_agent_terminated(source, AgentStatus.TERMINATED)
self._registry.agent_cache.discard(source)
await processors.agent.mark_agent_exit.wait_for_complete(
MarkAgentExitAction(
agent_id=source,
agent_status=AgentStatus.TERMINATED,
)
)

async def handle_agent_heartbeat(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,15 @@ async def update_agent_last_seen(self, agent_id: AgentId, time: datetime) -> Non
await self._valkey_live.update_agent_last_seen(agent_id, time.timestamp())
except Exception as e:
log.debug("Failed to update last seen for agent: {}, error: {}", agent_id, e)

async def remove_agent_last_seen(self, agent_id: AgentId) -> None:
try:
await self._valkey_live.remove_agent_last_seen(agent_id)
except Exception as e:
log.debug("Failed to remove last seen for agent: {}, error: {}", agent_id, e)

async def remove_agent_from_all_images(self, agent_id: AgentId) -> None:
try:
await self._valkey_image.remove_agent_from_all_images(agent_id)
except Exception as e:
log.debug("Failed to remove agent: {} from all images, error: {}", agent_id, e)
39 changes: 38 additions & 1 deletion src/ai/backend/manager/repositories/agent/db_source/db_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from ai.backend.common.exception import ScalingGroupNotFoundError
from ai.backend.common.types import AgentId
from ai.backend.logging.utils import BraceStyleAdapter
from ai.backend.manager.data.agent.modifier import AgentStatusModifier
from ai.backend.manager.data.agent.types import (
AgentHeartbeatUpsert,
UpsertResult,
)
from ai.backend.manager.models import agents
from ai.backend.manager.models.agent import AgentRow
from ai.backend.manager.models.agent import AgentRow, AgentStatus
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
from ai.backend.manager.services.agent.types import AgentData

Expand Down Expand Up @@ -52,3 +53,39 @@ async def upsert_agent_with_state(self, upsert_data: AgentHeartbeatUpsert) -> Up
raise ScalingGroupNotFoundError(upsert_data.scaling_group)

return upsert_result

async def update_agent_status_exit(
self, agent_id: AgentId, modifier: AgentStatusModifier
) -> None:
async with self._db.begin() as conn:
fetch_query = (
sa.select([
agents.c.status,
agents.c.addr,
])
.select_from(agents)
.where(agents.c.id == agent_id)
.with_for_update()
)
result = await conn.execute(fetch_query)
row = result.first()
prev_status = row["status"]
if prev_status in (None, AgentStatus.LOST, AgentStatus.TERMINATED):
return

if modifier.status == AgentStatus.LOST:
log.warning("agent {0} heartbeat timeout detected.", agent_id)
elif modifier.status == AgentStatus.TERMINATED:
log.info("agent {0} has terminated.", agent_id)

update_query = (
sa.update(agents).values(modifier.fields_to_update()).where(agents.c.id == agent_id)
)
await conn.execute(update_query)

async def update_agent_status(self, agent_id: AgentId, modifier: AgentStatusModifier) -> None:
async with self._db.begin() as conn:
query = (
sa.update(agents).values(modifier.fields_to_update()).where(agents.c.id == agent_id)
)
await conn.execute(query)
11 changes: 11 additions & 0 deletions src/ai/backend/manager/repositories/agent/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ai.backend.logging.utils import BraceStyleAdapter
from ai.backend.manager.agent_cache import AgentRPCCache
from ai.backend.manager.config.provider import ManagerConfigProvider
from ai.backend.manager.data.agent.modifier import AgentStatusModifier
from ai.backend.manager.data.agent.types import (
AgentHeartbeatUpsert,
AgentStateSyncData,
Expand Down Expand Up @@ -75,3 +76,13 @@ async def sync_agent_heartbeat(
)

return upsert_result

@repository_decorator()
async def cleanup_agent_on_exit(self, agent_id: AgentId, modifier: AgentStatusModifier) -> None:
await self._cache_source.remove_agent_last_seen(agent_id)
await self._db_source.update_agent_status_exit(agent_id, modifier)
await self._cache_source.remove_agent_from_all_images(agent_id)

@repository_decorator()
async def update_agent_status(self, agent_id: AgentId, modifier: AgentStatusModifier) -> None:
await self._db_source.update_agent_status(agent_id, modifier)
31 changes: 31 additions & 0 deletions src/ai/backend/manager/services/agent/actions/mark_agent_exit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import Literal, Optional, override

from ai.backend.common.types import AgentId
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.models.agent import AgentStatus
from ai.backend.manager.services.agent.actions.base import AgentAction


@dataclass
class MarkAgentExitAction(AgentAction):
agent_id: AgentId
agent_status: Literal[AgentStatus.LOST, AgentStatus.TERMINATED]

@override
def entity_id(self) -> Optional[str]:
return str(self.agent_id)

@override
@classmethod
def operation_type(cls) -> str:
return "mark_agent_exit"


@dataclass
class MarkAgentExitActionResult(BaseActionResult):
agent_id: AgentId

@override
def entity_id(self) -> Optional[str]:
return str(self.agent_id)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import Literal, Optional, override

from ai.backend.common.types import AgentId
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.models.agent import AgentStatus
from ai.backend.manager.services.agent.actions.base import AgentAction


@dataclass
class MarkAgentRunningAction(AgentAction):
agent_id: AgentId
agent_status: Literal[AgentStatus.ALIVE, AgentStatus.RESTARTING]

@override
def entity_id(self) -> Optional[str]:
return str(self.agent_id)

@override
@classmethod
def operation_type(cls) -> str:
return "mark_agent_running"


@dataclass
class MarkAgentRunningActionResult(BaseActionResult):
agent_id: AgentId

@override
def entity_id(self) -> Optional[str]:
return str(self.agent_id)
14 changes: 14 additions & 0 deletions src/ai/backend/manager/services/agent/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
HandleHeartbeatAction,
HandleHeartbeatActionResult,
)
from ai.backend.manager.services.agent.actions.mark_agent_exit import (
MarkAgentExitAction,
MarkAgentExitActionResult,
)
from ai.backend.manager.services.agent.actions.mark_agent_running import (
MarkAgentRunningAction,
MarkAgentRunningActionResult,
)
from ai.backend.manager.services.agent.actions.recalculate_usage import (
RecalculateUsageAction,
RecalculateUsageActionResult,
Expand Down Expand Up @@ -44,6 +52,8 @@ class AgentProcessors(AbstractProcessorPackage):
watcher_agent_stop: ActionProcessor[WatcherAgentStopAction, WatcherAgentStopActionResult]
recalculate_usage: ActionProcessor[RecalculateUsageAction, RecalculateUsageActionResult]
handle_heartbeat: ActionProcessor[HandleHeartbeatAction, HandleHeartbeatActionResult]
mark_agent_exit: ActionProcessor[MarkAgentExitAction, MarkAgentExitActionResult]
mark_agent_running: ActionProcessor[MarkAgentRunningAction, MarkAgentRunningActionResult]

def __init__(self, service: AgentService, action_monitors: list[ActionMonitor]) -> None:
self.sync_agent_registry = ActionProcessor(service.sync_agent_registry, action_monitors)
Expand All @@ -53,6 +63,8 @@ def __init__(self, service: AgentService, action_monitors: list[ActionMonitor])
self.watcher_agent_stop = ActionProcessor(service.watcher_agent_stop, action_monitors)
self.recalculate_usage = ActionProcessor(service.recalculate_usage, action_monitors)
self.handle_heartbeat = ActionProcessor(service.handle_heartbeat, action_monitors)
self.mark_agent_exit = ActionProcessor(service.mark_agent_exit, action_monitors)
self.mark_agent_running = ActionProcessor(service.mark_agent_running, action_monitors)

@override
def supported_actions(self) -> list[ActionSpec]:
Expand All @@ -64,4 +76,6 @@ def supported_actions(self) -> list[ActionSpec]:
WatcherAgentStopAction.spec(),
RecalculateUsageAction.spec(),
HandleHeartbeatAction.spec(),
MarkAgentExitAction.spec(),
MarkAgentRunningAction.spec(),
]
34 changes: 34 additions & 0 deletions src/ai/backend/manager/services/agent/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ai.backend.logging.utils import BraceStyleAdapter
from ai.backend.manager.agent_cache import AgentRPCCache
from ai.backend.manager.config.provider import ManagerConfigProvider
from ai.backend.manager.data.agent.modifier import AgentStatusModifier
from ai.backend.manager.data.agent.types import (
AgentHeartbeatUpsert,
AgentStateSyncData,
Expand All @@ -36,6 +37,14 @@
HandleHeartbeatAction,
HandleHeartbeatActionResult,
)
from ai.backend.manager.services.agent.actions.mark_agent_exit import (
MarkAgentExitAction,
MarkAgentExitActionResult,
)
from ai.backend.manager.services.agent.actions.mark_agent_running import (
MarkAgentRunningAction,
MarkAgentRunningActionResult,
)
from ai.backend.manager.services.agent.actions.recalculate_usage import (
RecalculateUsageAction,
RecalculateUsageActionResult,
Expand All @@ -56,6 +65,7 @@
WatcherAgentStopAction,
WatcherAgentStopActionResult,
)
from ai.backend.manager.types import OptionalState

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

Expand Down Expand Up @@ -224,3 +234,27 @@ async def handle_heartbeat(self, action: HandleHeartbeatAction) -> HandleHeartbe
),
)
return HandleHeartbeatActionResult(agent_id=action.agent_id)

async def mark_agent_exit(self, action: MarkAgentExitAction) -> MarkAgentExitActionResult:
now = datetime.now(tzutc())
self._agent_repository.cleanup_agent_on_exit(
agent_id=action.agent_id,
modifier=AgentStatusModifier(
status=action.agent_status, status_changed=now, lost_at=OptionalState.update(now)
),
)
self._agent_cache.discard(action.agent_id)
return MarkAgentExitActionResult(agent_id=action.agent_id)

async def mark_agent_running(
self, action: MarkAgentRunningAction
) -> MarkAgentRunningActionResult:
now = datetime.now(tzutc())
await self._agent_repository.update_agent_status(
agent_id=action.agent_id,
modifier=AgentStatusModifier(
status=action.agent_status,
status_changed=now,
),
)
return MarkAgentRunningActionResult(agent_id=action.agent_id)
Loading