You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
You need to be very careful when loading from a PyTorch state dict, as Flax and PyTorch may have slightly different representations of the weights (for example, one could be the transpose of the other). It's worth validating the output of your PyTorch model against your JAX model to make sure.
129
+
126
130
127
131
128
-
Option 2: converting a Pytorch model to a JAX model
132
+
Option 2: converting a PyTorch model to a JAX model
To remove the need for performing the conversion every time the model is loaded, you might want to save a JAX-compatible version of the weights and model to disk:
151
155
@@ -155,7 +159,7 @@ To remove the need for performing the conversion every time the model is loaded,
155
159
import numpy as np
156
160
157
161
# jax.export uses StableHLO to serialize the model to a binary format
By decorating with ``functools.lru_cache(maxsize=1)``, the result of this function - the loaded model - is stored in the cache and is only re-loaded if the function is called with a different ``path``.
241
245
246
+
**JITting model calls**: In general, you should make sure that your forward call of the model is JITted:
247
+
248
+
.. code-block:: python
249
+
250
+
output_tensor = jax.jit(flax_model.apply)(params, input_tensor) # Good
251
+
output_tensor = flax_model.apply(params, input_tensor) # Bad
0 commit comments