Skip to content

Commit 6ff355c

Browse files
committed
test: add numpy benchmark tests
Signed-off-by: Nathaniel Starkman <[email protected]>
1 parent ef8ca8d commit 6ff355c

File tree

3 files changed

+170
-4
lines changed

3 files changed

+170
-4
lines changed

.github/workflows/run_tests.yml

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,30 @@ jobs:
5757
- name: Checkout code
5858
uses: actions/checkout@v4
5959

60+
- name: Install uv
61+
uses: astral-sh/setup-uv@v5
62+
with:
63+
enable-cache: true
64+
cache-dependency-glob: "**/pyproject.toml"
65+
python-version: "3.10"
66+
67+
- name: Install the project
68+
run: uv sync --extra dev --resolution lowest-direct
69+
70+
- name: Test package
71+
run: uv run --frozen pytest
72+
73+
benchmark:
74+
needs: [format]
75+
runs-on: ubuntu-24.04
76+
if:
77+
github.event_name == 'workflow_dispatch' || (github.event_name ==
78+
'pull_request' && contains(github.event.pull_request.labels.*.name,
79+
'run-benchmarks')) || (github.event_name == 'push' && github.ref ==
80+
'refs/heads/main')
81+
steps:
82+
- uses: actions/checkout@v4
83+
6084
- name: Install uv
6185
uses: astral-sh/setup-uv@v5
6286
with:
@@ -66,10 +90,13 @@ jobs:
6690
- name: "Set up Python"
6791
uses: actions/setup-python@v5
6892
with:
69-
python-version: "3.10"
93+
python-version: "3.13"
7094

7195
- name: Install the project
72-
run: uv sync --extra dev --resolution lowest-direct
96+
run: uv sync --extra dev
7397

74-
- name: Test package
75-
run: uv run --extra dev --resolution lowest-direct pytest
98+
- name: Run benchmarks
99+
uses: CodSpeedHQ/action@v3
100+
with:
101+
token: ${{ secrets.CODSPEED_TOKEN }}
102+
run: uv run --frozen pytest tests/benchmark --codspeed

tests/benchmark/__init__.py

Whitespace-only changes.

tests/benchmark/test_numpy.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Benchmark tests for quaxed functions on quantities."""
2+
3+
from collections.abc import Callable
4+
from typing import Any, TypeAlias, TypedDict
5+
from typing_extensions import Unpack
6+
7+
import jax
8+
import jax.numpy as jnp
9+
import pytest
10+
from jax._src.stages import Compiled
11+
12+
import quax
13+
14+
from ..myarray import MyArray
15+
16+
17+
Args: TypeAlias = tuple[Any, ...]
18+
19+
x = jnp.linspace(0, 1, 1000)
20+
xm = MyArray(x)
21+
22+
23+
def process_func(func: Callable[..., Any], args: Args) -> tuple[Compiled, Args]:
24+
"""JIT and compile the function."""
25+
return jax.jit(quax.quaxify(func)), args
26+
27+
28+
class ParameterizationKWArgs(TypedDict):
29+
"""Keyword arguments for a pytest parameterization."""
30+
31+
argvalues: list[tuple[Callable[..., Any], Args]]
32+
ids: list[str]
33+
34+
35+
def process_pytest_argvalues(
36+
process_fn: Callable[[Callable[..., Any], Args], tuple[Callable[..., Any], Args]],
37+
argvalues: list[tuple[Callable[..., Any], Unpack[tuple[Args, ...]]]],
38+
) -> ParameterizationKWArgs:
39+
"""Process the argvalues."""
40+
# Get the ID for each parameterization
41+
get_types = lambda args: tuple(str(type(a)) for a in args)
42+
ids: list[str] = []
43+
processed_argvalues: list[tuple[Compiled, Args]] = []
44+
45+
for func, *many_args in argvalues:
46+
for args in many_args:
47+
ids.append(f"{func.__name__}-{get_types(args)}")
48+
processed_argvalues.append(process_fn(func, args))
49+
50+
# Process the argvalues and return the parameterization, with IDs
51+
return {"argvalues": processed_argvalues, "ids": ids}
52+
53+
54+
funcs_and_args: list[tuple[Callable[..., Any], Unpack[tuple[Args, ...]]]] = [
55+
(jnp.abs, (xm,)),
56+
(jnp.acos, (xm,)),
57+
(jnp.acosh, (xm,)),
58+
(jnp.add, (xm, xm)),
59+
(jnp.asin, (xm,)),
60+
(jnp.asinh, (xm,)),
61+
(jnp.atan, (xm,)),
62+
(jnp.atan2, (xm, xm)),
63+
(jnp.atanh, (xm,)),
64+
# bitwise_and
65+
# bitwise_left_shift
66+
# bitwise_invert
67+
# bitwise_or
68+
# bitwise_right_shift
69+
# bitwise_xor
70+
(jnp.ceil, (xm,)),
71+
(jnp.conj, (xm,)),
72+
(jnp.cos, (xm,)),
73+
(jnp.cosh, (xm,)),
74+
(jnp.divide, (xm, xm)),
75+
(jnp.equal, (xm, xm)),
76+
(jnp.exp, (xm,)),
77+
(jnp.expm1, (xm,)),
78+
(jnp.floor, (xm,)),
79+
(jnp.floor_divide, (xm, xm)),
80+
(jnp.greater, (xm, xm)),
81+
(jnp.greater_equal, (xm, xm)),
82+
(jnp.imag, (xm,)),
83+
(jnp.isfinite, (xm,)),
84+
(jnp.isinf, (xm,)),
85+
(jnp.isnan, (xm,)),
86+
(jnp.less, (xm, xm)),
87+
(jnp.less_equal, (xm, xm)),
88+
(jnp.log, (xm,)),
89+
(jnp.log1p, (xm,)),
90+
(jnp.log2, (xm,)),
91+
(jnp.log10, (xm,)),
92+
(jnp.logaddexp, (xm, xm)),
93+
(jnp.logical_and, (xm, xm)),
94+
(jnp.logical_not, (xm,)),
95+
(jnp.logical_or, (xm, xm)),
96+
(jnp.logical_xor, (xm, xm)),
97+
(jnp.multiply, (xm, xm)),
98+
(jnp.negative, (xm,)),
99+
(jnp.not_equal, (xm, xm)),
100+
(jnp.positive, (xm,)),
101+
(jnp.power, (xm, 2.0)),
102+
(jnp.real, (xm,)),
103+
(jnp.remainder, (xm, xm)),
104+
(jnp.round, (xm,)),
105+
(jnp.sign, (xm,)),
106+
(jnp.sin, (xm,)),
107+
(jnp.sinh, (xm,)),
108+
(jnp.square, (xm,)),
109+
(jnp.sqrt, (xm,)),
110+
(jnp.subtract, (xm, xm)),
111+
(jnp.tan, (xm,)),
112+
(jnp.tanh, (xm,)),
113+
(jnp.trunc, (xm,)),
114+
]
115+
116+
117+
# =============================================================================
118+
119+
120+
@pytest.mark.parametrize(
121+
("func", "args"), **process_pytest_argvalues(process_func, funcs_and_args)
122+
)
123+
@pytest.mark.benchmark(group="quaxed", max_time=1.0, warmup=False)
124+
def test_jit_compile(func, args):
125+
"""Test the speed of jitting a function."""
126+
_ = func.lower(*args).compile()
127+
128+
129+
@pytest.mark.parametrize(
130+
("func", "args"), **process_pytest_argvalues(process_func, funcs_and_args)
131+
)
132+
@pytest.mark.benchmark(
133+
group="quaxed",
134+
max_time=1.0, # NOTE: max_time is ignored
135+
warmup=True,
136+
)
137+
def test_execute(func, args):
138+
"""Test the speed of calling the function."""
139+
_ = jax.block_until_ready(func(*args))

0 commit comments

Comments
 (0)