Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using ShardingFilterIterDataPipe with MPRS may cause unnecessary batch drops. #1180

Open
yuxinyuan opened this issue Jun 7, 2023 · 0 comments

Comments

@yuxinyuan
Copy link

馃悰 Describe the bug

When using ShardingFilterIterDataPipe, the data in the datapipe will be evenly sharded to num_of_instances workers. However, if we called batch() later on the datapipe, the overly even distribution can cause workers to discard data that would not need to be discarded otherwise.

This might not be considered a bug, but it's kind of unexpected. Besides, the current ShardingFilterIterDataPipe will produce different batches of data for different number of workers, which is also kind of unexpected.

dp = torchdata.datapipes.iter.IterableWrapper(range(15)).sharding_filter().batch(5)

loader = DataLoader2(dp, reading_service=MultiProcessingReadingService(2))
for i in loader:
    print(i)
loader.shutdown()
print("++++++++++++++++++++++++++++++++++++++++++++++")
loader = DataLoader2(dp, reading_service=MultiProcessingReadingService(1))
for i in loader:
    print(i)
loader.shutdown()

This gives the following result:

[0, 2, 4, 6, 8]
[1, 3, 5, 7, 9]
[10, 12, 14]  # These two batches will be dropped, if we set drop_last to True
[11, 13]
++++++++++++++++++++++++++++++++++++++++++++++
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[10, 11, 12, 13, 14]

One solution to this is to use a sharding filter that is aware of the batch size of the datapipe. Maybe something like the following:

class BatchShardingFilterIterDataPipe(torchdata.datapipes.iter.ShardingFilter):
    def __init__(self, source_datapipe, sharding_group_filter=None):
        super().__init__(source_datapipe, sharding_group_filter)
        self.batch_size = 1

    def set_batch_size(self, batch_size, drop_last):
        self.batch_size = batch_size

    def __iter__(self):
        for i, batch_items in enumerate(
            self.source_datapipe.batch(batch_size=self.batch_size, drop_last=False)
        ):
            if i % self.num_of_instances == self.instance_id:
                yield from batch_items

set_batch_size() needs to be called once the batch size is determined.

I wonder what do the torchdata team think of the current sharding filter. Is its behavior expected?

Versions

torch 2.0.0
torchaudio 2.0.0
torchdata 0.6.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant