-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi GPU support #9
Comments
Ahh that's a good idea. It was not on the roadmap, but I would imagine doing something like it would not be that difficult. Do you think it would just largely involve |
also interested in this. +1 |
We adapted PureJaxRL ppo+rnn implementation to the multi-gpu with pmap in XLand-MiniGrid and it scales well (almost linear from 1 up to 8 A100 gpus)! |
Awesome! I took a quick look -- I see that the env steps per second scales linearly; however, do you know how performance scales with time? |
@luchris429 It just takes a bit more to compile in general (If I correctly understood time as number of total timesteps). I didn't notice any other performance dips for the 10 minute and ~8 hour runs. GPU utilization 100%, OOM does not happen. |
I was wondering if there are any plans to release multi-gpu training code?
Naively pmapping and using DDPPO does not seem to scale well, as the gpus remain idle while syncing the gradients.
The text was updated successfully, but these errors were encountered: