Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficiency difference in using jax.lax.fori_loop vs looping over identical layers? #709

Open
hrbigelow opened this issue Jul 28, 2023 · 2 comments

Comments

@hrbigelow
Copy link

hrbigelow commented Jul 28, 2023

This might be a question for jax, but I think it probably comes up in Haiku.

Supposing I have the code within some hk.Module:

out = input
# the code in each layer is identical, only the parameters differ
for layer in self.layers:
  out = layer(out)
return out

And, assume that each layer is an instance of the same derived hk.Module class that uses hk.get_parameter inside its __call__ method.

Given the situation that the code is identical in each layer, one could express it as a jax.lax.fori_loop, but it is quite awkward.

Would there be any efficiency gain doing so? Or would the jax compiler be smart enough to effectively do this anyhow?

# parameters previously defined by hk.get_parameter in the above, merged across layers
all_layer_params = ...

def layer_fn(i, input):
    # the code in any layer of above self.layers
    layer_params = jax.lax.dynamic_slice(all_layer_params, i)
    ...

return jax.lax.fori_loop(0, num_layers, layer_fn, input)

Is there a way to do this idiomatically in Haiku, to take advantage of the internal hk.get_parameter calls?

Thanks in advance!

@tomhennigan
Copy link
Collaborator

Hey @hrbigelow, both versions should work and in theory should be equally efficient, however we've seen a few cases (in particular with transformer models) where if you use structured control flow the XLA compiler does a better job at optimizing (in particular reducing peak memory usage) and (sometimes) overlapping communication with compute.

The recommended pattern in Haiku for repeated application of a block is to use hk.experimental.layer_stack.

The implementation of layer stack is kind of complex (it handles quite a few edge cases) but it basically boils down to using jax.lax.scan for the per-layer init and apply functions correctly.

@hrbigelow
Copy link
Author

Thanks Tom. Actually I'm looking at the examples for hk.experimental.layer_stack. Just making sure I understand, it doesn't seem possible to somehow use layer_stack the same way you would use an ordinary hk.Module that has calls to hk.get_parameter, is that right?

Instead, if you wanted to build an hk.Module method that used layer_stack, you'd need to somehow obtain the pure function f to pass to stack(f)(...).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants