-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Fix dataloader in to_static mode. #64334
[AutoParallel] Fix dataloader in to_static mode. #64334
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unittest for it
@@ -274,7 +274,8 @@ def __init__( | |||
def _prepare_data_spec_from_dataloader(self, dataloader): | |||
inputs_spec = [] | |||
labels_spec = [] | |||
data = next(iter(dataloader)) | |||
|
|||
data = dataloader._get_input_spec() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if hasattr(dataloader, _get_input_spec):
data = dataloader._get_input_spec()
else:
data = next(iter(dataloader))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下面的else分之会出现shuffle的问题吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# of BatchSampler when `Shuffle=True`. It may cause difference of data feeding | ||
# between dynamic and to_static mode. | ||
def _get_input_spec(self): | ||
batch_data = self._dataloader.batch_sampler.dataset.__getitem__(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about IterableDataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IterableDataset
也有__getitem__(self, index)
方法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里会再修改一下,对IterableDataset
还是需要特殊处理一下。
collate_fn = self._dataloader.collate_fn | ||
batch_data = collate_fn(batch_data) | ||
if isinstance(batch_data, dict): | ||
batch_data = [batch_data] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch data is already a tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getitem(0)取出来的是一个np.ndarray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥这个分支不需要转to tensor呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的顺序写错了。collate_fn
会把 batch 遍历一遍,需要输入 batch_data
是一个 list。完成collate_fn
后,_get_batch
需要输入为 Tensor 。
… fix_dataloader
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
Bug fixes
Description
Pcard-73145
Fix dataloader in to_static mode. Get data from DataLoader iterator directly may affect data generation randomness of BatchSampler when
Shuffle=True
. It may cause difference of data feeding between dynamic and to_static mode.