File tree Expand file tree Collapse file tree 2 files changed +10
-4
lines changed Expand file tree Collapse file tree 2 files changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -150,9 +150,12 @@ def torch2jax(
150
150
example_kw: Example keyword arguments. Defaults to None.
151
151
output_shapes: Output shapes or shapes + dtype struct. Defaults to None.
152
152
output_sharding_spec: jax.sharding.PartitionSpec specifying the sharding spec of the output, uses input mesh.
153
- vmap_method: batching method, see https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap
153
+ vmap_method: batching method, see
154
+ [https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
155
+
154
156
NOTE: only vmap_method="sequntial" is supported non-experimentally
155
- NOTE: try "expand_dims", "broadcast_all" if you want experimentally use pytorch-side batching
157
+
158
+ NOTE: try "expand_dims", "broadcast_all" if you want to experiment with pytorch-side batching
156
159
Returns:
157
160
Callable: JIT-compatible JAX function.
158
161
Original file line number Diff line number Diff line change @@ -46,9 +46,12 @@ def torch2jax_with_vjp(
46
46
library PyTorch code may need this fallback. Defaults to True (i.e., do not use fallback).
47
47
output_sharding_spec: (not supported) sharding spec of the output, use shard_map instead for a device-local
48
48
version of this function
49
- vmap_method: batching method, see https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap
49
+ vmap_method: batching method, see
50
+ [https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
51
+
50
52
NOTE: only vmap_method="sequntial" is supported non-experimentally
51
- NOTE: try "expand_dims", "broadcast_all" if you want experimentally use pytorch-side batching
53
+
54
+ NOTE: try "expand_dims", "broadcast_all" if you want to experiment with pytorch-side batching
52
55
Returns:
53
56
Callable: JIT-compatible JAX version of the torch function (VJP defined up to depth `depth`).
54
57
You can’t perform that action at this time.
0 commit comments