Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(BA-533): Add delete path API to storage watcher #3548

Open
wants to merge 2 commits into
base: topic/12-20-feat_enable_per-user_uid_gid_set_for_containers
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3548.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add storage-watcher API to delete VFolders with elevated permissions
3 changes: 3 additions & 0 deletions src/ai/backend/storage/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TreeUsage,
VFolderID,
)
from .watcher import WatcherClient

# Available capabilities of a volume implementation
CAP_VFOLDER: Final = "vfolder" # ability to create vfolder
Expand Down Expand Up @@ -200,6 +201,7 @@ def __init__(
etcd: AsyncEtcd,
event_dispatcher: EventDispatcher,
event_producer: EventProducer,
watcher: Optional[WatcherClient] = None,
options: Optional[Mapping[str, Any]] = None,
) -> None:
self.local_config = local_config
Expand All @@ -208,6 +210,7 @@ def __init__(
self.etcd = etcd
self.event_dispatcher = event_dispatcher
self.event_producer = event_producer
self.watcher = watcher

async def init(self) -> None:
self.fsop_model = await self.create_fsop_model()
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/storage/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ async def get_volume(self, name: str) -> AsyncIterator[AbstractVolume]:
etcd=self.etcd,
event_dispatcher=self.event_dispatcher,
event_producer=self.event_producer,
watcher=self.watcher,
)

await volume_obj.init()
Expand Down
27 changes: 19 additions & 8 deletions src/ai/backend/storage/vfs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
VFolderID,
)
from ..utils import fstime2datetime
from ..watcher import DeletePathTask, WatcherClient

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

Expand Down Expand Up @@ -189,9 +190,12 @@ async def delete_quota_scope(


class BaseFSOpModel(AbstractFSOpModel):
def __init__(self, mount_path: Path, scandir_limit: int) -> None:
def __init__(
self, mount_path: Path, scandir_limit: int, watcher: Optional[WatcherClient] = None
) -> None:
self.mount_path = mount_path
self.scandir_limit = scandir_limit
self.watcher = watcher

async def copy_tree(
self,
Expand Down Expand Up @@ -224,11 +228,14 @@ async def delete_tree(
self,
path: Path,
) -> None:
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, lambda: shutil.rmtree(path))
except FileNotFoundError:
pass
if self.watcher is not None:
await self.watcher.request_task(DeletePathTask(path))
else:
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, lambda: shutil.rmtree(path))
except FileNotFoundError:
pass

def scan_tree(
self,
Expand Down Expand Up @@ -371,6 +378,7 @@ async def create_fsop_model(self) -> AbstractFSOpModel:
return BaseFSOpModel(
self.mount_path,
self.local_config["storage-proxy"]["scandir-limit"],
self.watcher,
)

async def get_capabilities(self) -> FrozenSet[str]:
Expand Down Expand Up @@ -676,5 +684,8 @@ async def delete_files(
for p in target_paths:
if p.is_dir() and recursive:
await self.fsop_model.delete_tree(p)
else:
await aiofiles.os.remove(p)
elif p.is_file():
if self.watcher is not None:
await self.watcher.request_task(DeletePathTask(p))
else:
await aiofiles.os.remove(p)
Comment on lines +687 to +691
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there might be a better way to decide whether the watcher should monitor, rather than using if statements for branching.
At the very least, is there a way to wrap the file removal and write operations at the time of creation?

75 changes: 47 additions & 28 deletions src/ai/backend/storage/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import asyncio
import logging
import os
import shutil
import traceback
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Any, ClassVar, Sequence, Type
from typing import Any, ClassVar, Self, Sequence

import aiofiles.os
import attrs
import zmq
import zmq.asyncio
Expand Down Expand Up @@ -79,27 +81,19 @@ class AbstractTask(metaclass=ABCMeta):
async def run(self) -> Any:
pass

@classmethod
def deserialize_from_request(cls, raw_data: Request) -> AbstractTask:
serializer_name = str(raw_data.header, "utf8")
values: tuple = msgpack.unpackb(raw_data.body)
serializer_cls = SERIALIZER_MAP[serializer_name]
return serializer_cls.deserialize(values)

def serialize_to_request(self) -> Request:
assert self.name in SERIALIZER_MAP
header = bytes(self.name, "utf8")
return Request(header, self.serialize())

@abstractmethod
def serialize(self) -> bytes:
pass

@classmethod
@abstractmethod
def deserialize(cls, values: tuple) -> AbstractTask:
def deserialize(cls, values: tuple) -> Self:
pass

def serialize_to_request(self) -> Request:
header = bytes(self.name, "utf8")
return Request(header, self.serialize())


@attrs.define(slots=True)
class ChownTask(AbstractTask):
Expand All @@ -120,8 +114,8 @@ def serialize(self) -> bytes:
))

@classmethod
def deserialize(cls, values: tuple) -> ChownTask:
return ChownTask(
def deserialize(cls, values: tuple) -> Self:
return cls(
values[0],
values[1],
values[2],
Expand Down Expand Up @@ -161,7 +155,7 @@ async def run(self) -> Any:
def from_event(
cls, event: DoVolumeMountEvent, *, mount_path: Path, mount_prefix: str | None = None
) -> MountTask:
return MountTask(
return cls(
str(mount_path),
event.quota_scope_id,
event.fs_location,
Expand All @@ -187,8 +181,8 @@ def serialize(self) -> bytes:
))

@classmethod
def deserialize(cls, values: tuple) -> MountTask:
return MountTask(
def deserialize(cls, values: tuple) -> Self:
return cls(
values[0],
QuotaScopeID.parse(values[1]),
values[2],
Expand Down Expand Up @@ -236,7 +230,7 @@ def from_event(
mount_prefix: str | None = None,
timeout: float | None = None,
) -> UmountTask:
return UmountTask(
return cls(
str(mount_path),
event.quota_scope_id,
event.scaling_group,
Expand All @@ -258,8 +252,8 @@ def serialize(self) -> bytes:
))

@classmethod
def deserialize(cls, values: tuple) -> UmountTask:
return UmountTask(
def deserialize(cls, values: tuple) -> Self:
return cls(
values[0],
QuotaScopeID.parse(values[1]),
values[2],
Expand All @@ -270,11 +264,29 @@ def deserialize(cls, values: tuple) -> UmountTask:
)


SERIALIZER_MAP: dict[str, Type[AbstractTask]] = {
MountTask.name: MountTask,
UmountTask.name: UmountTask,
ChownTask.name: ChownTask,
}
@attrs.define(slots=True)
class DeletePathTask(AbstractTask):
name = "delete-path"
path: Path

async def run(self) -> Any:
if self.path.is_dir():
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, lambda: shutil.rmtree(self.path))
except FileNotFoundError:
pass
else:
await aiofiles.os.remove(self.path)

def serialize(self) -> bytes:
return msgpack.packb((self.path,))

@classmethod
def deserialize(cls, values: tuple) -> Self:
return cls(
values[0],
)


@attrs.define(slots=True)
Expand Down Expand Up @@ -355,6 +367,7 @@ def __init__(

self.outsock = zctx.socket(zmq.PUSH)
self.outsock.bind(get_zmq_socket_file_path(output_sock_prefix, self.pidx))
self.serializer_map = {cls.name: cls for cls in AbstractTask.__subclasses__()}

async def close(self) -> None:
self.insock.close()
Expand All @@ -366,13 +379,19 @@ async def ack(self) -> None:
async def respond(self, succeeded: bool, data: str) -> None:
await Protocol.respond(self.outsock, Response(succeeded, data))

def _deserialize_from_request(self, raw_data: Request) -> AbstractTask:
serializer_name = str(raw_data.header, "utf8")
values: tuple = msgpack.unpackb(raw_data.body)
serializer_cls = self.serializer_map[serializer_name]
return serializer_cls.deserialize(values)

async def main(self) -> None:
try:
while True:
client_request = await Protocol.listen_to_request(self.insock)

try:
task = AbstractTask.deserialize_from_request(client_request)
task = self._deserialize_from_request(client_request)
result = await task.run()
except Exception as e:
log.exception(f"Error in watcher task. (e: {e})")
Expand Down
Loading