Skip to content

Commit daabdac

Browse files
ejguanfacebook-github-bot
authored andcommitted
Implement BatchAsyncMapper (#1044)
Summary: ### Changes - Add `BatchAsyncMapper` to support async processing - Add syntax sugar to combine `batch().async_map().flatmap()` - `input_col`/`output_col` are added to align the behavior to `Mapper` - Add unit tests Pull Request resolved: #1044 Reviewed By: wenleix, NivekT Differential Revision: D43573398 Pulled By: ejguan fbshipit-source-id: 8a933c381c35d76712500e551c78047681e6e8cb
1 parent ca83b84 commit daabdac

File tree

5 files changed

+241
-4
lines changed

5 files changed

+241
-4
lines changed

docs/source/torchdata.datapipes.iter.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ These DataPipes apply the a given function to each element in the DataPipe.
162162
:toctree: generated/
163163
:template: class_template.rst
164164

165+
BatchAsyncMapper
165166
BatchMapper
166167
FlatMapper
167168
Mapper

test/test_iterdatapipe.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import asyncio
78
import io
89
import itertools
910
import pickle
@@ -78,6 +79,16 @@ def _convert_to_tensor(data):
7879
return torch.tensor(data)
7980

8081

82+
async def _async_mul_ten(x):
83+
await asyncio.sleep(1)
84+
return x * 10
85+
86+
87+
async def _async_x_mul_y(x, y):
88+
await asyncio.sleep(1)
89+
return x * y
90+
91+
8192
class TestIterDataPipe(expecttest.TestCase):
8293
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
8394
source_dp = IterableWrapper(range(10))
@@ -1517,6 +1528,79 @@ def test_pin_memory(self):
15171528
)
15181529
self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for batch in dp for d0, d1 in batch))
15191530

1531+
def test_async_map_batches(self):
1532+
batch_size = 16
1533+
1534+
def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_concurrency=32):
1535+
dp = IterableWrapper(input_data)
1536+
dp = dp.async_map_batches(async_fn, batch_size, input_col, output_col, max_concurrency)
1537+
self.assertEqual(
1538+
exp_res,
1539+
list(dp),
1540+
msg=f"Async map test with {async_fn=}, {input_col=}, {output_col=}, {max_concurrency=}",
1541+
)
1542+
1543+
_helper(range(50), [i * 10 for i in range(50)], _async_mul_ten)
1544+
1545+
# Smaller max_concurrency
1546+
_helper(range(50), [i * 10 for i in range(50)], _async_mul_ten, max_concurrency=6)
1547+
1548+
# Tuple with input_col
1549+
_helper([(i, i) for i in range(50)], [(i * 10, i) for i in range(50)], _async_mul_ten, input_col=0)
1550+
_helper([(i, i) for i in range(50)], [(i, i * 10) for i in range(50)], _async_mul_ten, input_col=1)
1551+
# Tuple with input_col and output_col
1552+
_helper(
1553+
[(i, i) for i in range(50)], [(i, i * 10) for i in range(50)], _async_mul_ten, input_col=0, output_col=1
1554+
)
1555+
_helper(
1556+
[(i, i) for i in range(50)], [(i, i, i * 10) for i in range(50)], _async_mul_ten, input_col=0, output_col=-1
1557+
)
1558+
1559+
# Dict with input_col
1560+
_helper(
1561+
[{"a": i, "b": i} for i in range(50)],
1562+
[{"a": i, "b": i * 10} for i in range(50)],
1563+
_async_mul_ten,
1564+
input_col="b",
1565+
)
1566+
# Dict with input_col and output_col
1567+
_helper(
1568+
[{"a": i, "b": i} for i in range(50)],
1569+
[{"a": i * 10, "b": i} for i in range(50)],
1570+
_async_mul_ten,
1571+
input_col="b",
1572+
output_col="a",
1573+
)
1574+
_helper(
1575+
[{"a": i, "b": i} for i in range(50)],
1576+
[{"a": i, "b": i, "c": i * 10} for i in range(50)],
1577+
_async_mul_ten,
1578+
input_col="b",
1579+
output_col="c",
1580+
)
1581+
1582+
# Multiple input_col
1583+
_helper(
1584+
[(i - 1, i, i + 1) for i in range(50)],
1585+
[((i - 1) * (i + 1), i) for i in range(50)],
1586+
_async_x_mul_y,
1587+
input_col=(0, 2),
1588+
)
1589+
_helper(
1590+
[(i - 1, i, i + 1) for i in range(50)],
1591+
[(i, (i - 1) * (i + 1)) for i in range(50)],
1592+
_async_x_mul_y,
1593+
input_col=(2, 0),
1594+
)
1595+
# Multiple input_col with output_col
1596+
_helper(
1597+
[(i - 1, i, i + 1) for i in range(50)],
1598+
[(i - 1, (i - 1) * (i + 1), i + 1) for i in range(50)],
1599+
_async_x_mul_y,
1600+
input_col=(0, 2),
1601+
output_col=1,
1602+
)
1603+
15201604

15211605
if __name__ == "__main__":
15221606
unittest.main()

tools/gen_pyi.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,18 @@ def gen_pyi() -> None:
6767
iterDP_files_to_exclude: Set[str] = {"__init__.py"}
6868
iterDP_deprecated_files: Set[str] = set()
6969
iterDP_method_to_special_output_type: Dict[str, str] = {
70+
"async_map_batches": "IterDataPipe",
7071
"bucketbatch": "IterDataPipe",
7172
"dataframe": "torcharrow.DataFrame",
7273
"end_caching": "IterDataPipe",
73-
"unzip": "List[IterDataPipe]",
74+
"extract": "IterDataPipe",
7475
"random_split": "Union[IterDataPipe, List[IterDataPipe]]",
7576
"read_from_tar": "IterDataPipe",
7677
"read_from_xz": "IterDataPipe",
7778
"read_from_zip": "IterDataPipe",
78-
"extract": "IterDataPipe",
79-
"to_map_datapipe": "MapDataPipe",
8079
"round_robin_demux": "List[IterDataPipe]",
80+
"to_map_datapipe": "MapDataPipe",
81+
"unzip": "List[IterDataPipe]",
8182
}
8283
iter_method_name_exclusion: Set[str] = {"def extract", "read_from_tar", "read_from_xz", "read_from_zip"}
8384

torchdata/datapipes/iter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
MaxTokenBucketizerIterDataPipe as MaxTokenBucketizer,
6767
)
6868
from torchdata.datapipes.iter.transform.callable import (
69+
BatchAsyncMapperIterDataPipe as BatchAsyncMapper,
6970
BatchMapperIterDataPipe as BatchMapper,
7071
DropperIterDataPipe as Dropper,
7172
FlatMapperIterDataPipe as FlatMapper,
@@ -136,6 +137,7 @@
136137
__all__ = [
137138
"AISFileLister",
138139
"AISFileLoader",
140+
"BatchAsyncMapper",
139141
"BatchMapper",
140142
"Batcher",
141143
"BucketBatcher",

torchdata/datapipes/iter/transform/callable.py

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import asyncio
8+
import inspect
79
import warnings
10+
811
from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union
912

10-
from torch.utils.data import functional_datapipe, IterDataPipe
1113
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, validate_input_col
14+
from torchdata.datapipes import functional_datapipe
15+
from torchdata.datapipes.iter import IterDataPipe
1216

1317
T_co = TypeVar("T_co", covariant=True)
1418

@@ -414,3 +418,148 @@ def __len__(self) -> int:
414418
if isinstance(self.datapipe, Sized):
415419
return len(self.datapipe)
416420
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
421+
422+
423+
class _BatchAsyncMapperIterDataPipe(IterDataPipe):
424+
datapipe: IterDataPipe
425+
async_fn: Callable
426+
427+
def __init__(
428+
self,
429+
source_datapipe: IterDataPipe,
430+
async_fn: Callable,
431+
input_col=None,
432+
output_col=None,
433+
max_concurrency: int = 32,
434+
):
435+
self.source_datapipe = source_datapipe
436+
if not inspect.iscoroutinefunction(async_fn):
437+
raise ValueError(f"Expected a corotine function with an async def syntax, but got a {type(async_fn)}")
438+
self.async_fn = async_fn # type: ignore[assignment]
439+
if input_col is None and output_col is not None:
440+
raise ValueError("`output_col` must be None when `input_col` is None.")
441+
self.input_col = input_col
442+
if isinstance(output_col, (list, tuple)):
443+
if len(output_col) > 1:
444+
raise ValueError("`output_col` must be a single-element list or tuple")
445+
output_col = output_col[0]
446+
self.output_col = output_col
447+
self.max_concurrency = max_concurrency
448+
449+
def __iter__(self):
450+
for batch in self.source_datapipe:
451+
new_batch = asyncio.run(self.processbatch(batch))
452+
yield new_batch
453+
454+
async def processbatch(self, batch):
455+
sem = asyncio.Semaphore(self.max_concurrency)
456+
457+
async def controlled_async_fn(async_fn, *data):
458+
async with sem:
459+
return await async_fn(*data)
460+
461+
coroutines = []
462+
if self.input_col is None:
463+
for data in batch:
464+
coroutines.append(controlled_async_fn(self.async_fn, data))
465+
results = await asyncio.gather(*coroutines)
466+
return results
467+
468+
for data in batch:
469+
if isinstance(self.input_col, (list, tuple)):
470+
args = tuple(data[col] for col in self.input_col)
471+
coroutines.append(controlled_async_fn(self.async_fn, *args))
472+
else:
473+
coroutines.append(controlled_async_fn(self.async_fn, data[self.input_col]))
474+
results = await asyncio.gather(*coroutines)
475+
476+
new_batch = []
477+
for data, res in zip(batch, results):
478+
t_flag = isinstance(data, tuple)
479+
if t_flag:
480+
data = list(data)
481+
482+
if self.output_col is None:
483+
if isinstance(self.input_col, (list, tuple)):
484+
data[self.input_col[0]] = res
485+
for idx in sorted(self.input_col[1:], reverse=True):
486+
del data[idx]
487+
else:
488+
data[self.input_col] = res
489+
elif self.output_col == -1:
490+
data.append(res)
491+
else:
492+
data[self.output_col] = res
493+
494+
if t_flag:
495+
data = tuple(data)
496+
497+
new_batch.append(data)
498+
return new_batch
499+
500+
501+
@functional_datapipe("async_map_batches")
502+
class BatchAsyncMapperIterDataPipe(IterDataPipe):
503+
r"""
504+
Combines elements from the source DataPipe to batches and applies a coroutine function
505+
over each element within the batch concurrently, then flattens the outpus to a
506+
single, unnested IterDataPipe (functional name: ``async_map_batches``).
507+
508+
Args:
509+
source_datapipe: Source IterDataPipe
510+
async_fn: The coroutine function to be applied to each batch of data
511+
batch_size: The size of batch to be aggregated from ``source_datapipe``
512+
input_col: Index or indices of data which ``fn`` is applied, such as:
513+
- ``None`` as default to apply ``fn`` to the data directly.
514+
- Integer(s) is used for list/tuple.
515+
- Key(s) is used for dict.
516+
output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified
517+
only when ``input_col`` is not ``None``
518+
- ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with
519+
multiple indices, the left-most one is used, and other indices will be removed.
520+
- Integer is used for list/tuple. ``-1`` represents to append result at the end.
521+
- Key is used for dict. New key is acceptable.
522+
max_concurrency: Maximum concurrency to call async functions. (Default value: 32)
523+
524+
Example:
525+
>>> from torchdata.datapipes.iter import IterableWrapper
526+
>>> async def mul_ten(x):
527+
... await asyncio.sleep(1)
528+
... return x * 10
529+
>>> dp = IterableWrapper(range(50))
530+
>>> dp = dp.async_map_batches(mul_ten, 16)
531+
>>> list(dp)
532+
[0, 10, 20, 30, ...]
533+
>>> dp = IterableWrapper([(i, i) for i in range(50)])
534+
>>> dp = dp.async_map_batches(mul_ten, 16, input_col=1)
535+
>>> list(dp)
536+
[(0, 0), (1, 10), (2, 20), (3, 30), ...]
537+
>>> dp = IterableWrapper([(i, i) for i in range(50)])
538+
>>> dp = dp.async_map_batches(mul_ten, 16, input_col=1, output_col=-1)
539+
>>> list(dp)
540+
[(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...]
541+
# Async fetching html from remote
542+
>>> from aiohttp import ClientSession
543+
>>> async def fetch_html(url: str, **kwargs):
544+
... async with ClientSession() as session:
545+
... resp = await session.request(method="GET", url=url, **kwargs)
546+
... resp.raise_for_status()
547+
... html = await resp.text()
548+
... return html
549+
>>> dp = IterableWrapper(urls)
550+
>>> dp = dp.async_map_batches(fetch_html, 16)
551+
"""
552+
553+
def __new__(
554+
self,
555+
source_datapipe,
556+
async_fn: Callable,
557+
batch_size: int,
558+
input_col=None,
559+
output_col=None,
560+
max_concurrency: int = 32,
561+
):
562+
dp = source_datapipe.batch(batch_size)
563+
dp = _BatchAsyncMapperIterDataPipe(dp, async_fn, input_col, output_col, max_concurrency)
564+
dp = dp.flatmap()
565+
return dp

0 commit comments

Comments
 (0)