Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jax.tree_util.register_simple #21245

Closed
wants to merge 1 commit into from
Closed

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented May 15, 2024

This PR proposes a new API, jax.tree_util.register_simple that will simplify pytree registration and promote best practices in the most common of cases. Example usage with dataclass, although this will work on any class that can be flattened and unflattened via simple attribute access:

import jax
import jax.numpy as jnp
from dataclasses import dataclass

@jax.tree_util.register_simple(
    dynamic_attributes=['x', 'y'],
    static_attributes=['val'])
@dataclass
class MyContainer:
    x: jax.Array
    y: jax.Array
    val: str

m = MyContainer(jnp.zeros(4), jnp.arange(4), val='name')

leaves, tree = jax.tree.flatten(m)
m2 = jax.tree.unflatten(tree, leaves)

print(m)
# MyContainer(x=Array([0., 0., 0., 0.], dtype=float32), y=Array([0, 1, 2, 3], dtype=int32), val='name')
print(m2)
# MyContainer(x=Array([0., 0., 0., 0.], dtype=float32), y=Array([0, 1, 2, 3], dtype=int32), val='name')

One benefit of this is that it's somewhat self-documenting: we mention in the documentation that children should contain dynamic data, and aux_data should contain static data, but in practice users often miss this detail (e.g. we've had several questions in recent weeks about problems arising after including arrays in aux_data). Calling these static_attributes and dynamic_attributes should hopefully point people in the right direction.

@jakevdp jakevdp self-assigned this May 15, 2024
@jakevdp jakevdp marked this pull request as draft May 15, 2024 17:48
@jakevdp
Copy link
Collaborator Author

jakevdp commented May 21, 2024

This is probably too duplicative of jax.tree_util.register_dataclass. Closing.

@jakevdp jakevdp closed this May 25, 2024
@jakevdp jakevdp deleted the pytree-simple branch May 25, 2024 21:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant