Skip to content

Commit e419b07

Browse files
committed
improve vmap_method description
1 parent baa4c22 commit e419b07

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

torch2jax/api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,12 @@ def torch2jax(
150150
example_kw: Example keyword arguments. Defaults to None.
151151
output_shapes: Output shapes or shapes + dtype struct. Defaults to None.
152152
output_sharding_spec: jax.sharding.PartitionSpec specifying the sharding spec of the output, uses input mesh.
153-
vmap_method: batching method, see https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap
153+
vmap_method: batching method, see
154+
[https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
155+
154156
NOTE: only vmap_method="sequntial" is supported non-experimentally
155-
NOTE: try "expand_dims", "broadcast_all" if you want experimentally use pytorch-side batching
157+
158+
NOTE: try "expand_dims", "broadcast_all" if you want to experiment with pytorch-side batching
156159
Returns:
157160
Callable: JIT-compatible JAX function.
158161

torch2jax/gradients.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ def torch2jax_with_vjp(
4646
library PyTorch code may need this fallback. Defaults to True (i.e., do not use fallback).
4747
output_sharding_spec: (not supported) sharding spec of the output, use shard_map instead for a device-local
4848
version of this function
49-
vmap_method: batching method, see https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap
49+
vmap_method: batching method, see
50+
[https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
51+
5052
NOTE: only vmap_method="sequntial" is supported non-experimentally
51-
NOTE: try "expand_dims", "broadcast_all" if you want experimentally use pytorch-side batching
53+
54+
NOTE: try "expand_dims", "broadcast_all" if you want to experiment with pytorch-side batching
5255
Returns:
5356
Callable: JIT-compatible JAX version of the torch function (VJP defined up to depth `depth`).
5457

0 commit comments

Comments
 (0)