Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hk.custom_getter with_sharding_constraint #666

Open
kavorite opened this issue Jun 5, 2023 · 0 comments
Open

hk.custom_getter with_sharding_constraint #666

kavorite opened this issue Jun 5, 2023 · 0 comments

Comments

@kavorite
Copy link

kavorite commented Jun 5, 2023

I have this code:

shards = jax.sharding.PositionalSharding(np.array(jax.devices())).reshape(-1, 2)

class ShardGetter:
    def __init__(self):
        self.transpose = True
        self.placement_cache = defaultdict(dict)

    def __call__(
        self,
        next_getter: callable,
        value: jax.Array,
        context: hk.GetterContext,
    ):
        if value.ndim == 2:
            if context.full_name in self.placement_cache:
                placement = self.placement_cache[context.module_name][context.name]
            else:
                if self.transpose:
                    placement = shards.replicate(0).T
                else:
                    placement = shards.replicate(0)
                self.placement_cache[context.module_name][context.name] = placement
                self.transpose = not self.transpose
        else:
            placement = shards.replicate()
        value = jax.lax.with_sharding_constraint(value, placement)
        return next_getter(value)

what i want is for each matrix to be sharded 'vertically' if its access index is even, 'horizontally' if its access index is odd (approach inspired by https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). It is more appropriate to perform this sharding according to how the weights are accessed than according to how they are created in order to minimize I/O overhead and associated idleness in the forward pass.

However, this code does not work! Everything is just on TPU 0 instead? That is after running it, everything is still placed on the first device:

shard_getter = ShardGetter()

@hk.without_apply_rng
@hk.transform
def objective(inputs: Batch):
    # ... do some application-specific, error-computing stuff...
    with hk.custom_getter(shard_getter):
        loss = model(inputs)
    return loss

@jax.jit
def train_init(rng, inputs):
    params = objective.init(rng, inputs)
    opt_st = optimizer().init(params)
    loss = 0.0
    step = 0
    return TrainState(params, opt_st, loss, step)
    
inputs = next(batches)
inputs = jax.device_put(inputs, shards.replicate(-1))  # meticulously arrange everything _just so..._
tstate = train_init(jax.device_put(jax.random.PRNGKey(42), shards.replicate()), inputs)
jax.debug.visualize_array_sharding(tstate.params)  # should be fully sharded-- somehow not?

What I would like is for this code to apply the sharding constraints specified in ShardGetter.__call__! As it stands, my monkey patch for this limitation is that I just do this:

@jax.jit
def train_init(rng, inputs):
    params = objective.init(rng, inputs)
    shtree = dict(shard_getter.placement_cache)
    shtree = hkds.merge(hkds.map(lambda *_: shards.replicate(), params), shtree)
    params = jtu.tree_map(jax.lax.with_sharding_constraint, params, shtree)
    opt_st = optimizer().init(params)
    loss = 0.0
    step = 0
    return TrainState(params, opt_st, loss, step)

which works fine. Understand this is mainly an aesthetic concern (correcting it only required adding three lines). Still, mystified as to what seems to be erasing the sharding constraints after the application of the getter interceptor?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant