|
1 | 1 | import sys
|
2 | 2 | from pathlib import Path
|
3 | 3 |
|
| 4 | +from absl.testing import parameterized, absltest |
4 | 5 | import torch
|
5 | 6 | from torch import Size
|
6 | 7 | import jax
|
|
15 | 16 |
|
16 | 17 | from utils import jax_randn # noqa: E402
|
17 | 18 | from torch2jax import torch2jax # noqa: E402
|
18 |
| -from torch2jax.compat import torch2jax as _torch2jax_flat # noqa: E402 |
19 | 19 |
|
20 | 20 | ####################################################################################################
|
21 | 21 |
|
22 | 22 |
|
23 |
| -def test_single_output_fn(): |
24 |
| - shape = (10, 2) |
25 |
| - |
26 |
| - def torch_fn(x, y): |
27 |
| - return (x + 1 - y.reshape(x.shape)) / torch.norm(y) |
28 |
| - |
29 |
| - device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] |
30 |
| - dtype_list = [jnp.float32, jnp.float64] |
31 |
| - |
32 |
| - x = jax_randn(shape, device="cpu", dtype=jnp.float64) |
33 |
| - y = jax_randn(shape, device="cpu", dtype=jnp.float64).reshape(-1) |
34 |
| - jax_fn = torch2jax(torch_fn, x, y, output_shapes=Size(shape)) |
35 |
| - |
36 |
| - for device in device_list: |
37 |
| - for dtype in dtype_list: |
38 |
| - x = jax_randn(shape, device=device, dtype=dtype) |
39 |
| - y = jax_randn(shape, device=device, dtype=dtype).reshape(-1) |
40 |
| - |
41 |
| - # non-jit version |
42 |
| - out = jax_fn(x, y) |
43 |
| - assert isinstance(out, Array) |
44 |
| - expected = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y) |
45 |
| - err = jnp.linalg.norm(out - expected) / jnp.linalg.norm(expected) |
46 |
| - assert err < 1e-5 |
47 |
| - |
48 |
| - # jit version |
49 |
| - @jax.jit |
50 |
| - def complication_fn(x, y): |
51 |
| - a = jax_fn(x, y) |
52 |
| - y2 = y.reshape(x.shape) |
53 |
| - b, c = x - y2 + 1, x + y2 + 1 |
54 |
| - d = jnp.linalg.norm(x) - jnp.linalg.norm(y) |
55 |
| - return a, b, c, d |
56 |
| - |
57 |
| - out = complication_fn(x, y) |
58 |
| - assert isinstance(out, (list, tuple)) and len(out) == 4 |
59 |
| - out1 = out[0] |
60 |
| - assert isinstance(out1, Array) |
61 |
| - expected = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y) |
62 |
| - err = jnp.linalg.norm(out1 - expected) / jnp.linalg.norm(expected) |
63 |
| - assert err < 1e-5, f"Error is quite high: {err:.4e}" |
64 |
| - |
65 |
| - |
66 |
| -def test_multi_output_fn(): |
67 |
| - shape = (10, 2) |
68 |
| - |
69 |
| - def torch_fn(x, y): |
70 |
| - a = (x + 1 - y.reshape(x.shape)) / torch.norm(y) |
71 |
| - b = (x - y.reshape(x.shape)).reshape(-1)[:5] |
72 |
| - return a, b |
73 |
| - |
74 |
| - device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] |
75 |
| - dtype_list = [jnp.float32, jnp.float64] |
76 |
| - x = jax_randn(shape, device="cpu", dtype=jnp.float64) |
77 |
| - y = jax_randn(shape, device="cpu", dtype=jnp.float64).reshape(-1) |
78 |
| - jax_fn = torch2jax(torch_fn, x, y, output_shapes=(Size(shape), Size((5,)))) |
79 |
| - |
80 |
| - for device in device_list: |
81 |
| - for dtype in dtype_list: |
82 |
| - x = jax_randn(shape, device=device, dtype=dtype) |
83 |
| - y = jax_randn(shape, device=device, dtype=dtype).reshape(-1) |
84 |
| - |
85 |
| - # non-jit version |
86 |
| - out = jax_fn(x, y) |
87 |
| - assert isinstance(out, (list, tuple)) and len(out) == 2 |
88 |
| - assert all(isinstance(z, Array) for z in out) |
89 |
| - expected1 = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y) |
90 |
| - expected2 = (x - y.reshape(x.shape)).reshape(-1)[:5] |
91 |
| - err1 = jnp.linalg.norm(out[0] - expected1) / jnp.linalg.norm(expected1) |
92 |
| - err2 = jnp.linalg.norm(out[1] - expected2) / jnp.linalg.norm(expected2) |
93 |
| - assert err1 < 1e-5 and err2 < 1e-5 |
94 |
| - |
95 |
| - # jit version |
96 |
| - @jax.jit |
97 |
| - def complication_fn(x, y): |
98 |
| - a = jax_fn(x, y) |
99 |
| - y2 = y.reshape(x.shape) |
100 |
| - b, c = x - y2 + 1, x + y2 + 1 |
101 |
| - d = jnp.linalg.norm(x) - jnp.linalg.norm(y) |
102 |
| - return a, b, c, d |
103 |
| - |
104 |
| - out = complication_fn(x, y) |
105 |
| - assert isinstance(out, (list, tuple)) and len(out) == 4 |
106 |
| - assert all(isinstance(z, Array) for z in out[0]) |
107 |
| - out1 = out[0] |
108 |
| - expected1 = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y) |
109 |
| - expected2 = (x - y.reshape(x.shape)).reshape(-1)[:5] |
110 |
| - err1 = jnp.linalg.norm(out1[0] - expected1) / jnp.linalg.norm(expected2) |
111 |
| - err2 = jnp.linalg.norm(out1[1] - expected2) / jnp.linalg.norm(expected2) |
112 |
| - assert err1 < 1e-5 and err2 < 1e-5 |
| 23 | +class TestFnWrapping(parameterized.TestCase): |
| 24 | + @parameterized.product(device=["cpu", "cuda"], dtype=[jnp.float32, jnp.float64]) |
| 25 | + def test_single_output_fn(self, device, dtype): |
| 26 | + if not torch.cuda.is_available() and device == "cuda": |
| 27 | + self.skipTest("Skipping CUDA test when CUDA is not available.") |
| 28 | + shape = (10, 2) |
| 29 | + |
| 30 | + def torch_fn(x, y): |
| 31 | + return (x + 1 - y.reshape(x.shape)) / torch.norm(y) |
| 32 | + |
| 33 | + device_list = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] |
| 34 | + dtype_list = [jnp.float32, jnp.float64] |
| 35 | + |
| 36 | + x = jax_randn(shape, device="cpu", dtype=jnp.float64) |
| 37 | + y = jax_randn(shape, device="cpu", dtype=jnp.float64).reshape(-1) |
| 38 | + jax_fn = torch2jax(torch_fn, x, y, output_shapes=Size(shape)) |
| 39 | + |
| 40 | + x = jax_randn(shape, device=device, dtype=dtype) |
| 41 | + y = jax_randn(shape, device=device, dtype=dtype).reshape(-1) |
| 42 | + |
| 43 | + # non-jit version |
| 44 | + out = jax_fn(x, y) |
| 45 | + assert isinstance(out, Array) |
| 46 | + expected = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y) |
| 47 | + err = jnp.linalg.norm(out - expected) / jnp.linalg.norm(expected) |
| 48 | + assert err < 1e-5 |
| 49 | + |
| 50 | + # jit version |
| 51 | + @jax.jit |
| 52 | + def complication_fn(x, y): |
| 53 | + a = jax_fn(x, y) |
| 54 | + y2 = y.reshape(x.shape) |
| 55 | + b, c = x - y2 + 1, x + y2 + 1 |
| 56 | + d = jnp.linalg.norm(x) - jnp.linalg.norm(y) |
| 57 | + return a, b, c, d |
| 58 | + |
| 59 | + out = complication_fn(x, y) |
| 60 | + assert isinstance(out, (list, tuple)) and len(out) == 4 |
| 61 | + out1 = out[0] |
| 62 | + assert isinstance(out1, Array) |
| 63 | + expected = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y) |
| 64 | + err = jnp.linalg.norm(out1 - expected) / jnp.linalg.norm(expected) |
| 65 | + assert err < 1e-5, f"Error is quite high: {err:.4e}" |
| 66 | + |
| 67 | + @parameterized.product(device=["cpu", "cuda"], dtype=[jnp.float32, jnp.float64]) |
| 68 | + def test_multi_output_fn(self, device, dtype): |
| 69 | + if not torch.cuda.is_available() and device == "cuda": |
| 70 | + self.skipTest("Skipping CUDA test when CUDA is not available.") |
| 71 | + |
| 72 | + shape = (10, 2) |
| 73 | + |
| 74 | + def torch_fn(x, y): |
| 75 | + a = (x + 1 - y.reshape(x.shape)) / torch.norm(y) |
| 76 | + b = (x - y.reshape(x.shape)).reshape(-1)[:5] |
| 77 | + return a, b |
| 78 | + |
| 79 | + x = jax_randn(shape, device="cpu", dtype=jnp.float64) |
| 80 | + y = jax_randn(shape, device="cpu", dtype=jnp.float64).reshape(-1) |
| 81 | + jax_fn = torch2jax(torch_fn, x, y, output_shapes=(Size(shape), Size((5,)))) |
| 82 | + |
| 83 | + x = jax_randn(shape, device=device, dtype=dtype) |
| 84 | + y = jax_randn(shape, device=device, dtype=dtype).reshape(-1) |
| 85 | + |
| 86 | + # non-jit version |
| 87 | + out = jax_fn(x, y) |
| 88 | + assert isinstance(out, (list, tuple)) and len(out) == 2 |
| 89 | + assert all(isinstance(z, Array) for z in out) |
| 90 | + expected1 = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y) |
| 91 | + expected2 = (x - y.reshape(x.shape)).reshape(-1)[:5] |
| 92 | + err1 = jnp.linalg.norm(out[0] - expected1) / jnp.linalg.norm(expected1) |
| 93 | + err2 = jnp.linalg.norm(out[1] - expected2) / jnp.linalg.norm(expected2) |
| 94 | + assert err1 < 1e-5 and err2 < 1e-5 |
| 95 | + |
| 96 | + # jit version |
| 97 | + @jax.jit |
| 98 | + def complication_fn(x, y): |
| 99 | + a = jax_fn(x, y) |
| 100 | + y2 = y.reshape(x.shape) |
| 101 | + b, c = x - y2 + 1, x + y2 + 1 |
| 102 | + d = jnp.linalg.norm(x) - jnp.linalg.norm(y) |
| 103 | + return a, b, c, d |
| 104 | + |
| 105 | + out = complication_fn(x, y) |
| 106 | + assert isinstance(out, (list, tuple)) and len(out) == 4 |
| 107 | + assert all(isinstance(z, Array) for z in out[0]) |
| 108 | + out1 = out[0] |
| 109 | + expected1 = (x + 1 - y.reshape(x.shape)) / jnp.linalg.norm(y) |
| 110 | + expected2 = (x - y.reshape(x.shape)).reshape(-1)[:5] |
| 111 | + err1 = jnp.linalg.norm(out1[0] - expected1) / jnp.linalg.norm(expected2) |
| 112 | + err2 = jnp.linalg.norm(out1[1] - expected2) / jnp.linalg.norm(expected2) |
| 113 | + assert err1 < 1e-5 and err2 < 1e-5 |
113 | 114 |
|
114 | 115 |
|
115 | 116 | ####################################################################################################
|
116 | 117 |
|
117 | 118 | if __name__ == "__main__":
|
118 |
| - test_single_output_fn() |
119 |
| - test_multi_output_fn() |
| 119 | + absltest.main() |
0 commit comments