Skip to content

Commit 642b3ef

Browse files
committed
updating tests to use absltest
1 parent 1f9fb76 commit 642b3ef

File tree

7 files changed

+443
-446
lines changed

7 files changed

+443
-446
lines changed

tests/test_fn_wrapping.py

Lines changed: 93 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
from pathlib import Path
33

4+
from absl.testing import parameterized, absltest
45
import torch
56
from torch import Size
67
import jax
@@ -15,105 +16,104 @@
1516

1617
from utils import jax_randn # noqa: E402
1718
from torch2jax import torch2jax # noqa: E402
18-
from torch2jax.compat import torch2jax as _torch2jax_flat # noqa: E402
1919

2020
####################################################################################################
2121

2222

23-
def test_single_output_fn():
24-
shape = (10, 2)
25-
26-
def torch_fn(x, y):
27-
return (x + 1 - y.reshape(x.shape)) / torch.norm(y)
28-
29-
device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
30-
dtype_list = [jnp.float32, jnp.float64]
31-
32-
x = jax_randn(shape, device="cpu", dtype=jnp.float64)
33-
y = jax_randn(shape, device="cpu", dtype=jnp.float64).reshape(-1)
34-
jax_fn = torch2jax(torch_fn, x, y, output_shapes=Size(shape))
35-
36-
for device in device_list:
37-
for dtype in dtype_list:
38-
x = jax_randn(shape, device=device, dtype=dtype)
39-
y = jax_randn(shape, device=device, dtype=dtype).reshape(-1)
40-
41-
# non-jit version
42-
out = jax_fn(x, y)
43-
assert isinstance(out, Array)
44-
expected = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y)
45-
err = jnp.linalg.norm(out - expected) / jnp.linalg.norm(expected)
46-
assert err < 1e-5
47-
48-
# jit version
49-
@jax.jit
50-
def complication_fn(x, y):
51-
a = jax_fn(x, y)
52-
y2 = y.reshape(x.shape)
53-
b, c = x - y2 + 1, x + y2 + 1
54-
d = jnp.linalg.norm(x) - jnp.linalg.norm(y)
55-
return a, b, c, d
56-
57-
out = complication_fn(x, y)
58-
assert isinstance(out, (list, tuple)) and len(out) == 4
59-
out1 = out[0]
60-
assert isinstance(out1, Array)
61-
expected = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y)
62-
err = jnp.linalg.norm(out1 - expected) / jnp.linalg.norm(expected)
63-
assert err < 1e-5, f"Error is quite high: {err:.4e}"
64-
65-
66-
def test_multi_output_fn():
67-
shape = (10, 2)
68-
69-
def torch_fn(x, y):
70-
a = (x + 1 - y.reshape(x.shape)) / torch.norm(y)
71-
b = (x - y.reshape(x.shape)).reshape(-1)[:5]
72-
return a, b
73-
74-
device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
75-
dtype_list = [jnp.float32, jnp.float64]
76-
x = jax_randn(shape, device="cpu", dtype=jnp.float64)
77-
y = jax_randn(shape, device="cpu", dtype=jnp.float64).reshape(-1)
78-
jax_fn = torch2jax(torch_fn, x, y, output_shapes=(Size(shape), Size((5,))))
79-
80-
for device in device_list:
81-
for dtype in dtype_list:
82-
x = jax_randn(shape, device=device, dtype=dtype)
83-
y = jax_randn(shape, device=device, dtype=dtype).reshape(-1)
84-
85-
# non-jit version
86-
out = jax_fn(x, y)
87-
assert isinstance(out, (list, tuple)) and len(out) == 2
88-
assert all(isinstance(z, Array) for z in out)
89-
expected1 = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y)
90-
expected2 = (x - y.reshape(x.shape)).reshape(-1)[:5]
91-
err1 = jnp.linalg.norm(out[0] - expected1) / jnp.linalg.norm(expected1)
92-
err2 = jnp.linalg.norm(out[1] - expected2) / jnp.linalg.norm(expected2)
93-
assert err1 < 1e-5 and err2 < 1e-5
94-
95-
# jit version
96-
@jax.jit
97-
def complication_fn(x, y):
98-
a = jax_fn(x, y)
99-
y2 = y.reshape(x.shape)
100-
b, c = x - y2 + 1, x + y2 + 1
101-
d = jnp.linalg.norm(x) - jnp.linalg.norm(y)
102-
return a, b, c, d
103-
104-
out = complication_fn(x, y)
105-
assert isinstance(out, (list, tuple)) and len(out) == 4
106-
assert all(isinstance(z, Array) for z in out[0])
107-
out1 = out[0]
108-
expected1 = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y)
109-
expected2 = (x - y.reshape(x.shape)).reshape(-1)[:5]
110-
err1 = jnp.linalg.norm(out1[0] - expected1) / jnp.linalg.norm(expected2)
111-
err2 = jnp.linalg.norm(out1[1] - expected2) / jnp.linalg.norm(expected2)
112-
assert err1 < 1e-5 and err2 < 1e-5
23+
class TestFnWrapping(parameterized.TestCase):
24+
@parameterized.product(device=["cpu", "cuda"], dtype=[jnp.float32, jnp.float64])
25+
def test_single_output_fn(self, device, dtype):
26+
if not torch.cuda.is_available() and device == "cuda":
27+
self.skipTest("Skipping CUDA test when CUDA is not available.")
28+
shape = (10, 2)
29+
30+
def torch_fn(x, y):
31+
return (x + 1 - y.reshape(x.shape)) / torch.norm(y)
32+
33+
device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
34+
dtype_list = [jnp.float32, jnp.float64]
35+
36+
x = jax_randn(shape, device="cpu", dtype=jnp.float64)
37+
y = jax_randn(shape, device="cpu", dtype=jnp.float64).reshape(-1)
38+
jax_fn = torch2jax(torch_fn, x, y, output_shapes=Size(shape))
39+
40+
x = jax_randn(shape, device=device, dtype=dtype)
41+
y = jax_randn(shape, device=device, dtype=dtype).reshape(-1)
42+
43+
# non-jit version
44+
out = jax_fn(x, y)
45+
assert isinstance(out, Array)
46+
expected = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y)
47+
err = jnp.linalg.norm(out - expected) / jnp.linalg.norm(expected)
48+
assert err < 1e-5
49+
50+
# jit version
51+
@jax.jit
52+
def complication_fn(x, y):
53+
a = jax_fn(x, y)
54+
y2 = y.reshape(x.shape)
55+
b, c = x - y2 + 1, x + y2 + 1
56+
d = jnp.linalg.norm(x) - jnp.linalg.norm(y)
57+
return a, b, c, d
58+
59+
out = complication_fn(x, y)
60+
assert isinstance(out, (list, tuple)) and len(out) == 4
61+
out1 = out[0]
62+
assert isinstance(out1, Array)
63+
expected = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y)
64+
err = jnp.linalg.norm(out1 - expected) / jnp.linalg.norm(expected)
65+
assert err < 1e-5, f"Error is quite high: {err:.4e}"
66+
67+
@parameterized.product(device=["cpu", "cuda"], dtype=[jnp.float32, jnp.float64])
68+
def test_multi_output_fn(self, device, dtype):
69+
if not torch.cuda.is_available() and device == "cuda":
70+
self.skipTest("Skipping CUDA test when CUDA is not available.")
71+
72+
shape = (10, 2)
73+
74+
def torch_fn(x, y):
75+
a = (x + 1 - y.reshape(x.shape)) / torch.norm(y)
76+
b = (x - y.reshape(x.shape)).reshape(-1)[:5]
77+
return a, b
78+
79+
x = jax_randn(shape, device="cpu", dtype=jnp.float64)
80+
y = jax_randn(shape, device="cpu", dtype=jnp.float64).reshape(-1)
81+
jax_fn = torch2jax(torch_fn, x, y, output_shapes=(Size(shape), Size((5,))))
82+
83+
x = jax_randn(shape, device=device, dtype=dtype)
84+
y = jax_randn(shape, device=device, dtype=dtype).reshape(-1)
85+
86+
# non-jit version
87+
out = jax_fn(x, y)
88+
assert isinstance(out, (list, tuple)) and len(out) == 2
89+
assert all(isinstance(z, Array) for z in out)
90+
expected1 = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y)
91+
expected2 = (x - y.reshape(x.shape)).reshape(-1)[:5]
92+
err1 = jnp.linalg.norm(out[0] - expected1) / jnp.linalg.norm(expected1)
93+
err2 = jnp.linalg.norm(out[1] - expected2) / jnp.linalg.norm(expected2)
94+
assert err1 < 1e-5 and err2 < 1e-5
95+
96+
# jit version
97+
@jax.jit
98+
def complication_fn(x, y):
99+
a = jax_fn(x, y)
100+
y2 = y.reshape(x.shape)
101+
b, c = x - y2 + 1, x + y2 + 1
102+
d = jnp.linalg.norm(x) - jnp.linalg.norm(y)
103+
return a, b, c, d
104+
105+
out = complication_fn(x, y)
106+
assert isinstance(out, (list, tuple)) and len(out) == 4
107+
assert all(isinstance(z, Array) for z in out[0])
108+
out1 = out[0]
109+
expected1 = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y)
110+
expected2 = (x - y.reshape(x.shape)).reshape(-1)[:5]
111+
err1 = jnp.linalg.norm(out1[0] - expected1) / jnp.linalg.norm(expected2)
112+
err2 = jnp.linalg.norm(out1[1] - expected2) / jnp.linalg.norm(expected2)
113+
assert err1 < 1e-5 and err2 < 1e-5
113114

114115

115116
####################################################################################################
116117

117118
if __name__ == "__main__":
118-
test_single_output_fn()
119-
test_multi_output_fn()
119+
absltest.main()

tests/test_grad_fallback.py

Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
from pathlib import Path
55

6+
from absl.testing import parameterized, absltest
67
import torch
78
from torch.autograd import Function
89
import jax
@@ -19,76 +20,78 @@
1920
####################################################################################################
2021

2122

22-
def test_torch2jax_with_vjp_vjp_fallback():
23-
shape = (5, 7)
23+
class TestGradFallback(parameterized.TestCase):
24+
@parameterized.product(device=["cpu", "cuda"], dtype=[jnp.float32, jnp.float64])
25+
def test_torch2jax_with_vjp_vjp_fallback(self, device, dtype):
26+
if device == "cuda" and not torch.cuda.is_available():
27+
self.skipTest("Skipping CUDA tests when CUDA is not available")
28+
shape = (5, 7)
2429

25-
class OldInterfaceFunction(Function):
26-
@staticmethod
27-
def forward(ctx, x, y):
28-
ctx.save_for_backward(x, y)
29-
return x**2 + y**2
30+
class OldInterfaceFunction(Function):
31+
@staticmethod
32+
def forward(ctx, x, y):
33+
ctx.save_for_backward(x, y)
34+
return x**2 + y**2
3035

31-
@staticmethod
32-
def backward(ctx, grad_output):
33-
x, y = ctx.saved_tensors
34-
return 2 * x * grad_output, 2 * y * grad_output
36+
@staticmethod
37+
def backward(ctx, grad_output):
38+
x, y = ctx.saved_tensors
39+
return 2 * x * grad_output, 2 * y * grad_output
3540

36-
def torch_fn(x, y):
37-
return OldInterfaceFunction.apply(x, y)
41+
def torch_fn(x, y):
42+
return OldInterfaceFunction.apply(x, y)
3843

39-
def expected_f_fn(x, y):
40-
return x**2 + y**2
44+
def expected_f_fn(x, y):
45+
return x**2 + y**2
4146

42-
expected_g_fn = jax.grad(lambda *args: jnp.sum(expected_f_fn(*args)), argnums=(0, 1))
43-
expected_h_fn = jax.grad(lambda *args: jnp.sum(expected_g_fn(*args)[0] + expected_g_fn(*args)[1]), argnums=(0, 1))
47+
expected_g_fn = jax.grad(lambda *args: jnp.sum(expected_f_fn(*args)), argnums=(0, 1))
48+
expected_h_fn = jax.grad(
49+
lambda *args: jnp.sum(expected_g_fn(*args)[0] + expected_g_fn(*args)[1]), argnums=(0, 1)
50+
)
4451

45-
xt, yt = torch.randn(shape), torch.randn(shape)
46-
device_list = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"]
47-
dtype_list = [jnp.float32, jnp.float64]
52+
xt, yt = torch.randn(shape), torch.randn(shape)
4853

49-
wrap_jax_f_fn = torch2jax_with_vjp(torch_fn, xt, yt, depth=2)
50-
wrap_jax_g_fn = jax.grad(lambda x, y: jnp.sum(wrap_jax_f_fn(x, y)), argnums=(0, 1))
51-
wrap_jax_h_fn = jax.grad(lambda x, y: jnp.sum(wrap_jax_g_fn(x, y)[0] + wrap_jax_g_fn(x, y)[1]), argnums=(0, 1))
54+
wrap_jax_f_fn = torch2jax_with_vjp(torch_fn, xt, yt, depth=2)
55+
wrap_jax_g_fn = jax.grad(lambda x, y: jnp.sum(wrap_jax_f_fn(x, y)), argnums=(0, 1))
56+
wrap_jax_h_fn = jax.grad(lambda x, y: jnp.sum(wrap_jax_g_fn(x, y)[0] + wrap_jax_g_fn(x, y)[1]), argnums=(0, 1))
5257

53-
for device in device_list:
54-
for dtype in dtype_list:
55-
x = jax_randn(shape, dtype=dtype, device=device)
56-
y = jax_randn(shape, dtype=dtype, device=device)
57-
f = wrap_jax_f_fn(x, y)
58-
g = wrap_jax_g_fn(x, y)
59-
h = wrap_jax_h_fn(x, y)
58+
x = jax_randn(shape, dtype=dtype, device=device)
59+
y = jax_randn(shape, dtype=dtype, device=device)
60+
f = wrap_jax_f_fn(x, y)
61+
g = wrap_jax_g_fn(x, y)
62+
h = wrap_jax_h_fn(x, y)
6063

61-
f_expected = expected_f_fn(x, y)
62-
g_expected = expected_g_fn(x, y)
63-
h_expected = expected_h_fn(x, y)
64+
f_expected = expected_f_fn(x, y)
65+
g_expected = expected_g_fn(x, y)
66+
h_expected = expected_h_fn(x, y)
6467

65-
# test output structure #############################
66-
assert isinstance(f, Array)
67-
assert isinstance(g, (tuple, list)) and len(g) == 2 and isinstance(g[0], Array) and isinstance(g[1], Array)
68-
assert isinstance(h, (tuple, list)) and len(h) == 2 and isinstance(h[0], Array) and isinstance(h[1], Array)
68+
# test output structure #############################
69+
self.assertIsInstance(f, Array)
70+
assert isinstance(g, (tuple, list)) and len(g) == 2 and isinstance(g[0], Array) and isinstance(g[1], Array)
71+
assert isinstance(h, (tuple, list)) and len(h) == 2 and isinstance(h[0], Array) and isinstance(h[1], Array)
6972

70-
# test values not under JIT #########################
71-
err_f = jnp.linalg.norm(f - f_expected)
72-
err_g = jnp.linalg.norm(g[0] - g_expected[0]) + jnp.linalg.norm(g[1] - g_expected[1])
73-
err_h = jnp.linalg.norm(h[0] - h_expected[0]) + jnp.linalg.norm(h[1] - h_expected[1])
73+
# test values not under JIT #########################
74+
err_f = jnp.linalg.norm(f - f_expected)
75+
err_g = jnp.linalg.norm(g[0] - g_expected[0]) + jnp.linalg.norm(g[1] - g_expected[1])
76+
err_h = jnp.linalg.norm(h[0] - h_expected[0]) + jnp.linalg.norm(h[1] - h_expected[1])
7477

75-
assert err_f < 1e-5, f"Error in f value is {err_f:.4e}"
76-
assert err_g < 1e-5, f"Error in g value is {err_g:.4e}"
77-
assert err_h < 1e-5, f"Error in h value is {err_h:.4e}"
78+
self.assertLess(err_f, 1e-5)
79+
self.assertLess(err_g, 1e-5)
80+
self.assertLess(err_h, 1e-5)
7881

79-
# test values when under JIT ########################
80-
f = jax.jit(wrap_jax_f_fn)(x, y)
81-
g = jax.jit(wrap_jax_g_fn)(x, y)
82-
h = jax.jit(wrap_jax_h_fn)(x, y)
82+
# test values when under JIT ########################
83+
f = jax.jit(wrap_jax_f_fn)(x, y)
84+
g = jax.jit(wrap_jax_g_fn)(x, y)
85+
h = jax.jit(wrap_jax_h_fn)(x, y)
8386

84-
err_f = jnp.linalg.norm(f - f_expected)
85-
err_g = jnp.linalg.norm(g[0] - g_expected[0]) + jnp.linalg.norm(g[1] - g_expected[1])
86-
err_h = jnp.linalg.norm(h[0] - h_expected[0]) + jnp.linalg.norm(h[1] - h_expected[1])
87+
err_f = jnp.linalg.norm(f - f_expected)
88+
err_g = jnp.linalg.norm(g[0] - g_expected[0]) + jnp.linalg.norm(g[1] - g_expected[1])
89+
err_h = jnp.linalg.norm(h[0] - h_expected[0]) + jnp.linalg.norm(h[1] - h_expected[1])
8790

88-
assert err_f < 1e-5, f"Error in f value is {err_f:.4e}"
89-
assert err_g < 1e-5, f"Error in g value is {err_g:.4e}"
90-
assert err_h < 1e-5, f"Error in h value is {err_h:.4e}"
91+
self.assertLess(err_f, 1e-5)
92+
self.assertLess(err_g, 1e-5)
93+
self.assertLess(err_h, 1e-5)
9194

9295

9396
if __name__ == "__main__":
94-
test_torch2jax_with_vjp_vjp_fallback()
97+
absltest.main()

0 commit comments

Comments
 (0)