Skip to content

Commit f4c20bb

Browse files
committed
Fix powspec round-off error problem and add jit
1 parent b9a82bf commit f4c20bb

File tree

3 files changed

+63
-52
lines changed

3 files changed

+63
-52
lines changed

pmwd/spec_util.py

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,40 @@
1+
from functools import partial
12
import math
2-
from functools import reduce
3-
from operator import mul
43

4+
from jax import jit, ensure_compile_time_eval
55
import jax.numpy as jnp
66

77
from pmwd.pm_util import rfftnfreq
8+
from pmwd.util import bincount
89

910

10-
def powspec(f, spacing, bins=1j/3, g=None, deconv=0, cut_zero=True, cut_nyq=True,
11-
int_dtype=jnp.uint32):
11+
@partial(jit, static_argnames=('bins', 'cut_zero', 'cut_nyq', 'dtype', 'int_dtype'))
12+
def powspec(f, spacing, bins=1j/3, g=None, deconv=None, cut_zero=True, cut_nyq=True,
13+
dtype=jnp.float_, int_dtype=jnp.uint32):
1214
"""Compute auto or cross power spectrum in 3D averaged in spherical bins.
1315
1416
Parameters
1517
----------
1618
f : ArrayLike
17-
The field, with the last 3 axes for FFT and the other summed over.
19+
The field, with the last 3 axes for FFT and the other reduced by sum of squares.
1820
spacing : float
1921
Field grid spacing.
20-
bins : float, complex, or 1D ArrayLike, optional
21-
Wavenumber bins. A real number sets the linear spaced bin width in unit of the
22-
smallest fundamental in 3D (right edge inclusive starting from zero); an
23-
imaginary number sets the log spaced bin width in octave (left edge inclusive
24-
starting from the smallest fundamental in 3D); and an array sets the bin edges
25-
directly (right edge inclusive and must starting from zero).
22+
bins : float, complex, or tuple, optional
23+
(Angular) wavenumber bins. A real number sets the linear bin width in unit of
24+
the smallest fundamental in 3D (right edge inclusive starting from zero); an
25+
imaginary number sets the log bin width in octave (left edge inclusive starting
26+
from the smallest fundamental in 3D); and a tuple sets the bin edges directly
27+
(right edge inclusive and must starting from zero).
2628
g : ArrayLike, optional
2729
Another field of the same shape for cross correlation.
2830
deconv : int, optional
2931
Power of sinc factors to deconvolve in the power spectrum.
3032
cut_zero : bool, optional
3133
Whether to discard the bin containing the zero or DC mode.
3234
cut_nyq : bool, optional
33-
Whether to discard the bins beyond the Nyquist.
35+
Whether to discard the bins beyond the Nyquist, only for linear or log ``bins``.
36+
dtype : DTypeLike, optional
37+
Float dtype for the wavenumber and power spectrum.
3438
int_dtype : DTypeLike, optional
3539
Integer dtype for the number of modes.
3640
@@ -47,77 +51,75 @@ def powspec(f, spacing, bins=1j/3, g=None, deconv=0, cut_zero=True, cut_nyq=True
4751
4852
"""
4953
f = jnp.asarray(f)
50-
grid_shape = f.shape[-3:]
5154

5255
if g is not None and f.shape != jnp.shape(g):
5356
raise ValueError(f'shape mismatch: {f.shape} != {jnp.shape(g)}')
57+
grid_shape = f.shape[-3:]
58+
59+
with ensure_compile_time_eval():
60+
kfun = 1 / max(grid_shape)
61+
knyq = 0.5
62+
kmax = knyq * math.sqrt(3)
63+
if isinstance(bins, (int, float)):
64+
bins *= kfun
65+
bin_num = math.ceil(kmax / bins)
66+
bins *= jnp.arange(1 + bin_num)
67+
right = True
68+
bcut = jnp.digitize(knyq if cut_nyq else kmax, bins, right=right).item() + 1
69+
elif isinstance(bins, complex):
70+
kmaxable = all(s % 2 == 0 for s in grid_shape) # extra bin just in case
71+
bin_num = math.ceil(math.log2(kmax / kfun) / bins.imag) + kmaxable
72+
bins = kfun * 2 ** (bins.imag * jnp.arange(1 + bin_num))
73+
right = False
74+
bcut = jnp.digitize(knyq if cut_nyq else kmax, bins, right=right).item() + 1
75+
else:
76+
bin_num = len(bins) - 1
77+
bins = jnp.asarray(bins)
78+
bins *= spacing / (2 * jnp.pi) # convert to 2π spacing
79+
right = True
80+
bcut = bin_num + 1 # no trim otherwise hard to jit
5481

5582
last_three = range(-3, 0)
5683
f = jnp.fft.rfftn(f, axes=last_three)
57-
5884
if g is None:
5985
P = f.real**2 + f.imag**2
6086
else:
6187
g = jnp.asarray(g)
6288
g = jnp.fft.rfftn(g, axes=last_three)
63-
6489
P = f * g.conj()
6590

