File tree Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -150,7 +150,7 @@ def torch2jax(
150
150
example_kw: Example keyword arguments. Defaults to None.
151
151
output_shapes: Output shapes or shapes + dtype struct. Defaults to None.
152
152
output_sharding_spec: jax.sharding.PartitionSpec specifying the sharding spec of the output, uses input mesh.
153
- vmap_method: batching method, see
153
+ vmap_method: batching method, see
154
154
[https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
155
155
156
156
NOTE: only vmap_method="sequntial" is supported non-experimentally
Original file line number Diff line number Diff line change @@ -46,7 +46,7 @@ def torch2jax_with_vjp(
46
46
library PyTorch code may need this fallback. Defaults to True (i.e., do not use fallback).
47
47
output_sharding_spec: (not supported) sharding spec of the output, use shard_map instead for a device-local
48
48
version of this function
49
- vmap_method: batching method, see
49
+ vmap_method: batching method, see
50
50
[https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
51
51
52
52
NOTE: only vmap_method="sequntial" is supported non-experimentally
You can’t perform that action at this time.
0 commit comments