Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modules got silently "reused" with hk.vmap #740

Open
jjyyxx opened this issue Oct 8, 2023 · 2 comments
Open

Modules got silently "reused" with hk.vmap #740

jjyyxx opened this issue Oct 8, 2023 · 2 comments

Comments

@jjyyxx
Copy link
Contributor

jjyyxx commented Oct 8, 2023

I have to admit that I do not fully understand the necessity of hk.vmap instead of jax.vmap. Nevertheless, when I need to vmap something, I would use hk.vmap whenever the inner function contains calls to haiku modules. This works OK, until I debug the bad performance of a transformer model. Things boils down to the following snippet

import jax, haiku as hk
jax.config.update("jax_platforms", "cpu")

def f1(x):
    def g(x):
        return hk.Linear(2)(x)
    x = g(x)
    x = g(x)
    return x

def f2(x):
    def g(x):
        return hk.Linear(2)(x)
    x = jax.vmap(g)(x)
    x = jax.vmap(g)(x)
    return x

key = jax.random.PRNGKey(42)
x = jax.random.normal(key, (1, 2))
w1 = hk.transform(f1).init(key, x)
w2 = hk.transform(f2).init(key, x)
print("w1:", w1.keys())
print("w2:", w2.keys())
# w1: dict_keys(['linear', 'linear_1'])
# w2: dict_keys(['linear'])

It turns out that when g is vmapped, modules created inside g would reuse a previously created module. In some cases, errors would happen immediately due to incompatible shape, but in other cases (for me, transformer layers have quite consistent shapes), things went wrong silently.

My question: Is this behavior intended? Could the documentation be improved on this topic? Or am I missing something?

@tomhennigan
Copy link
Collaborator

Hi @jjyyxx , this is indeed confusing behaviour.

Changing this would be backwards incompatible with all existing usages of hk.vmap so I think for now we will need to work around it.

hk.vmap is mostly useful when you make use of hk.{g,s}et_state, if you are developing a transformer then you are unlikely to be using these APIs and I think it would be safe to unconditionally use jax.vmap which I believe does what you want (creates a new instance of the module on each call to the mapped function).

If you need to use hk.vmap then there is a way to define a version of this that has the reuse semantics you want. Basically by wrapping the mapped function in a module (the only caveat is that this will add a prefix to the modules to disambiguate them):

def vmap_with_reuse(f, *, name: str | None = None):
  f = hk.vmap(f, split_rng=(not hk.running_init()))
  f = hk.to_module(f)
  return lambda *a, **k: f(name=name)(*a, **k)

def f3(x):
  def g(x):
    return hk.Linear(2)(x)
  x = vmap_with_reuse(g)(x)
  x = vmap_with_reuse(g)(x)
  return x

# w3: dict_keys(['g/linear', 'g_1/linear'])

@jjyyxx
Copy link
Contributor Author

jjyyxx commented Oct 9, 2023

Thanks for your suggestion! Indeed, I found that jax.vmap works just OK before filing this issue. But I was worried about the documentation saying hk.vmap is Equivalent to jax.vmap() with module parameters/state not mapped., which (from my perspective) implies that hk.vmap handles both parameter and state. So, I kept using hk.vmap at that time.

However, you mentioned that

hk.vmap is mostly useful when you make use of hk.{g,s}et_state

So, if only hk.get_parameter is used, there is no need to use hk.vmap? Also, what about the behavior of hk.next_rng_key inside jax.vmap?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants