Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Fix dataloader in to_static mode. #64334

Merged
merged 11 commits into from
May 30, 2024

Conversation

GhostScreaming
Copy link
Contributor

PR Category

Auto Parallel

PR Types

Bug fixes

Description

Pcard-73145
Fix dataloader in to_static mode. Get data from DataLoader iterator directly may affect data generation randomness of BatchSampler when Shuffle=True. It may cause difference of data feeding between dynamic and to_static mode.

Copy link

paddle-bot bot commented May 15, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add unittest for it

@@ -274,7 +274,8 @@ def __init__(
def _prepare_data_spec_from_dataloader(self, dataloader):
inputs_spec = []
labels_spec = []
data = next(iter(dataloader))

data = dataloader._get_input_spec()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if hasattr(dataloader, _get_input_spec):
data = dataloader._get_input_spec()
else:
data = next(iter(dataloader))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只有ShardDataloader会调用_prepare_data_spec_from_dataloader
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面的else分之会出现shuffle的问题吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_prepare_data_spec直接取dataset,对于IterableDataset,不支持Shuffle,不存在这个问题。对于普通Dataset,直接取第 0 个数据,不会影响随机性。
image
image

# of BatchSampler when `Shuffle=True`. It may cause difference of data feeding
# between dynamic and to_static mode.
def _get_input_spec(self):
batch_data = self._dataloader.batch_sampler.dataset.__getitem__(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about IterableDataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IterableDataset 也有__getitem__(self, index)方法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里会再修改一下,对IterableDataset还是需要特殊处理一下。

collate_fn = self._dataloader.collate_fn
batch_data = collate_fn(batch_data)
if isinstance(batch_data, dict):
batch_data = [batch_data]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch data is already a tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getitem(0)取出来的是一个np.ndarray

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥这个分支不需要转to tensor呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的顺序写错了。collate_fn 会把 batch 遍历一遍,需要输入 batch_data 是一个 list。完成collate_fn 后,_get_batch 需要输入为 Tensor 。

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@GhostScreaming GhostScreaming merged commit f406545 into PaddlePaddle:develop May 30, 2024
32 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants