Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

what is the right way to serialize DataLoader2 so that pipeline with shuffle can resume from the right place? #1177

Open
zhengwy888 opened this issue Jun 2, 2023 · 2 comments

Comments

@zhengwy888
Copy link

馃悰 Describe the bug

I tried all these versions, the only version that worked was the last one, but it's too hacky. Is there a better way?

    dp = IterableWrapper(list(range(20)))
    dp = dp.shuffle()
    items = []
    rs = InProcessReadingService()
    dl = DataLoader2(dp, reading_service=rs)
    iter1 = iter(dl)
    for _ in range(4):
        next(iter1)

    # 16 elements left in dl
    state = dl.state_dict()
    dl2 = DataLoader2.from_state(state, reading_service=rs)
    # assert len(list(dl2)) == 20 - 4  # got 20

    dp2 = deserialize_datapipe(serialize_datapipe(dl.datapipe))
    # assert len(list(dp2)) == 20 - 4 # got 20

    dp3 = deserialize_datapipe(serialize_datapipe(dl.datapipe))
    _simple_graph_snapshot_restoration(dp3, dp3._number_of_samples_yielded)
    ret3 = list(dp3)
    assert len(ret3) == 20 - 4
    # but content is not the same

    dl4 = DataLoader2.from_state(state, reading_service=rs)
    _simple_graph_snapshot_restoration(dl4.datapipe, dl.datapipe._number_of_samples_yielded)
    ret4 = list(dl4)
    assert len(ret4) == 20 - 4
    # but content is not the same

    dp5 = deserialize_datapipe(serialize_datapipe(dl.datapipe))
    pipes = get_all_pipes(dp5)
    for pipe in pipes:
        if isinstance(pipe, ShufflerIterDataPipe):
            buffer_cache = pipe._buffer[:]
            assert len(buffer_cache) == 20 - 4
            rng_state = pipe._rng.getstate()
    _simple_graph_snapshot_restoration(dp5, dl.datapipe._number_of_samples_yielded)
    dp5._buffer = buffer_cache[:]
    dp5._rng.setstate(rng_state)
    it5 = iter(dp5)
    ret5 = list(it5)
    assert len(ret5) == 20 - 4

    expected = list(iter1)
    # ret5 is the only method that worked
    # assert ret3 == expected
    # assert ret4 == expected
    assert ret5 == expected

Versions

PyTorch version: 2.0.0a0+gite9ebda2
Is debug build: False
CUDA used to build PyTorch: 12.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 12.0.1 (https://github.com/conda-forge/clangdev-feedstock d44358f44aef33e9fa7c5f93e2481ee8f1a04ab6)
CMake version: version 3.19.1
Libc version: glibc-2.31

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10)  [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-64-generic-x86_64-with-glibc2.10
Is CUDA available: False
CUDA runtime version: 12.0.140
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: False

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] mypy-protobuf==3.3.0
[pip3] numpy==1.23.5
[pip3] pytorch3d==0.6.2
[pip3] torch==2.0.1+1684801906.cuda120.cudnn891.nccl218.ap
[pip3] torch-mlir==1684442443
[pip3] torch-scatter==2.1.0
[pip3] torch-tb-profiler==0.4.1
[pip3] torchdata==0.7.0.dev20230601
[pip3] torchfile==0.1.0
[pip3] torchvision==0.15.1a0+42759b1
[conda] magma-cuda121             2.6.1                         1    pytorch
[conda] mkl                       2020.4             h726a3e6_304    conda-forge
[conda] mkl-include               2023.1.0         h84fe81f_48680    conda-forge
[conda] numpy                     1.23.5           py38h7042d01_0    conda-forge
[conda] pytorch3d                 0.6.2                    pypi_0    pypi
[conda] torch                     2.0.1+1684801906.cuda120.cudnn891.nccl218.ap          pypi_0    pypi
[conda] torch-mlir                1684442443               pypi_0    pypi
[conda] torch-scatter             2.1.0                    pypi_0    pypi
[conda] torch-tb-profiler         0.4.1                    pypi_0    pypi
[conda] torchfile                 0.1.0                    pypi_0    pypi
[conda] torchvision               0.15.1a0+42759b1          pypi_0    pypi
@ejguan
Copy link
Contributor

ejguan commented Jun 2, 2023

I think you can rely on the dlv2.state_dict() to get the state. But, it's still in prototyping mode it might has some Errors.

@zhengwy888
Copy link
Author

zhengwy888 commented Jun 8, 2023

but it didn't work, see example 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants