Skip to content

Commit cd7c460

Browse files
committed
formatting fix
1 parent e419b07 commit cd7c460

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

torch2jax/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def torch2jax(
150150
example_kw: Example keyword arguments. Defaults to None.
151151
output_shapes: Output shapes or shapes + dtype struct. Defaults to None.
152152
output_sharding_spec: jax.sharding.PartitionSpec specifying the sharding spec of the output, uses input mesh.
153-
vmap_method: batching method, see
153+
vmap_method: batching method, see
154154
[https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
155155
156156
NOTE: only vmap_method="sequntial" is supported non-experimentally

torch2jax/gradients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def torch2jax_with_vjp(
4646
library PyTorch code may need this fallback. Defaults to True (i.e., do not use fallback).
4747
output_sharding_spec: (not supported) sharding spec of the output, use shard_map instead for a device-local
4848
version of this function
49-
vmap_method: batching method, see
49+
vmap_method: batching method, see
5050
[https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
5151
5252
NOTE: only vmap_method="sequntial" is supported non-experimentally

0 commit comments

Comments
 (0)