3
3
import asyncio
4
4
import logging
5
5
import os
6
+ import shutil
6
7
import traceback
7
8
from abc import ABCMeta , abstractmethod
8
9
from pathlib import Path
9
- from typing import Any , ClassVar , Sequence , Type
10
+ from typing import Any , ClassVar , Self , Sequence
10
11
12
+ import aiofiles .os
11
13
import attrs
12
14
import zmq
13
15
import zmq .asyncio
@@ -79,27 +81,19 @@ class AbstractTask(metaclass=ABCMeta):
79
81
async def run (self ) -> Any :
80
82
pass
81
83
82
- @classmethod
83
- def deserialize_from_request (cls , raw_data : Request ) -> AbstractTask :
84
- serializer_name = str (raw_data .header , "utf8" )
85
- values : tuple = msgpack .unpackb (raw_data .body )
86
- serializer_cls = SERIALIZER_MAP [serializer_name ]
87
- return serializer_cls .deserialize (values )
88
-
89
- def serialize_to_request (self ) -> Request :
90
- assert self .name in SERIALIZER_MAP
91
- header = bytes (self .name , "utf8" )
92
- return Request (header , self .serialize ())
93
-
94
84
@abstractmethod
95
85
def serialize (self ) -> bytes :
96
86
pass
97
87
98
88
@classmethod
99
89
@abstractmethod
100
- def deserialize (cls , values : tuple ) -> AbstractTask :
90
+ def deserialize (cls , values : tuple ) -> Self :
101
91
pass
102
92
93
+ def serialize_to_request (self ) -> Request :
94
+ header = bytes (self .name , "utf8" )
95
+ return Request (header , self .serialize ())
96
+
103
97
104
98
@attrs .define (slots = True )
105
99
class ChownTask (AbstractTask ):
@@ -120,8 +114,8 @@ def serialize(self) -> bytes:
120
114
))
121
115
122
116
@classmethod
123
- def deserialize (cls , values : tuple ) -> ChownTask :
124
- return ChownTask (
117
+ def deserialize (cls , values : tuple ) -> Self :
118
+ return cls (
125
119
values [0 ],
126
120
values [1 ],
127
121
values [2 ],
@@ -161,7 +155,7 @@ async def run(self) -> Any:
161
155
def from_event (
162
156
cls , event : DoVolumeMountEvent , * , mount_path : Path , mount_prefix : str | None = None
163
157
) -> MountTask :
164
- return MountTask (
158
+ return cls (
165
159
str (mount_path ),
166
160
event .quota_scope_id ,
167
161
event .fs_location ,
@@ -187,8 +181,8 @@ def serialize(self) -> bytes:
187
181
))
188
182
189
183
@classmethod
190
- def deserialize (cls , values : tuple ) -> MountTask :
191
- return MountTask (
184
+ def deserialize (cls , values : tuple ) -> Self :
185
+ return cls (
192
186
values [0 ],
193
187
QuotaScopeID .parse (values [1 ]),
194
188
values [2 ],
@@ -236,7 +230,7 @@ def from_event(
236
230
mount_prefix : str | None = None ,
237
231
timeout : float | None = None ,
238
232
) -> UmountTask :
239
- return UmountTask (
233
+ return cls (
240
234
str (mount_path ),
241
235
event .quota_scope_id ,
242
236
event .scaling_group ,
@@ -258,8 +252,8 @@ def serialize(self) -> bytes:
258
252
))
259
253
260
254
@classmethod
261
- def deserialize (cls , values : tuple ) -> UmountTask :
262
- return UmountTask (
255
+ def deserialize (cls , values : tuple ) -> Self :
256
+ return cls (
263
257
values [0 ],
264
258
QuotaScopeID .parse (values [1 ]),
265
259
values [2 ],
@@ -270,11 +264,29 @@ def deserialize(cls, values: tuple) -> UmountTask:
270
264
)
271
265
272
266
273
- SERIALIZER_MAP : dict [str , Type [AbstractTask ]] = {
274
- MountTask .name : MountTask ,
275
- UmountTask .name : UmountTask ,
276
- ChownTask .name : ChownTask ,
277
- }
267
+ @attrs .define (slots = True )
268
+ class DeletePathTask (AbstractTask ):
269
+ name = "delete-path"
270
+ path : Path
271
+
272
+ async def run (self ) -> Any :
273
+ if self .path .is_dir ():
274
+ loop = asyncio .get_running_loop ()
275
+ try :
276
+ await loop .run_in_executor (None , lambda : shutil .rmtree (self .path ))
277
+ except FileNotFoundError :
278
+ pass
279
+ else :
280
+ await aiofiles .os .remove (self .path )
281
+
282
+ def serialize (self ) -> bytes :
283
+ return msgpack .packb ((self .path ,))
284
+
285
+ @classmethod
286
+ def deserialize (cls , values : tuple ) -> Self :
287
+ return cls (
288
+ values [0 ],
289
+ )
278
290
279
291
280
292
@attrs .define (slots = True )
@@ -355,6 +367,7 @@ def __init__(
355
367
356
368
self .outsock = zctx .socket (zmq .PUSH )
357
369
self .outsock .bind (get_zmq_socket_file_path (output_sock_prefix , self .pidx ))
370
+ self .serializer_map = {cls .name : cls for cls in AbstractTask .__subclasses__ ()}
358
371
359
372
async def close (self ) -> None :
360
373
self .insock .close ()
@@ -366,13 +379,19 @@ async def ack(self) -> None:
366
379
async def respond (self , succeeded : bool , data : str ) -> None :
367
380
await Protocol .respond (self .outsock , Response (succeeded , data ))
368
381
382
+ def _deserialize_from_request (self , raw_data : Request ) -> AbstractTask :
383
+ serializer_name = str (raw_data .header , "utf8" )
384
+ values : tuple = msgpack .unpackb (raw_data .body )
385
+ serializer_cls = self .serializer_map [serializer_name ]
386
+ return serializer_cls .deserialize (values )
387
+
369
388
async def main (self ) -> None :
370
389
try :
371
390
while True :
372
391
client_request = await Protocol .listen_to_request (self .insock )
373
392
374
393
try :
375
- task = AbstractTask . deserialize_from_request (client_request )
394
+ task = self . _deserialize_from_request (client_request )
376
395
result = await task .run ()
377
396
except Exception as e :
378
397
log .exception (f"Error in watcher task. (e: { e } )" )
0 commit comments