-
Notifications
You must be signed in to change notification settings - Fork 391
Description
Motivation
Replay buffers are used to store a lot of data and ideally we could put this data on a GPU.
Often raw sensory observations are stored in these buffers, such as images or audio, which consumes many gigabytes of precious VRAM.
In torchRL, memmapping and prefetching could be considered the fix for this issue but I have instead pursued compression instead.
I've found that zstd is effective in compressing atari frames (reducing storage use by 150-250x) and my DQN implementation has gone from storing ~4M frames consuming ~20ish GB to less than a GB.
I would like some pointers on where it might be sensible to integrate compression into the torch rl replay buffer in the hopes that compressing raw sensory observations could be more widely adopted.
Solution
I would like feedback on my proposed solutions to implementing compression in the replay buffer.
A CompressedReplayBuffer class that inherents from the ReplayBuffer class could be one option.
This class would compress data "as it comes in", by modifying the append
and extend
functions.
The compressed data is then passed to the parent ReplayBuffer class as a uint8 byte stream for storage.
The compressed data is then decompressed in the sample
function.
Although, the issue in this approach is that the compressed byte stream can be of variable shape.
Some images might be compressed into ~150 bytes, but others might be ~188 bytes.
The TensorStorage object expects to receive objects of a consistent shape, so perhaps this is an inconvenient place to implement compression..
A compressed storage type might make more sense, as ultimately we want a storage type that can tolerate storing data of different lengths.
One approach could be to use the standard replay buffer but with a ListStorage
type.
By modifying the get
and set
functions we could compress and decompress as data comes in and out.
Alternatively, TensorStorage
could be modified to use a nested & jagged layout, which might be more efficient.
Additional context
Although I have used zstd as the compression algorithm in this issue, ideally we don't lock someone into using it.
Compression is just converting any arbitrary tensor into a uint8 1d tensor (the byte stream) and then decompression is rebuilding the prior tensor and recovering the shape and datatype.
If the data is sensory observations from a gym-like environment, we can reconstruct the observation's original shape and datatype from the gym environment metadata.
Checklist
- I have checked that there is no similar issue in the repo (required)