Tesseract-JAX is a lightweight extension to Tesseract Core that makes Tesseracts look and feel like regular JAX primitives, and makes them jittable, differentiable, and composable.
Read the docs | Explore the examples | Report an issue | Talk to the community | Contribute
The API of Tesseract-JAX consists of a single function, apply_tesseract(tesseract_client, inputs)
, which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:
@jax.jit
def vector_sum(x, y):
res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
return res["vector_add"]["result"].sum()
jax.grad(vector_sum)(x, y) # 🎉
Note
Before proceeding, make sure you have a working installation of Docker and a modern Python installation (Python 3.10+).
Important
For more detailed installation instructions, please refer to the Tesseract Core documentation.
-
Install Tesseract-JAX:
$ pip install tesseract-jax
-
Build an example Tesseract:
$ git clone https://github.com/pasteurlabs/tesseract-jax $ tesseract build tesseract-jax/examples/simple/vectoradd_jax
-
Use it as part of a JAX program via the JAX-native
apply_tesseract
function:import jax import jax.numpy as jnp from tesseract_core import Tesseract from tesseract_jax import apply_tesseract # Load the Tesseract t = Tesseract.from_image("vectoradd_jax") t.serve() # Run it with JAX x = jnp.ones((1000,)) y = jnp.ones((1000,)) def vector_sum(x, y): res = apply_tesseract(t, {"a": {"v": x}, "b": {"v": y}}) return res["vector_add"]["result"].sum() vector_sum(x, y) # success! # You can also use it with JAX transformations like JIT and grad vector_sum_jit = jax.jit(vector_sum) vector_sum_jit(x, y) vector_sum_grad = jax.grad(vector_sum) vector_sum_grad(x, y)
Tip
Now you're ready to jump into our examples for more ways to use Tesseract-JAX.
-
Arrays vs. array-like objects: Tesseract-JAX is stricter than Tesseract Core in that all array inputs to Tesseracts must be JAX or NumPy arrays, not just any array-like (such as Python floats or lists). As a result, you may need to convert your inputs to JAX arrays before passing them to Tesseract-JAX, including scalar values.
from tesseract_core import Tesseract from tesseract_jax import apply_tesseract tess = Tesseract.from_image("vectoradd_jax") with Tesseract.from_image("vectoradd_jax") as tess: apply_tesseract(tess, {"a": {"v": [1.0]}, "b": {"v": [2.0]}}) # ❌ raises an error apply_tesseract(tess, {"a": {"v": jnp.array([1.0])}, "b": {"v": jnp.array([2.0])}}) # ✅ works
-
Additional required endpoints: Tesseract-JAX requires the
abstract_eval
Tesseract endpoint to be defined for all operations. This is because JAX mandates abstract evaluation of all operations before they are executed. Additionally, many gradient transformations likejax.grad
requirevector_jacobian_product
to be defined.
Tip
When creating a new Tesseract based on a JAX function, use tesseract init --recipe jax
to define all required endpoints automatically, including abstract_eval
and vector_jacobian_product
.
Tesseract-JAX is licensed under the Apache License 2.0 and is free to use, modify, and distribute (under the terms of the license).
Tesseract is a registered trademark of Pasteur Labs, Inc. and may not be used without permission.