6691
if P.ndim > 3:
6792
P = P.sum(tuple(range(P.ndim-3)))
6893

6994
kvec = rfftnfreq(grid_shape, None, dtype=P.real.dtype)
7095
k = jnp.sqrt(sum(k**2 for k in kvec))
71-
kfun = 1 / max(grid_shape)
72-
knyq = 0.5
73-
kmax = knyq * math.sqrt(3)
7496

75-
if deconv != 0:
76-
P = reduce(mul, (jnp.sinc(k) ** -deconv for k in kvec), P) # numpy sinc has pi
97+
if deconv is not None:
98+
P = math.prod((jnp.sinc(k) ** -deconv for k in kvec), start=P) # numpy sinc has pi
7799

78-
N = jnp.full_like(P, 2, dtype=int_dtype)
100+
N = jnp.full_like(P, 2, dtype=jnp.uint8)
79101
N = N.at[..., 0].set(1)
80102
if grid_shape[-1] % 2 == 0:
81103
N = N.at[..., -1].set(1)
82104

83105
k = k.ravel()
84106
P = P.ravel()
85107
N = N.ravel()
86-
87-
if isinstance(bins, (int, float)):
88-
bins *= kfun
89-
bin_num = math.ceil(kmax / bins)
90-
bins *= jnp.arange(bin_num + 1)
91-
right = True
92-
elif isinstance(bins, complex):
93-
kmaxable = all(s % 2 == 0 for s in grid_shape)
94-
bin_num = math.ceil(math.log2(kmax / kfun) / bins.imag) + kmaxable
95-
bins = kfun * 2 ** (bins.imag * jnp.arange(bin_num + 1))
96-
right = False
97-
else:
98-
bin_num = len(bins) - 1
99-
bins = jnp.asarray(bins)
100-
bins *= spacing / (2 * jnp.pi) # convert to 2π spacing
101-
right = True
102-
103108
b = jnp.digitize(k, bins, right=right)
104-
k *= N
105-
P *= N
106-
k = jnp.bincount(b, weights=k, length=1+bin_num) # k=0 goes to b=0
107-
P = jnp.bincount(b, weights=P, length=1+bin_num)
108-
N = jnp.bincount(b, weights=N, length=1+bin_num)
109-
110-
bmax = jnp.digitize(knyq if cut_nyq else kmax, bins, right=True)
111-
k = k[cut_zero:bmax+1]
112-
P = P[cut_zero:bmax+1]
113-
N = N[cut_zero:bmax+1]
114-
bins = bins[:bmax+1]
109+
k = bincount(b, weights=k * N, length=1+bin_num, dtype=dtype) # k=0 goes to b=0
110+
P = bincount(b, weights=P * N, length=1+bin_num, dtype=dtype)
111+
N = bincount(b, weights=N, length=1+bin_num, dtype=int_dtype)
112+
113+
k = k[cut_zero:bcut]
114+
P = P[cut_zero:bcut]
115+
N = N[cut_zero:bcut]
116+
bins = bins[:bcut]
115117

116118
k /= N
117119
P /= N
118120

119121
k *= 2 * jnp.pi / spacing
120122
bins *= 2 * jnp.pi / spacing
121-
P *= spacing**3 / reduce(mul, grid_shape)
123+
P *= spacing**3 / math.prod(grid_shape)
122124

123125
return k, P, N, bins

pmwd/util.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from jax import float0
3+
import jax.numpy as jnp
34

45

56
def is_float0_array(x):
@@ -8,3 +9,10 @@ def is_float0_array(x):
89

910
def float0_like(x):
1011
return np.empty(x.shape, dtype=float0) # see jax issue #4433
12+
13+
14+
def bincount(x, weights, *, length, dtype=None):
15+
"""Quick fix before is released."""
16+
if dtype is None:
17+
dtype = jnp.dtype(weights)
18+
return jnp.zeros(length, dtype=dtype).at[jnp.clip(x, 0)].add(weights)

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
vis_require = ['matplotlib', 'scipy']
1010
#docs_require = ['sphinx', 'jupyterlab']
1111
#tests_require = ['pytest', 'pytest-cov', 'pytest-benchmark', 'pytest-xdist', 'scipy']
12+
#dev_require = vis_require + docs_require + tests_require
1213

1314

1415
setup(
@@ -22,7 +23,7 @@
2223
use_scm_version={'write_to': 'pmwd/_version.py'},
2324
setup_requires=['setuptools_scm'],
2425
packages=find_packages(),
25-
python_requires='>=3.7',
26+
python_requires='>=3.8', # math.prod
2627
install_requires=[
2728
'jax>=0.4.7',
2829
'mcfit>=0.0.18', # jax backend
@@ -31,6 +32,6 @@
3132
'vis': vis_require,
3233
#'docs': docs_require,
3334
#'tests': tests_require,
34-
#'dev': vis_require + docs_require + tests_require,
35+
#'dev': dev_require,
3536
}
3637
)

0 commit comments

Comments
 (0)