-
Notifications
You must be signed in to change notification settings - Fork 13
Description
#365 introduces support for Optimistix solvers. For its typing Optimistix uses jaxtyping, so at first I copied their annotations and introduced a dependency on jaxtyping.
I rolled these back, simply using the already defined Pytree instead, losing only the parametrized definitions (aux_struct: PyTree[jax.ShapeDtypeStruct]), which I left there as comments.
The questions is if you want to add jaxtyping as a dependency, and update type annotations to use it and be consistent with it? I'm not sure it's necessary.
If yes, I would suggest standardizing the spelling to PyTree everywhere. Currently, both "Pytree" and "PyTree" are present in nemos. (Also everywhere for that matter: in this thread about adding the annotation to jax itself and even in the official docs.)
Also KeyArrayLike could be then replaced by jaxtyping.PRNGKeyArray .