Skip to content

Commit 5ff11c8

Browse files
fregataaYaminyam
authored andcommitted
feat(BA-533): Add delete path API to storage watcher (#3548)
1 parent 50dd6c4 commit 5ff11c8

File tree

5 files changed

+71
-36
lines changed

5 files changed

+71
-36
lines changed

changes/3548.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add storage-watcher API to delete VFolders with elevated permissions

src/ai/backend/storage/abc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
TreeUsage,
3232
VFolderID,
3333
)
34+
from .watcher import WatcherClient
3435

3536
# Available capabilities of a volume implementation
3637
CAP_VFOLDER: Final = "vfolder" # ability to create vfolder
@@ -201,6 +202,7 @@ def __init__(
201202
etcd: AsyncEtcd,
202203
event_dispatcher: EventDispatcher,
203204
event_producer: EventProducer,
205+
watcher: Optional[WatcherClient] = None,
204206
options: Optional[Mapping[str, Any]] = None,
205207
) -> None:
206208
self.local_config = local_config
@@ -209,6 +211,7 @@ def __init__(
209211
self.etcd = etcd
210212
self.event_dispatcher = event_dispatcher
211213
self.event_producer = event_producer
214+
self.watcher = watcher
212215

213216
async def init(self) -> None:
214217
self.fsop_model = await self.create_fsop_model()

src/ai/backend/storage/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ async def get_volume(self, name: str) -> AsyncIterator[AbstractVolume]:
201201
etcd=self.etcd,
202202
event_dispatcher=self.event_dispatcher,
203203
event_producer=self.event_producer,
204+
watcher=self.watcher,
204205
)
205206

206207
await volume_obj.init()

src/ai/backend/storage/vfs/__init__.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
VFolderID,
4444
)
4545
from ..utils import fstime2datetime
46+
from ..watcher import DeletePathTask, WatcherClient
4647

4748
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
4849

@@ -190,9 +191,12 @@ async def delete_quota_scope(
190191

191192

192193
class BaseFSOpModel(AbstractFSOpModel):
193-
def __init__(self, mount_path: Path, scandir_limit: int) -> None:
194+
def __init__(
195+
self, mount_path: Path, scandir_limit: int, watcher: Optional[WatcherClient] = None
196+
) -> None:
194197
self.mount_path = mount_path
195198
self.scandir_limit = scandir_limit
199+
self.watcher = watcher
196200

197201
async def copy_tree(
198202
self,
@@ -225,11 +229,14 @@ async def delete_tree(
225229
self,
226230
path: Path,
227231
) -> None:
228-
loop = asyncio.get_running_loop()
229-
try:
230-
await loop.run_in_executor(None, lambda: shutil.rmtree(path))
231-
except FileNotFoundError:
232-
pass
232+
if self.watcher is not None:
233+
await self.watcher.request_task(DeletePathTask(path))
234+
else:
235+
loop = asyncio.get_running_loop()
236+
try:
237+
await loop.run_in_executor(None, lambda: shutil.rmtree(path))
238+
except FileNotFoundError:
239+
pass
233240

234241
def scan_tree(
235242
self,
@@ -372,6 +379,7 @@ async def create_fsop_model(self) -> AbstractFSOpModel:
372379
return BaseFSOpModel(
373380
self.mount_path,
374381
self.local_config["storage-proxy"]["scandir-limit"],
382+
self.watcher,
375383
)
376384

377385
async def get_capabilities(self) -> FrozenSet[str]:
@@ -682,5 +690,8 @@ async def delete_files(
682690
for p in target_paths:
683691
if p.is_dir() and recursive:
684692
await self.fsop_model.delete_tree(p)
685-
else:
686-
await aiofiles.os.remove(p)
693+
elif p.is_file():
694+
if self.watcher is not None:
695+
await self.watcher.request_task(DeletePathTask(p))
696+
else:
697+
await aiofiles.os.remove(p)

src/ai/backend/storage/watcher.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import asyncio
44
import logging
55
import os
6+
import shutil
67
import traceback
78
from abc import ABCMeta, abstractmethod
89
from pathlib import Path
9-
from typing import Any, ClassVar, Sequence, Type
10+
from typing import Any, ClassVar, Self, Sequence
1011

12+
import aiofiles.os
1113
import attrs
1214
import zmq
1315
import zmq.asyncio
@@ -79,27 +81,19 @@ class AbstractTask(metaclass=ABCMeta):
7981
async def run(self) -> Any:
8082
pass
8183

82-
@classmethod
83-
def deserialize_from_request(cls, raw_data: Request) -> AbstractTask:
84-
serializer_name = str(raw_data.header, "utf8")
85-
values: tuple = msgpack.unpackb(raw_data.body)
86-
serializer_cls = SERIALIZER_MAP[serializer_name]
87-
return serializer_cls.deserialize(values)
88-
89-
def serialize_to_request(self) -> Request:
90-
assert self.name in SERIALIZER_MAP
91-
header = bytes(self.name, "utf8")
92-
return Request(header, self.serialize())
93-
9484
@abstractmethod
9585
def serialize(self) -> bytes:
9686
pass
9787

9888
@classmethod
9989
@abstractmethod
100-
def deserialize(cls, values: tuple) -> AbstractTask:
90+
def deserialize(cls, values: tuple) -> Self:
10191
pass
10292

93+
def serialize_to_request(self) -> Request:
94+
header = bytes(self.name, "utf8")
95+
return Request(header, self.serialize())
96+
10397

10498
@attrs.define(slots=True)
10599
class ChownTask(AbstractTask):
@@ -120,8 +114,8 @@ def serialize(self) -> bytes:
120114
))
121115

122116
@classmethod
123-
def deserialize(cls, values: tuple) -> ChownTask:
124-
return ChownTask(
117+
def deserialize(cls, values: tuple) -> Self:
118+
return cls(
125119
values[0],
126120
values[1],
127121
values[2],
@@ -161,7 +155,7 @@ async def run(self) -> Any:
161155
def from_event(
162156
cls, event: DoVolumeMountEvent, *, mount_path: Path, mount_prefix: str | None = None
163157
) -> MountTask:
164-
return MountTask(
158+
return cls(
165159
str(mount_path),
166160
event.quota_scope_id,
167161
event.fs_location,
@@ -187,8 +181,8 @@ def serialize(self) -> bytes:
187181
))
188182

189183
@classmethod
190-
def deserialize(cls, values: tuple) -> MountTask:
191-
return MountTask(
184+
def deserialize(cls, values: tuple) -> Self:
185+
return cls(
192186
values[0],
193187
QuotaScopeID.parse(values[1]),
194188
values[2],
@@ -236,7 +230,7 @@ def from_event(
236230
mount_prefix: str | None = None,
237231
timeout: float | None = None,
238232
) -> UmountTask:
239-
return UmountTask(
233+
return cls(
240234
str(mount_path),
241235
event.quota_scope_id,
242236
event.scaling_group,
@@ -258,8 +252,8 @@ def serialize(self) -> bytes:
258252
))
259253

260254
@classmethod
261-
def deserialize(cls, values: tuple) -> UmountTask:
262-
return UmountTask(
255+
def deserialize(cls, values: tuple) -> Self:
256+
return cls(
263257
values[0],
264258
QuotaScopeID.parse(values[1]),
265259
values[2],
@@ -270,11 +264,29 @@ def deserialize(cls, values: tuple) -> UmountTask:
270264
)
271265

272266

273-
SERIALIZER_MAP: dict[str, Type[AbstractTask]] = {
274-
MountTask.name: MountTask,
275-
UmountTask.name: UmountTask,
276-
ChownTask.name: ChownTask,
277-
}
267+
@attrs.define(slots=True)
268+
class DeletePathTask(AbstractTask):
269+
name = "delete-path"
270+
path: Path
271+
272+
async def run(self) -> Any:
273+
if self.path.is_dir():
274+
loop = asyncio.get_running_loop()
275+
try:
276+
await loop.run_in_executor(None, lambda: shutil.rmtree(self.path))
277+
except FileNotFoundError:
278+
pass
279+
else:
280+
await aiofiles.os.remove(self.path)
281+
282+
def serialize(self) -> bytes:
283+
return msgpack.packb((self.path,))
284+
285+
@classmethod
286+
def deserialize(cls, values: tuple) -> Self:
287+
return cls(
288+
values[0],
289+
)
278290

279291

280292
@attrs.define(slots=True)
@@ -355,6 +367,7 @@ def __init__(
355367

356368
self.outsock = zctx.socket(zmq.PUSH)
357369
self.outsock.bind(get_zmq_socket_file_path(output_sock_prefix, self.pidx))
370+
self.serializer_map = {cls.name: cls for cls in AbstractTask.__subclasses__()}
358371

359372
async def close(self) -> None:
360373
self.insock.close()
@@ -366,13 +379,19 @@ async def ack(self) -> None:
366379
async def respond(self, succeeded: bool, data: str) -> None:
367380
await Protocol.respond(self.outsock, Response(succeeded, data))
368381

382+
def _deserialize_from_request(self, raw_data: Request) -> AbstractTask:
383+
serializer_name = str(raw_data.header, "utf8")
384+
values: tuple = msgpack.unpackb(raw_data.body)
385+
serializer_cls = self.serializer_map[serializer_name]
386+
return serializer_cls.deserialize(values)
387+
369388
async def main(self) -> None:
370389
try:
371390
while True:
372391
client_request = await Protocol.listen_to_request(self.insock)
373392

374393
try:
375-
task = AbstractTask.deserialize_from_request(client_request)
394+
task = self._deserialize_from_request(client_request)
376395
result = await task.run()
377396
except Exception as e:
378397
log.exception(f"Error in watcher task. (e: {e})")

0 commit comments

Comments
 (0)