Skip to content

Execute + differentiate Tesseracts as part of JAX programs, with full support for function transformations like JIT, grad, and more. ⚡

License

Notifications You must be signed in to change notification settings

pasteurlabs/tesseract-jax

Repository files navigation

Tesseract-JAX

Tesseract-JAX is a lightweight extension to Tesseract Core that makes Tesseracts look and feel like regular JAX primitives, and makes them jittable, differentiable, and composable.

Read the docs | Explore the examples | Report an issue | Talk to the community | Contribute


The API of Tesseract-JAX consists of a single function, apply_tesseract(tesseract_client, inputs), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:

@jax.jit
def vector_sum(x, y):
    res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
    return res["vector_add"]["result"].sum()

jax.grad(vector_sum)(x, y) # 🎉

Quick start

Note

Before proceeding, make sure you have a working installation of Docker and a modern Python installation (Python 3.10+).

Important

For more detailed installation instructions, please refer to the Tesseract Core documentation.

  1. Install Tesseract-JAX:

    $ pip install tesseract-jax
  2. Build an example Tesseract:

    $ git clone https://github.com/pasteurlabs/tesseract-jax
    $ tesseract build tesseract-jax/examples/simple/vectoradd_jax
  3. Use it as part of a JAX program via the JAX-native apply_tesseract function:

    import jax
    import jax.numpy as jnp
    from tesseract_core import Tesseract
    from tesseract_jax import apply_tesseract
    
    # Load the Tesseract
    t = Tesseract.from_image("vectoradd_jax")
    t.serve()
    
    # Run it with JAX
    x = jnp.ones((1000,))
    y = jnp.ones((1000,))
    
    def vector_sum(x, y):
        res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}})
        return res["vector_add"]["result"].sum()
    
    vector_sum(x, y) # success!
    
    # You can also use it with JAX transformations like JIT and grad
    vector_sum_jit = jax.jit(vector_sum)
    vector_sum_jit(x, y)
    
    vector_sum_grad = jax.grad(vector_sum)
    vector_sum_grad(x, y)

Tip

Now you're ready to jump into our examples for more ways to use Tesseract-JAX.

Sharp edges

  • Arrays vs. array-like objects: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.

    from tesseract_core import Tesseract
    from tesseract_jax import apply_tesseract
    
    tess = Tesseract.from_image("vectoradd_jax")
    with Tesseract.from_image("vectoradd_jax") as tess:
        apply_tesseract(tess, {"a": {"v": [1.0]}, "b": {"v": [2.0]}})  # ❌ raises an error
        apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}})  # ✅ works
  • Additional required endpoints: Tesseract-JAX requires the abstract_eval Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations like jax.grad require vector_jacobian_product to be defined.

Tip

When creating a new Tesseract based on a JAX function, use tesseract init --recipe jax to define all required endpoints automatically, including abstract_eval and vector_jacobian_product.

License

Tesseract-JAX is licensed under the Apache License 2.0 and is free to use, modify, and distribute (under the terms of the license).

Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.

About

Execute + differentiate Tesseracts as part of JAX programs, with full support for function transformations like JIT, grad, and more. ⚡

Topics

Resources

License

Code of conduct

Stars

Watchers

Forks

Contributors 9

Languages