|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import asyncio |
| 8 | +import inspect |
7 | 9 | import warnings
|
| 10 | + |
8 | 11 | from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union
|
9 | 12 |
|
10 |
| -from torch.utils.data import functional_datapipe, IterDataPipe |
11 | 13 | from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, validate_input_col
|
| 14 | +from torchdata.datapipes import functional_datapipe |
| 15 | +from torchdata.datapipes.iter import IterDataPipe |
12 | 16 |
|
13 | 17 | T_co = TypeVar("T_co", covariant=True)
|
14 | 18 |
|
@@ -414,3 +418,148 @@ def __len__(self) -> int:
|
414 | 418 | if isinstance(self.datapipe, Sized):
|
415 | 419 | return len(self.datapipe)
|
416 | 420 | raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
| 421 | + |
| 422 | + |
| 423 | +class _BatchAsyncMapperIterDataPipe(IterDataPipe): |
| 424 | + datapipe: IterDataPipe |
| 425 | + async_fn: Callable |
| 426 | + |
| 427 | + def __init__( |
| 428 | + self, |
| 429 | + source_datapipe: IterDataPipe, |
| 430 | + async_fn: Callable, |
| 431 | + input_col=None, |
| 432 | + output_col=None, |
| 433 | + max_concurrency: int = 32, |
| 434 | + ): |
| 435 | + self.source_datapipe = source_datapipe |
| 436 | + if not inspect.iscoroutinefunction(async_fn): |
| 437 | + raise ValueError(f"Expected a corotine function with an async def syntax, but got a {type(async_fn)}") |
| 438 | + self.async_fn = async_fn # type: ignore[assignment] |
| 439 | + if input_col is None and output_col is not None: |
| 440 | + raise ValueError("`output_col` must be None when `input_col` is None.") |
| 441 | + self.input_col = input_col |
| 442 | + if isinstance(output_col, (list, tuple)): |
| 443 | + if len(output_col) > 1: |
| 444 | + raise ValueError("`output_col` must be a single-element list or tuple") |
| 445 | + output_col = output_col[0] |
| 446 | + self.output_col = output_col |
| 447 | + self.max_concurrency = max_concurrency |
| 448 | + |
| 449 | + def __iter__(self): |
| 450 | + for batch in self.source_datapipe: |
| 451 | + new_batch = asyncio.run(self.processbatch(batch)) |
| 452 | + yield new_batch |
| 453 | + |
| 454 | + async def processbatch(self, batch): |
| 455 | + sem = asyncio.Semaphore(self.max_concurrency) |
| 456 | + |
| 457 | + async def controlled_async_fn(async_fn, *data): |
| 458 | + async with sem: |
| 459 | + return await async_fn(*data) |
| 460 | + |
| 461 | + coroutines = [] |
| 462 | + if self.input_col is None: |
| 463 | + for data in batch: |
| 464 | + coroutines.append(controlled_async_fn(self.async_fn, data)) |
| 465 | + results = await asyncio.gather(*coroutines) |
| 466 | + return results |
| 467 | + |
| 468 | + for data in batch: |
| 469 | + if isinstance(self.input_col, (list, tuple)): |
| 470 | + args = tuple(data[col] for col in self.input_col) |
| 471 | + coroutines.append(controlled_async_fn(self.async_fn, *args)) |
| 472 | + else: |
| 473 | + coroutines.append(controlled_async_fn(self.async_fn, data[self.input_col])) |
| 474 | + results = await asyncio.gather(*coroutines) |
| 475 | + |
| 476 | + new_batch = [] |
| 477 | + for data, res in zip(batch, results): |
| 478 | + t_flag = isinstance(data, tuple) |
| 479 | + if t_flag: |
| 480 | + data = list(data) |
| 481 | + |
| 482 | + if self.output_col is None: |
| 483 | + if isinstance(self.input_col, (list, tuple)): |
| 484 | + data[self.input_col[0]] = res |
| 485 | + for idx in sorted(self.input_col[1:], reverse=True): |
| 486 | + del data[idx] |
| 487 | + else: |
| 488 | + data[self.input_col] = res |
| 489 | + elif self.output_col == -1: |
| 490 | + data.append(res) |
| 491 | + else: |
| 492 | + data[self.output_col] = res |
| 493 | + |
| 494 | + if t_flag: |
| 495 | + data = tuple(data) |
| 496 | + |
| 497 | + new_batch.append(data) |
| 498 | + return new_batch |
| 499 | + |
| 500 | + |
| 501 | +@functional_datapipe("async_map_batches") |
| 502 | +class BatchAsyncMapperIterDataPipe(IterDataPipe): |
| 503 | + r""" |
| 504 | + Combines elements from the source DataPipe to batches and applies a coroutine function |
| 505 | + over each element within the batch concurrently, then flattens the outpus to a |
| 506 | + single, unnested IterDataPipe (functional name: ``async_map_batches``). |
| 507 | +
|
| 508 | + Args: |
| 509 | + source_datapipe: Source IterDataPipe |
| 510 | + async_fn: The coroutine function to be applied to each batch of data |
| 511 | + batch_size: The size of batch to be aggregated from ``source_datapipe`` |
| 512 | + input_col: Index or indices of data which ``fn`` is applied, such as: |
| 513 | + - ``None`` as default to apply ``fn`` to the data directly. |
| 514 | + - Integer(s) is used for list/tuple. |
| 515 | + - Key(s) is used for dict. |
| 516 | + output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified |
| 517 | + only when ``input_col`` is not ``None`` |
| 518 | + - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with |
| 519 | + multiple indices, the left-most one is used, and other indices will be removed. |
| 520 | + - Integer is used for list/tuple. ``-1`` represents to append result at the end. |
| 521 | + - Key is used for dict. New key is acceptable. |
| 522 | + max_concurrency: Maximum concurrency to call async functions. (Default value: 32) |
| 523 | +
|
| 524 | + Example: |
| 525 | + >>> from torchdata.datapipes.iter import IterableWrapper |
| 526 | + >>> async def mul_ten(x): |
| 527 | + ... await asyncio.sleep(1) |
| 528 | + ... return x * 10 |
| 529 | + >>> dp = IterableWrapper(range(50)) |
| 530 | + >>> dp = dp.async_map_batches(mul_ten, 16) |
| 531 | + >>> list(dp) |
| 532 | + [0, 10, 20, 30, ...] |
| 533 | + >>> dp = IterableWrapper([(i, i) for i in range(50)]) |
| 534 | + >>> dp = dp.async_map_batches(mul_ten, 16, input_col=1) |
| 535 | + >>> list(dp) |
| 536 | + [(0, 0), (1, 10), (2, 20), (3, 30), ...] |
| 537 | + >>> dp = IterableWrapper([(i, i) for i in range(50)]) |
| 538 | + >>> dp = dp.async_map_batches(mul_ten, 16, input_col=1, output_col=-1) |
| 539 | + >>> list(dp) |
| 540 | + [(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...] |
| 541 | + # Async fetching html from remote |
| 542 | + >>> from aiohttp import ClientSession |
| 543 | + >>> async def fetch_html(url: str, **kwargs): |
| 544 | + ... async with ClientSession() as session: |
| 545 | + ... resp = await session.request(method="GET", url=url, **kwargs) |
| 546 | + ... resp.raise_for_status() |
| 547 | + ... html = await resp.text() |
| 548 | + ... return html |
| 549 | + >>> dp = IterableWrapper(urls) |
| 550 | + >>> dp = dp.async_map_batches(fetch_html, 16) |
| 551 | + """ |
| 552 | + |
| 553 | + def __new__( |
| 554 | + self, |
| 555 | + source_datapipe, |
| 556 | + async_fn: Callable, |
| 557 | + batch_size: int, |
| 558 | + input_col=None, |
| 559 | + output_col=None, |
| 560 | + max_concurrency: int = 32, |
| 561 | + ): |
| 562 | + dp = source_datapipe.batch(batch_size) |
| 563 | + dp = _BatchAsyncMapperIterDataPipe(dp, async_fn, input_col, output_col, max_concurrency) |
| 564 | + dp = dp.flatmap() |
| 565 | + return dp |
0 commit comments