-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Efficiency difference in using jax.lax.fori_loop vs looping over identical layers? #709
Comments
Hey @hrbigelow, both versions should work and in theory should be equally efficient, however we've seen a few cases (in particular with transformer models) where if you use structured control flow the XLA compiler does a better job at optimizing (in particular reducing peak memory usage) and (sometimes) overlapping communication with compute. The recommended pattern in Haiku for repeated application of a block is to use The implementation of layer stack is kind of complex (it handles quite a few edge cases) but it basically boils down to using |
Thanks Tom. Actually I'm looking at the examples for Instead, if you wanted to build an |
This might be a question for jax, but I think it probably comes up in Haiku.
Supposing I have the code within some
hk.Module
:And, assume that each
layer
is an instance of the same derivedhk.Module
class that useshk.get_parameter
inside its__call__
method.Given the situation that the code is identical in each layer, one could express it as a
jax.lax.fori_loop
, but it is quite awkward.Would there be any efficiency gain doing so? Or would the jax compiler be smart enough to effectively do this anyhow?
Is there a way to do this idiomatically in Haiku, to take advantage of the internal
hk.get_parameter
calls?Thanks in advance!
The text was updated successfully, but these errors were encountered: