Skip to content

Commit b11ddb1

Browse files
committed
experimental vmap_method support
1 parent 862da53 commit b11ddb1

File tree

7 files changed

+112
-8
lines changed

7 files changed

+112
-8
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@ the GPU.
310310

311311
# Changelog
312312

313+
- version 0.6.1
314+
- added `vmap_method=` support for experimental pytorch-side batching support,
315+
see [https://github.com/rdyro/torch2jax/issues/28](https://github.com/rdyro/torch2jax/issues/28)
316+
313317
- version 0.6.0
314318
- proper multi-GPU support mostly with `shard_map` but also via `jax.jit` automatic sharding
315319
- `shard_map` and automatic `jax.jit` device parallelization should work, but `pmap` doesn't work

docs/changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
- version 0.6.1
4+
- added `vmap_method=` support for experimental pytorch-side batching support,
5+
see [https://github.com/rdyro/torch2jax/issues/28](https://github.com/rdyro/torch2jax/issues/28)
6+
37
- version 0.6.0
48
- proper multi-GPU support mostly with `shard_map` but also via `jax.jit` automatic sharding
59
- `shard_map` and automatic `jax.jit` device parallelization should work, but `pmap` doesn't work

docs/index.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@ the GPU.
310310

311311
# Changelog
312312

313+
- version 0.6.1
314+
- added `vmap_method=` support for experimental pytorch-side batching support,
315+
see [https://github.com/rdyro/torch2jax/issues/28](https://github.com/rdyro/torch2jax/issues/28)
316+
313317
- version 0.6.0
314318
- proper multi-GPU support mostly with `shard_map` but also via `jax.jit` automatic sharding
315319
- `shard_map` and automatic `jax.jit` device parallelization should work, but `pmap` doesn't work

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "torch2jax"
3-
version = "0.6.0"
3+
version = "0.6.1"
44
authors = [
55
{ name="Robert Dyro", email="[email protected]" },
66
]

tests/test_vmap.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import torch
88
import jax
99
from jax import numpy as jnp
10+
from jax import random
1011
from jax.scipy.linalg import cho_factor, cho_solve
1112

1213
paths = [Path(__file__).absolute().parents[1], Path(__file__).absolute().parent]
1314
for path in paths:
1415
if str(path) not in sys.path:
1516
sys.path.append(str(path))
1617

17-
from torch2jax import torch2jax_with_vjp # noqa: E402
18+
from torch2jax import torch2jax, torch2jax_with_vjp # noqa: E402
1819
from utils import jax_randn # noqa: E402
1920

2021
####################################################################################################
@@ -45,6 +46,67 @@ def expected_fn(A, x):
4546
err = jnp.linalg.norm(sol - sol_expected) / jnp.linalg.norm(sol_expected)
4647
assert err < 1e-3
4748

49+
@parameterized.product(device=["cuda", "cpu"], dtype=[jnp.float32, jnp.float64])
50+
def test_simple_vmap(self, device, dtype):
51+
if device == "cuda" and not torch.cuda.is_available():
52+
self.skipTest("Skipping CUDA tests when CUDA is not available")
53+
54+
device = jax.devices(device)[0]
55+
keys = iter(random.split(random.key(17), 1024))
56+
57+
torch_counter = 0
58+
59+
def torch_fn(x):
60+
nonlocal torch_counter
61+
torch_counter += 1
62+
print(f"torch_counter: {torch_counter}")
63+
return 2 * x
64+
65+
x = jax_randn((1024,), device=device, dtype=jnp.float32)
66+
X = jax_randn((572, 1024), dtype=jnp.float32, device=device)
67+
68+
# test sequential
69+
fn = torch2jax(torch_fn, x, output_shapes=x, vmap_method="sequential")
70+
current_counter_val = torch_counter
71+
y = fn(x)
72+
assert current_counter_val + 1 == torch_counter
73+
err = jnp.linalg.norm(y - 2 * x, axis=None)
74+
assert err < 1e-6
75+
76+
current_counter_val = torch_counter
77+
Y = jax.vmap(fn)(X)
78+
assert current_counter_val + X.shape[0] == torch_counter
79+
err = jnp.linalg.norm(Y - 2 * X, axis=None)
80+
assert err < 1e-6
81+
82+
# test broadcast_all
83+
fn = torch2jax(torch_fn, x, output_shapes=x, vmap_method="broadcast_all")
84+
current_counter_val = torch_counter
85+
y = fn(x)
86+
assert current_counter_val + 1 == torch_counter
87+
err = jnp.linalg.norm(y - 2 * x, axis=None)
88+
assert err < 1e-6
89+
90+
current_counter_val = torch_counter
91+
Y = jax.vmap(fn)(X)
92+
assert current_counter_val + 1 == torch_counter
93+
err = jnp.linalg.norm(Y - 2 * X, axis=None)
94+
assert err < 1e-6
95+
96+
# test expand_dims
97+
fn = torch2jax(torch_fn, x, output_shapes=x, vmap_method="expand_dims")
98+
current_counter_val = torch_counter
99+
y = fn(x)
100+
assert current_counter_val + 1 == torch_counter
101+
err = jnp.linalg.norm(y - 2 * x, axis=None)
102+
assert err < 1e-6
103+
104+
current_counter_val = torch_counter
105+
Y = jax.vmap(fn)(X)
106+
assert current_counter_val + 1 == torch_counter
107+
err = jnp.linalg.norm(Y - 2 * X, axis=None)
108+
assert err < 1e-6
109+
48110

49111
if __name__ == "__main__":
50112
absltest.main()

torch2jax/api.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,15 @@
2424
from .utils import find_unique_id, dtype_t2j, normalize_shapes, warn_once
2525

2626

27-
def _gen_ffi_call(outshapes):
27+
def _gen_ffi_call(outshapes, vmap_method: str):
2828
if signature(ffi.ffi_call).return_annotation.startswith("Callable"):
29-
fn_ = ffi.ffi_call("torch_call", outshapes, vmap_method="sequential")
29+
fn_ = ffi.ffi_call("torch_call", outshapes, vmap_method=vmap_method)
3030
else:
31+
if vmap_method != "sequential":
32+
raise ValueError(
33+
f"You specificed {vmap_method=}, but your jax version {jax.__version__} does not support new style of"
34+
" `vmap_method=` specification. Please upgrade your JAX version to use this features"
35+
)
3136
fn_ = lambda *args_flat, fn_id: ffi.ffi_call("torch_call", outshapes, *args_flat, vectorized=False, fn_id=fn_id)
3237
return fn_
3338

@@ -37,6 +42,7 @@ def _torch2jax_flat(
3742
input_shapes: list[jax.Array | Tensor | ShapeDtypeStruct] = None,
3843
output_shapes: list[jax.Array | Tensor | ShapeDtypeStruct] = None,
3944
output_sharding_spec: PartitionSpec | None = None,
45+
vmap_method: str = "sequential",
4046
) -> Callable:
4147
"""Define a jit-compatible JAX function that calls a PyTorch function. Flat
4248
arguments and outputs.
@@ -69,7 +75,7 @@ def torch_call_fn_(args: list[torch.Tensor]):
6975
@jax.jit
7076
def wrapped_flat_fn(*args_flat):
7177
nonlocal inshapes, outshapes
72-
fn_ = _gen_ffi_call(outshapes)
78+
fn_ = _gen_ffi_call(outshapes, vmap_method=vmap_method)
7379

7480
if output_sharding_spec is None:
7581
fn_id = f"{id:d}"
@@ -114,7 +120,7 @@ def _map_outshape(outshape: jax.ShapeDtypeStruct, result_info, result_sharding):
114120
return jax.ShapeDtypeStruct(new_outshape, dtype=outshape.dtype)
115121

116122
new_outshapes = jax.tree.map(_map_outshape, outshapes, result_info, result_sharding)
117-
fn_part_ = _gen_ffi_call(new_outshapes)
123+
fn_part_ = _gen_ffi_call(new_outshapes, vmap_method=vmap_method)
118124
return fn_part_(*args_flat, fn_id=fn_id)
119125

120126
return mesh, _partitioned_fn_, result_sharding, args_sharding
@@ -133,6 +139,7 @@ def torch2jax(
133139
example_kw: Any | None = None,
134140
output_shapes: Any = None,
135141
output_sharding_spec: PartitionSpec | None = None,
142+
vmap_method: str = "sequential",
136143
) -> Callable:
137144
"""Define a jit-compatible JAX function that calls a PyTorch function. Arbitrary nesting of
138145
arguments and outputs is supported.
@@ -143,6 +150,12 @@ def torch2jax(
143150
example_kw: Example keyword arguments. Defaults to None.
144151
output_shapes: Output shapes or shapes + dtype struct. Defaults to None.
145152
output_sharding_spec: jax.sharding.PartitionSpec specifying the sharding spec of the output, uses input mesh.
153+
vmap_method: batching method, see
154+
[https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
155+
156+
NOTE: only vmap_method="sequntial" is supported non-experimentally
157+
158+
NOTE: try "expand_dims", "broadcast_all" if you want to experiment with pytorch-side batching
146159
Returns:
147160
Callable: JIT-compatible JAX function.
148161
@@ -214,7 +227,11 @@ def flat_fn(*args_flat):
214227

215228
# define the wrapped function using flat interface
216229
wrapped_fn_flat = _torch2jax_flat(
217-
flat_fn, input_shapes=None, output_shapes=output_shapes, output_sharding_spec=output_sharding_spec_flat
230+
flat_fn,
231+
input_shapes=None,
232+
output_shapes=output_shapes,
233+
output_sharding_spec=output_sharding_spec_flat,
234+
vmap_method=vmap_method,
218235
)
219236

220237
# define the actual wrapper function

torch2jax/gradients.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def torch2jax_with_vjp(
2727
use_zeros: bool = True,
2828
use_torch_vjp: bool = True,
2929
output_sharding_spec: P | None = None,
30+
vmap_method: str = "sequential",
3031
) -> Callable:
3132
"""Convert a torch function to a jax function and define a custom vjp rule for it up to `depth` recursively deep.
3233
@@ -45,7 +46,12 @@ def torch2jax_with_vjp(
4546
library PyTorch code may need this fallback. Defaults to True (i.e., do not use fallback).
4647
output_sharding_spec: (not supported) sharding spec of the output, use shard_map instead for a device-local
4748
version of this function
49+
vmap_method: batching method, see
50+
[https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap](https://docs.jax.dev/en/latest/ffi.html#batching-with-vmap)
4851
52+
NOTE: only vmap_method="sequntial" is supported non-experimentally
53+
54+
NOTE: try "expand_dims", "broadcast_all" if you want to experiment with pytorch-side batching
4955
Returns:
5056
Callable: JIT-compatible JAX version of the torch function (VJP defined up to depth `depth`).
5157
@@ -86,7 +92,13 @@ def torch2jax_with_vjp(
8692
if output_shapes is None:
8793
outputs = torch_fn(*example_args)
8894
output_shapes = tree_map(lambda x: ShapeDtypeStruct(dtype=dtype_t2j(x.dtype), shape=x.shape), outputs)
89-
fn = torch2jax(torch_fn, *example_args, output_shapes=output_shapes, output_sharding_spec=output_sharding_spec)
95+
fn = torch2jax(
96+
torch_fn,
97+
*example_args,
98+
output_shapes=output_shapes,
99+
output_sharding_spec=output_sharding_spec,
100+
vmap_method=vmap_method,
101+
)
90102

91103
# if this we've reached the requested differentiation depth, refrain from defining a vjp rule ##
92104
if depth <= 0:
@@ -181,6 +193,7 @@ def bwd_fn_torch(args, gs):
181193
output_shapes=next_output_shapes,
182194
depth=depth - 1,
183195
use_torch_vjp=use_torch_vjp,
196+
vmap_method=vmap_method,
184197
)
185198
# define the custom vjp using the fwd_fn and bwd_fn ############################################
186199
fn.defvjp(fwd_fn, bwd_fn)

0 commit comments

Comments
 (0)