Skip to content

Commit edfa9a4

Browse files
committed
Add JIT and minor text corrections
1 parent 1955059 commit edfa9a4

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

docs/interfacing_with_surrogates.rst

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Consider a PyTorch neural network,
6868
6969
torch_model = PyTorchMLP(hidden_dim, n_hidden, output_dim, input_dim)
7070
71-
This model can be converted to a Flax model as follows:
71+
This model can be replicated in Flax as follows:
7272

7373
.. code-block:: python
7474
@@ -117,15 +117,19 @@ For loading weights from a PyTorch checkpoint, you might do something like:
117117
118118
params = {'params': params}
119119
120-
121120
The model can then be called like any Flax model,
122121

123122
.. code-block:: python
124123
125-
output_tensor = flax_model.apply(params, input_tensor)
124+
output_tensor = jax.jit(flax_model.apply)(params, input_tensor)
125+
126+
127+
.. warning::
128+
You need to be very careful when loading from a PyTorch state dict, as Flax and PyTorch may have slightly different representations of the weights (for example, one could be the transpose of the other). It's worth validating the output of your PyTorch model against your JAX model to make sure.
129+
126130

127131

128-
Option 2: converting a Pytorch model to a JAX model
132+
Option 2: converting a PyTorch model to a JAX model
129133
===================================================
130134

131135
.. warning::
@@ -145,7 +149,7 @@ The model can then be called as a pure JAX function:
145149

146150
.. code-block:: python
147151
148-
output_tensor = jax_model_from_torch(params, input_tensor)
152+
output_tensor = jax.jit(jax_model_from_torch)(params, input_tensor)
149153
150154
To remove the need for performing the conversion every time the model is loaded, you might want to save a JAX-compatible version of the weights and model to disk:
151155

@@ -155,7 +159,7 @@ To remove the need for performing the conversion every time the model is loaded,
155159
import numpy as np
156160
157161
# jax.export uses StableHLO to serialize the model to a binary format
158-
exported_model = jax.export(jax_model_from_torch)
162+
exported_model = jax.export(jax.jit(jax_model_from_torch))
159163
with open("model.hlo", "wb") as f:
160164
f.write(exported_model.serialize())
161165
@@ -210,7 +214,7 @@ To convert the ONNX model to a JAX representation, you can use the `jaxonnxrunti
210214
211215
jax_model_from_onnx = ONNXJaxBackend.prepare(onnx_model)
212216
# NOTE: run() returns a list of output tensors, in order of the output nodes
213-
output_tensors = jax_model_from_onnx.run({"input": jnp.asarray(input_tensor, dtype=jnp.float32)})
217+
output_tensors = jax.jit(jax_model_from_onnx.run)({"input": jnp.asarray(input_tensor, dtype=jnp.float32)})
214218
215219
216220
Best practices
@@ -239,6 +243,14 @@ where
239243
240244
By decorating with ``functools.lru_cache(maxsize=1)``, the result of this function - the loaded model - is stored in the cache and is only re-loaded if the function is called with a different ``path``.
241245

246+
**JITting model calls**: In general, you should make sure that your forward call of the model is JITted:
247+
248+
.. code-block:: python
249+
250+
output_tensor = jax.jit(flax_model.apply)(params, input_tensor) # Good
251+
output_tensor = flax_model.apply(params, input_tensor) # Bad
252+
253+
This is vital to ensure fast performance.
242254

243255
.. _Flax Linen: https://flax-linen.readthedocs.io/en/latest/index.html
244256
.. _Flax documentation: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#defining-your-own-models

0 commit comments

Comments
 (0)