@@ -141,6 +141,7 @@ def __init__(
141
141
max_concurrent : Optional [int ],
142
142
snapshot_frequency : int ,
143
143
initial_state : Optional [Dict [str , Any ]],
144
+ daemonic_reading : bool ,
144
145
):
145
146
self .source = source
146
147
self .map_fn = map_fn
@@ -149,6 +150,7 @@ def __init__(
149
150
self .method = method
150
151
self .mp_context = mp_context
151
152
self .snapshot_frequency = snapshot_frequency
153
+ self .daemonic_reading = daemonic_reading
152
154
153
155
self ._in_q : Union [queue .Queue , mp .Queue ] = queue .Queue () if method == "thread" else mp_context .Queue ()
154
156
self ._intermed_q : Union [queue .Queue , mp .Queue ] = queue .Queue () if method == "thread" else mp_context .Queue ()
@@ -182,7 +184,7 @@ def __init__(
182
184
self ._stop ,
183
185
),
184
186
name = "read_thread(target=_populate_queue)" ,
185
- daemon = True ,
187
+ daemon = self . daemonic_reading ,
186
188
)
187
189
self ._workers : List [Union [threading .Thread , mp .Process ]] = []
188
190
for worker_id in range (self .num_workers ):
@@ -311,6 +313,7 @@ def __init__(
311
313
multiprocessing_context : Optional [str ] = None ,
312
314
max_concurrent : Optional [int ] = None ,
313
315
snapshot_frequency : int = 1 ,
316
+ daemonic_reading : bool = True ,
314
317
):
315
318
super ().__init__ ()
316
319
assert method in ["thread" , "process" ]
@@ -329,6 +332,7 @@ def __init__(
329
332
raise ValueError (f"{ max_concurrent = } should be <= { num_workers = } !" )
330
333
self .max_concurrent = max_concurrent
331
334
self .snapshot_frequency = snapshot_frequency
335
+ self .daemonic_reading = daemonic_reading
332
336
self ._it : Optional [Union [_InlineMapperIter [T ], _ParallelMapperIter [T ]]] = None
333
337
334
338
def reset (self , initial_state : Optional [Dict [str , Any ]] = None ):
@@ -355,6 +359,7 @@ def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]):
355
359
max_concurrent = self .max_concurrent ,
356
360
snapshot_frequency = self .snapshot_frequency ,
357
361
initial_state = initial_state ,
362
+ daemonic_reading = self .daemonic_reading ,
358
363
)
359
364
360
365
def next (self ) -> T :
@@ -403,6 +408,7 @@ def __init__(
403
408
max_concurrent : Optional [int ] = None ,
404
409
snapshot_frequency : int = 1 ,
405
410
prebatch : Optional [int ] = None ,
411
+ daemonic_reading : bool = True ,
406
412
):
407
413
super ().__init__ ()
408
414
assert method in ["thread" , "process" ]
@@ -416,6 +422,7 @@ def __init__(
416
422
self .max_concurrent = max_concurrent
417
423
self .snapshot_frequency = snapshot_frequency
418
424
self .prebatch = prebatch
425
+ self .daemonic_reading = daemonic_reading
419
426
if prebatch is None :
420
427
self .map_fn = map_fn
421
428
self .source = source
@@ -434,6 +441,7 @@ def __init__(
434
441
multiprocessing_context = self .multiprocessing_context ,
435
442
max_concurrent = self .max_concurrent ,
436
443
snapshot_frequency = self .snapshot_frequency ,
444
+ daemonic_reading = self .daemonic_reading ,
437
445
)
438
446
439
447
if self .prebatch is None :
0 commit comments