1
+ from functools import partial
1
2
import math
2
- from functools import reduce
3
- from operator import mul
4
3
4
+ from jax import jit , ensure_compile_time_eval
5
5
import jax .numpy as jnp
6
6
7
7
from pmwd .pm_util import rfftnfreq
8
+ from pmwd .util import bincount
8
9
9
10
10
- def powspec (f , spacing , bins = 1j / 3 , g = None , deconv = 0 , cut_zero = True , cut_nyq = True ,
11
- int_dtype = jnp .uint32 ):
11
+ @partial (jit , static_argnames = ('bins' , 'cut_zero' , 'cut_nyq' , 'dtype' , 'int_dtype' ))
12
+ def powspec (f , spacing , bins = 1j / 3 , g = None , deconv = None , cut_zero = True , cut_nyq = True ,
13
+ dtype = jnp .float_ , int_dtype = jnp .uint32 ):
12
14
"""Compute auto or cross power spectrum in 3D averaged in spherical bins.
13
15
14
16
Parameters
15
17
----------
16
18
f : ArrayLike
17
- The field, with the last 3 axes for FFT and the other summed over .
19
+ The field, with the last 3 axes for FFT and the other reduced by sum of squares .
18
20
spacing : float
19
21
Field grid spacing.
20
- bins : float, complex, or 1D ArrayLike , optional
21
- Wavenumber bins. A real number sets the linear spaced bin width in unit of the
22
- smallest fundamental in 3D (right edge inclusive starting from zero); an
23
- imaginary number sets the log spaced bin width in octave (left edge inclusive
24
- starting from the smallest fundamental in 3D); and an array sets the bin edges
25
- directly (right edge inclusive and must starting from zero).
22
+ bins : float, complex, or tuple , optional
23
+ (Angular) wavenumber bins. A real number sets the linear bin width in unit of
24
+ the smallest fundamental in 3D (right edge inclusive starting from zero); an
25
+ imaginary number sets the log bin width in octave (left edge inclusive starting
26
+ from the smallest fundamental in 3D); and a tuple sets the bin edges directly
27
+ (right edge inclusive and must starting from zero).
26
28
g : ArrayLike, optional
27
29
Another field of the same shape for cross correlation.
28
30
deconv : int, optional
29
31
Power of sinc factors to deconvolve in the power spectrum.
30
32
cut_zero : bool, optional
31
33
Whether to discard the bin containing the zero or DC mode.
32
34
cut_nyq : bool, optional
33
- Whether to discard the bins beyond the Nyquist.
35
+ Whether to discard the bins beyond the Nyquist, only for linear or log ``bins``.
36
+ dtype : DTypeLike, optional
37
+ Float dtype for the wavenumber and power spectrum.
34
38
int_dtype : DTypeLike, optional
35
39
Integer dtype for the number of modes.
36
40
@@ -47,77 +51,75 @@ def powspec(f, spacing, bins=1j/3, g=None, deconv=0, cut_zero=True, cut_nyq=True
47
51
48
52
"""
49
53
f = jnp .asarray (f )
50
- grid_shape = f .shape [- 3 :]
51
54
52
55
if g is not None and f .shape != jnp .shape (g ):
53
56
raise ValueError (f'shape mismatch: { f .shape } != { jnp .shape (g )} ' )
57
+ grid_shape = f .shape [- 3 :]
58
+
59
+ with ensure_compile_time_eval ():
60
+ kfun = 1 / max (grid_shape )
61
+ knyq = 0.5
62
+ kmax = knyq * math .sqrt (3 )
63
+ if isinstance (bins , (int , float )):
64
+ bins *= kfun
65
+ bin_num = math .ceil (kmax / bins )
66
+ bins *= jnp .arange (1 + bin_num )
67
+ right = True
68
+ bcut = jnp .digitize (knyq if cut_nyq else kmax , bins , right = right ).item () + 1
69
+ elif isinstance (bins , complex ):
70
+ kmaxable = all (s % 2 == 0 for s in grid_shape ) # extra bin just in case
71
+ bin_num = math .ceil (math .log2 (kmax / kfun ) / bins .imag ) + kmaxable
72
+ bins = kfun * 2 ** (bins .imag * jnp .arange (1 + bin_num ))
73
+ right = False
74
+ bcut = jnp .digitize (knyq if cut_nyq else kmax , bins , right = right ).item () + 1
75
+ else :
76
+ bin_num = len (bins ) - 1
77
+ bins = jnp .asarray (bins )
78
+ bins *= spacing / (2 * jnp .pi ) # convert to 2π spacing
79
+ right = True
80
+ bcut = bin_num + 1 # no trim otherwise hard to jit
54
81
55
82
last_three = range (- 3 , 0 )
56
83
f = jnp .fft .rfftn (f , axes = last_three )
57
-
58
84
if g is None :
59
85
P = f .real ** 2 + f .imag ** 2
60
86
else :
61
87
g = jnp .asarray (g )
62
88
g = jnp .fft .rfftn (g , axes = last_three )
63
-
64
89
P = f * g .conj ()
65
90
66
91
if P .ndim > 3 :
67
92
P = P .sum (tuple (range (P .ndim - 3 )))
68
93
69
94
kvec = rfftnfreq (grid_shape , None , dtype = P .real .dtype )
70
95
k = jnp .sqrt (sum (k ** 2 for k in kvec ))
71
- kfun = 1 / max (grid_shape )
72
- knyq = 0.5
73
- kmax = knyq * math .sqrt (3 )
74
96
75
- if deconv != 0 :
76
- P = reduce ( mul , (jnp .sinc (k ) ** - deconv for k in kvec ), P ) # numpy sinc has pi
97
+ if deconv is not None :
98
+ P = math . prod ( (jnp .sinc (k ) ** - deconv for k in kvec ), start = P ) # numpy sinc has pi
77
99
78
- N = jnp .full_like (P , 2 , dtype = int_dtype )
100
+ N = jnp .full_like (P , 2 , dtype = jnp . uint8 )
79
101
N = N .at [..., 0 ].set (1 )
80
102
if grid_shape [- 1 ] % 2 == 0 :
81
103
N = N .at [..., - 1 ].set (1 )
82
104
83
105
k = k .ravel ()
84
106
P = P .ravel ()
85
107
N = N .ravel ()
86
-
87
- if isinstance (bins , (int , float )):
88
- bins *= kfun
89
- bin_num = math .ceil (kmax / bins )
90
- bins *= jnp .arange (bin_num + 1 )
91
- right = True
92
- elif isinstance (bins , complex ):
93
- kmaxable = all (s % 2 == 0 for s in grid_shape )
94
- bin_num = math .ceil (math .log2 (kmax / kfun ) / bins .imag ) + kmaxable
95
- bins = kfun * 2 ** (bins .imag * jnp .arange (bin_num + 1 ))
96
- right = False
97
- else :
98
- bin_num = len (bins ) - 1
99
- bins = jnp .asarray (bins )
100
- bins *= spacing / (2 * jnp .pi ) # convert to 2π spacing
101
- right = True
102
-
103
108
b = jnp .digitize (k , bins , right = right )
104
- k *= N
105
- P *= N
106
- k = jnp .bincount (b , weights = k , length = 1 + bin_num ) # k=0 goes to b=0
107
- P = jnp .bincount (b , weights = P , length = 1 + bin_num )
108
- N = jnp .bincount (b , weights = N , length = 1 + bin_num )
109
-
110
- bmax = jnp .digitize (knyq if cut_nyq else kmax , bins , right = True )
111
- k = k [cut_zero :bmax + 1 ]
112
- P = P [cut_zero :bmax + 1 ]
113
- N = N [cut_zero :bmax + 1 ]
114
- bins = bins [:bmax + 1 ]
109
+ k = bincount (b , weights = k * N , length = 1 + bin_num , dtype = dtype ) # k=0 goes to b=0
110
+ P = bincount (b , weights = P * N , length = 1 + bin_num , dtype = dtype )
111
+ N = bincount (b , weights = N , length = 1 + bin_num , dtype = int_dtype )
112
+
113
+ k = k [cut_zero :bcut ]
114
+ P = P [cut_zero :bcut ]
115
+ N = N [cut_zero :bcut ]
116
+ bins = bins [:bcut ]
115
117
116
118
k /= N
117
119
P /= N
118
120
119
121
k *= 2 * jnp .pi / spacing
120
122
bins *= 2 * jnp .pi / spacing
121
- P *= spacing ** 3 / reduce ( mul , grid_shape )
123
+ P *= spacing ** 3 / math . prod ( grid_shape )
122
124
123
125
return k , P , N , bins
0 commit comments