Skip to content

Update typing using jaxtyping? #385

@bagibence

Description

@bagibence

#365 introduces support for Optimistix solvers. For its typing Optimistix uses jaxtyping, so at first I copied their annotations and introduced a dependency on jaxtyping.
I rolled these back, simply using the already defined Pytree instead, losing only the parametrized definitions (aux_struct: PyTree[jax.ShapeDtypeStruct]), which I left there as comments.

The questions is if you want to add jaxtyping as a dependency, and update type annotations to use it and be consistent with it? I'm not sure it's necessary.

If yes, I would suggest standardizing the spelling to PyTree everywhere. Currently, both "Pytree" and "PyTree" are present in nemos. (Also everywhere for that matter: in this thread about adding the annotation to jax itself and even in the official docs.)
Also KeyArrayLike could be then replaced by jaxtyping.PRNGKeyArray .

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions