Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider using torch.compile(model, fullgraph=True, mode="reduce-overhead") #6

Open
lezcano opened this issue Jun 11, 2024 · 11 comments

Comments

@lezcano
Copy link

lezcano commented Jun 11, 2024

fullgraph=True will make sure that there are no graphbreaks (this may already be the case).
mode="reduce-overhead" will use CUDA graphs if possible. See in [these benchmarks] that going from regular torch.compile to reduce-overhead gives a good 70-100% speed-up on top of regular torch.compile.
image

@lezcano
Copy link
Author

lezcano commented Jun 11, 2024

For performance reasons, you might also want to avoid synchronising after every loop iteration, and just doing it after 10 or 20 iterations and averaging out the result. That being said, I understand this would affect QoL for the script, so fair enough.

@zeux
Copy link

zeux commented Jun 12, 2024

On RTX 4090 / PyTorch nightly this reduces the throughput slightly (from 130k tok/s to 127 tok/s, or equivalently from 4030ms dt to 4116ms dt; using B=16 to make sure training fits into 24 GB VRAM). This is specifically attributed to reduce-overhead mode, fullgraph=True works fine without changing performance (as I understand it merely turns graph breaks into compile errors and there are no graph breaks).

@lezcano
Copy link
Author

lezcano commented Jun 12, 2024

Perhaps some tweaks are needed to make them run. You can see whether they were enabled or not running your program with TORCH_LOGS=cudagraphs.

@zeux
Copy link

zeux commented Jun 12, 2024

Yes, the logs indicate that mode=reduce-overhead uses cuda graphs and by default they are not used. I assume there are some restrictions on kernel compilation/fusion when cuda graphs are enabled and these outweigh the CPU overhead savings in this case, as an individual step is fairly expensive anyway.

@lezcano
Copy link
Author

lezcano commented Jun 12, 2024

Within PyTorch there are no heuristics on whether to use cudagraphs or not. If reduce-overhead is on, PyTorch will try its best to use cudagrapsh. There are some limitations as to the programs that we can turn on cuda-graphs for in terms of input mutations, graph dynamism and so on, though. Sometimes the implementation needs to be tweaked a little bit (often not too much) to make it amenable to be used with cudagraphs.

@zeux
Copy link

zeux commented Jun 12, 2024

Sure - my point is that whatever else reduce-overhead changes in the compilation process, it’s more detrimental to overall performance on this workload on 4090 than cuda graphs are beneficial.

@JohannesVod
Copy link

someone has to try this out on a A100, it probably boosts the performance quite a lot. There are also other flags that are worth trying.

@marib00
Copy link

marib00 commented Jun 18, 2024

Tried on a H100. Goes down from ~277k tok/sec to ~269k tok/sec on nightly 2.5.0.dev20240616+cu124 🤷‍♂️

@lezcano
Copy link
Author

lezcano commented Jun 18, 2024

A few points:

  • reduce-overhead tries to turn on the CUDA graphs. Sometimes it can't, and will simply fallback to eager
  • To see the reasons why this may have failed, run the program with TORCH_LOGS=cudagraphs
  • If PyTorch could not run the model with CUDA graphs enabled, you might need to perform some minor modifications to the model for it to run
  • All in all, I'd be very surprised if the model with CUDA graphs enabled actually run slower than without them.

If you find the culprit of why didn't it run in the first place, feel free to tag @eellison in that PR. He's the maintainer of CUDA graphs within PyTorch.

@JohannesVod
Copy link

@marib00 very nice! Can you try "max-autotune" as well maybe? It is documented in https://pytorch.org/docs/stable/generated/torch.compile.html and might be even faster. Anyway, someone should create a PR

@marib00
Copy link

marib00 commented Jun 18, 2024

@JohannesVod I did try "max-autotune" already and no change; it was compiling (i.e. autotuning) forever though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants