Skip to content

Commit 922ac06

Browse files
authored
Add flag for non-daemonic _read_thread in ParallelMapper
Differential Revision: D71839829 Pull Request resolved: #1468
1 parent edb2875 commit 922ac06

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

torchdata/nodes/map.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def __init__(
141141
max_concurrent: Optional[int],
142142
snapshot_frequency: int,
143143
initial_state: Optional[Dict[str, Any]],
144+
daemonic_reading: bool,
144145
):
145146
self.source = source
146147
self.map_fn = map_fn
@@ -149,6 +150,7 @@ def __init__(
149150
self.method = method
150151
self.mp_context = mp_context
151152
self.snapshot_frequency = snapshot_frequency
153+
self.daemonic_reading = daemonic_reading
152154

153155
self._in_q: Union[queue.Queue, mp.Queue] = queue.Queue() if method == "thread" else mp_context.Queue()
154156
self._intermed_q: Union[queue.Queue, mp.Queue] = queue.Queue() if method == "thread" else mp_context.Queue()
@@ -182,7 +184,7 @@ def __init__(
182184
self._stop,
183185
),
184186
name="read_thread(target=_populate_queue)",
185-
daemon=True,
187+
daemon=self.daemonic_reading,
186188
)
187189
self._workers: List[Union[threading.Thread, mp.Process]] = []
188190
for worker_id in range(self.num_workers):
@@ -311,6 +313,7 @@ def __init__(
311313
multiprocessing_context: Optional[str] = None,
312314
max_concurrent: Optional[int] = None,
313315
snapshot_frequency: int = 1,
316+
daemonic_reading: bool = True,
314317
):
315318
super().__init__()
316319
assert method in ["thread", "process"]
@@ -329,6 +332,7 @@ def __init__(
329332
raise ValueError(f"{max_concurrent=} should be <= {num_workers=}!")
330333
self.max_concurrent = max_concurrent
331334
self.snapshot_frequency = snapshot_frequency
335+
self.daemonic_reading = daemonic_reading
332336
self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None
333337

334338
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
@@ -355,6 +359,7 @@ def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]):
355359
max_concurrent=self.max_concurrent,
356360
snapshot_frequency=self.snapshot_frequency,
357361
initial_state=initial_state,
362+
daemonic_reading=self.daemonic_reading,
358363
)
359364

360365
def next(self) -> T:
@@ -403,6 +408,7 @@ def __init__(
403408
max_concurrent: Optional[int] = None,
404409
snapshot_frequency: int = 1,
405410
prebatch: Optional[int] = None,
411+
daemonic_reading: bool = True,
406412
):
407413
super().__init__()
408414
assert method in ["thread", "process"]
@@ -416,6 +422,7 @@ def __init__(
416422
self.max_concurrent = max_concurrent
417423
self.snapshot_frequency = snapshot_frequency
418424
self.prebatch = prebatch
425+
self.daemonic_reading = daemonic_reading
419426
if prebatch is None:
420427
self.map_fn = map_fn
421428
self.source = source
@@ -434,6 +441,7 @@ def __init__(
434441
multiprocessing_context=self.multiprocessing_context,
435442
max_concurrent=self.max_concurrent,
436443
snapshot_frequency=self.snapshot_frequency,
444+
daemonic_reading=self.daemonic_reading,
437445
)
438446

439447
if self.prebatch is None:

0 commit comments

Comments
 (0)