Skip to content

Commit 5df6539

Browse files
authored
Merge pull request #280 from bashtage/mv-complex-normal
ENH: Add MV Complex normal
2 parents f71aeda + 2644dfa commit 5df6539

File tree

7 files changed

+365
-14
lines changed

7 files changed

+365
-14
lines changed

doc/source/change-log.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Change Log
1818

1919
Since v1.20.0
2020
=============
21+
- Added :func:`~randomgen.generator.ExtendedGenerator.multivariate_complex_normal`.
2122
- Added :func:`~randomgen.generator.ExtendedGenerator.standard_wishart` and
2223
:func:`~randomgen.generator.ExtendedGenerator.wishart` variate generators.
2324

doc/source/extended-generator.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Distributions
2929

3030
~ExtendedGenerator.complex_normal
3131
~ExtendedGenerator.multivariate_normal
32+
~ExtendedGenerator.multivariate_complex_normal
3233
~ExtendedGenerator.uintegers
3334
~ExtendedGenerator.standard_wishart
3435
~ExtendedGenerator.wishart

randomgen/generator.pyi

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import numpy as np
55
from numpy import ndarray
66

77
from randomgen.common import BitGenerator
8-
from randomgen.typing import Size
8+
from randomgen.typing import RequiredSize, Size
99

1010
class Generator:
1111
_bit_generator: BitGenerator
@@ -141,27 +141,44 @@ class ExtendedGenerator:
141141
def state(self) -> Dict[str, Any]: ...
142142
@state.setter
143143
def state(self, value: Dict[str, Any]) -> None: ...
144+
def uintegers(self, size: None, bits: Literal[32, 64] = 64) -> int: ...
144145
def uintegers(
145-
self, size: Size = None, bits: Literal[32, 64] = 64
146+
self, size: RequiredSize, bits: Literal[32, 64] = 64
146147
) -> Union[int, ndarray]: ...
147148
# Multivariate distributions:
148149
def multivariate_normal(
149150
self,
150151
mean: ndarray,
151152
cov: ndarray,
152153
size: Size = ...,
153-
check_valid: str = ...,
154+
check_valid: Literal["raise", "ignore", "warn"] = ...,
154155
tol: float = ...,
155156
*,
156-
method: str = ...
157+
method: Literal["svd", "eigh", "cholesky", "factor"] = ...
157158
) -> ndarray: ...
159+
@overload
160+
def complex_normal(
161+
self,
162+
loc: complex = ...,
163+
gamma: complex = ...,
164+
relation: complex = ...,
165+
) -> complex: ...
166+
@overload
158167
def complex_normal(
159168
self,
160169
loc: complex = ...,
161170
gamma: complex = ...,
162171
relation: complex = ...,
172+
size: RequiredSize = ...,
173+
) -> ndarray: ...
174+
@overload
175+
def complex_normal(
176+
self,
177+
loc: Union[complex, ndarray] = ...,
178+
gamma: Union[complex, ndarray] = ...,
179+
relation: Union[complex, ndarray] = ...,
163180
size: Size = ...,
164-
) -> Union[complex, ndarray]: ...
181+
) -> ndarray: ...
165182
def standard_wishart(
166183
self, df: int, dim: int, size: Size = ..., *, rescale: bool = ...
167184
) -> ndarray: ...
@@ -171,10 +188,21 @@ class ExtendedGenerator:
171188
scale: ndarray,
172189
size: Size = ...,
173190
*,
174-
check_valid: str = ...,
191+
check_valid: Literal["raise", "ignore", "warn"] = ...,
175192
tol: float = ...,
176193
rank: Optional[int] = ...,
177-
method: str = ...
194+
method: Literal["svd", "eigh", "cholesky", "factor"] = ...
195+
) -> ndarray: ...
196+
def multivariate_complex_normal(
197+
self,
198+
loc: ndarray,
199+
gamma: Optional[ndarray] = ...,
200+
relation: Optional[ndarray] = ...,
201+
size: Size = ...,
202+
*,
203+
check_valid: Literal["raise", "ignore", "warn"] = ...,
204+
tol: float = ...,
205+
method: Literal["svd", "eigh", "cholesky", "factor"] = ...
178206
) -> ndarray: ...
179207

180208
_random_generator: Generator

randomgen/generator.pyx

Lines changed: 266 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,32 @@ __all__ = ["Generator", "beta", "binomial", "bytes", "chisquare", "choice",
3737

3838
np.import_array()
3939

40-
def _factorize(cov, meth, check_valid, tol, rank):
40+
cdef object broadcast_shape(tuple x, tuple y, bint strict):
41+
cdef bint cond, bcast=True
42+
if x == () or y == ():
43+
if len(x) > len(y):
44+
return True, x
45+
return True, y
46+
lx = len(x)
47+
ly = len(y)
48+
if lx > ly:
49+
shape = list(x[:lx-ly])
50+
x = x[lx-ly:]
51+
else:
52+
shape = list(y[:ly-lx])
53+
y = y[ly-lx:]
54+
for xs, ys in zip(x, y):
55+
cond = xs == ys
56+
if not strict:
57+
cond |= min(xs, ys) == 1
58+
bcast &= cond
59+
if not bcast:
60+
break
61+
shape.append(max(xs, ys))
62+
return bcast, tuple(shape)
63+
64+
65+
cdef _factorize(cov, meth, check_valid, tol, rank):
4166
if meth == "svd":
4267
from numpy.linalg import svd
4368

@@ -4988,13 +5013,15 @@ cdef class ExtendedGenerator:
49885013
tol : float, optional
49895014
Tolerance when checking the singular values in covariance matrix.
49905015
cov is cast to double before the check.
4991-
method : {'svd', 'eigh', 'cholesky'}, optional
5016+
method : {'svd', 'eigh', 'cholesky', 'factor'}, optional
49925017
The cov input is used to compute a factor matrix A such that
49935018
``A @ A.T = cov``. This argument is used to select the method
49945019
used to compute the factor matrix A. The default method 'svd' is
49955020
the slowest, while 'cholesky' is the fastest but less robust than
49965021
the slowest method. The method `eigh` uses eigen decomposition to
4997-
compute A and is faster than svd but slower than cholesky.
5022+
compute A and is faster than svd but slower than cholesky. `factor`
5023+
assumes that cov has been pre-factored so that no transformation is
5024+
applied.
49985025
49995026
Returns
50005027
-------
@@ -5553,6 +5580,242 @@ and the trailing dimensions must match exactly so that
55535580
return out.reshape(out_shape)
55545581

55555582

5583+
def multivariate_complex_normal(self, loc, gamma=None, relation=None, size=None, *,
5584+
check_valid="warn", tol=1e-8, method="svd"):
5585+
r"""
5586+
multivariate_complex_normal(loc, gamma=None, relation=None, size=None, *, check_valid="warn", tol=1e-8, method="svd")
5587+
5588+
Draw random samples from a multivariate complex normal (Gaussian) distribution.
5589+
5590+
Parameters
5591+
----------
5592+
loc : array_like of complex
5593+
Mean of the distribution. Must have shape (m1, m2, ..., mk, N) where
5594+
(m1, m2, ..., mk) would broadcast with (g1, g2, ..., gj) and
5595+
(r1, r2, ..., rq).
5596+
gamma : array_like of float or complex, optional
5597+
Covariance of the real component of the distribution. Must have shape
5598+
(g1, g2, ..., gj, N, N) where (g1, g2, ..., gj) would broadcast
5599+
with (m1, m2, ..., mk) and (r1, r2, ..., rq). If not provided,
5600+
an identity matrix is used which produces the circularly-symmetric
5601+
complex normal when relation is an array of 0.
5602+
relation : array_like of float or complex, optional
5603+
Relation between the two component normals. (r1, r2, ..., rq, N, N)
5604+
where (r1, r2, ..., rq, N, N) would broadcast with (m1, m2, ..., mk)
5605+
and (g1, g2, ..., gj). If not provided, set to zero which produces
5606+
the circularly-symmetric complex normal when gamma is an identify matrix.
5607+
size : int or tuple of ints, optional
5608+
Given a shape of, for example, ``(m,n,k)``, ``m*n*k`` samples are
5609+
generated, and packed in an `m`-by-`n`-by-`k` arrangement. Because
5610+
each sample is `N`-dimensional, the output shape is ``(m,n,k,N)``.
5611+
If no shape is specified, a single (`N`-D) sample is returned.
5612+
check_valid : {'warn', 'raise', 'ignore' }, optional
5613+
Behavior when the covariance matrix implied by `gamma` and `relation`
5614+
is not positive semidefinite.
5615+
tol : float, optional
5616+
Tolerance when checking the singular values in the covariance matrix
5617+
implied by `gamma` and `relation`.
5618+
method : {'svd', 'eigh', 'cholesky'}, optional
5619+
The cov input is used to compute a factor matrix A such that
5620+
``A @ A.T = cov``. This argument is used to select the method
5621+
used to compute the factor matrix A for the covariance implied by
5622+
`gamma` and `relation`. The default method 'svd' is
5623+
the slowest, while 'cholesky' is the fastest but less robust than
5624+
the slowest method. The method `eigh` uses eigen decomposition to
5625+
compute A and is faster than svd but slower than cholesky.
5626+
5627+
Returns
5628+
-------
5629+
out : ndarray
5630+
Drawn samples from the parameterized complex normal distributions.
5631+
5632+
See Also
5633+
--------
5634+
numpy.random.Generator.normal : random values from a real-valued
5635+
normal distribution
5636+
randomgen.generator.ExtendedGenerator.complex_normal : random values from a
5637+
scalar complex-valued normal distribution
5638+
randomgen.generator.ExtendedGenerator.multivariate_normal : random values from a
5639+
scalar complex-valued normal distribution
5640+
5641+
Notes
5642+
-----
5643+
Complex normals are generated from a multivariate normal where the
5644+
covariance matrix of the real and imaginary components is
5645+
5646+
.. math::
5647+
5648+
\begin{array}{c}
5649+
X\\
5650+
Y
5651+
\end{array}\sim N\left(\left[\begin{array}{c}
5652+
\mathrm{Re\left[\mu\right]}\\
5653+
\mathrm{Im\left[\mu\right]}
5654+
\end{array}\right],\frac{1}{2}\left[\begin{array}{cc}
5655+
\mathrm{Re}\left[\Gamma+C\right] & \mathrm{Im}\left[C-\Gamma\right]\\
5656+
\mathrm{Im}\left[\Gamma+C\right] & \mathrm{Re}\left[\Gamma-C\right]
5657+
\end{array}\right]\right)
5658+
5659+
The complex normals are then
5660+
5661+
.. math::
5662+
5663+
Z = X + iY
5664+
5665+
If the implied covariance matrix is not positive semi-definite a warning
5666+
or exception may be raised depending on the value `check_valid`.
5667+
5668+
References
5669+
----------
5670+
.. [1] Wikipedia, "Complex normal distribution",
5671+
https://en.wikipedia.org/wiki/Complex_normal_distribution
5672+
.. [2] Leigh J. Halliwell, "Complex Random Variables" in "Casualty
5673+
Actuarial Society E-Forum", Fall 2015.
5674+
5675+
Examples
5676+
--------
5677+
Draw samples from the standard multivariate complex normal
5678+
5679+
>>> from randomgen import ExtendedGenerator
5680+
>>> eg = ExtendedGenerator()
5681+
>>> loc = np.zeros(3)
5682+
>>> eg.multivariate_complex_normal(loc, size=2)
5683+
array([[ 0.42551611+0.44163456j,
5684+
-0.18366146+0.88380663j,
5685+
-0.3035725 -1.19754723j],
5686+
[-0.86649667-0.88447445j,
5687+
-0.04913229-0.04674949j,
5688+
-0.28145563+1.04682163j]])
5689+
5690+
Draw samples a trivariate centered circularly symmetric complex normal
5691+
5692+
>>> rho = 0.7
5693+
>>> gamma = rho * np.eye(3) + (1-rho) * np.diag(np.ones(3))
5694+
>>> eg.multivariate_complex_normal(loc, gamma, size=3)
5695+
array([[ 0.32699266-0.57787275j, 0.46716898-0.06687298j,
5696+
-0.31483301+0.17233599j],
5697+
[ 0.28036548-0.56994348j, 0.18011468-0.50539209j,
5698+
0.35185607-0.15184288j],
5699+
[-0.1866397 +1.2701576j , -0.18419364-0.06912343j,
5700+
-0.66462037+0.73939778j]])
5701+
5702+
Draw samples from a bivariate distribution with
5703+
correlation between the real and imaginary components
5704+
5705+
>>> loc = np.array([3-7j, 2+4j])
5706+
>>> gamma = np.array([[2, 0 + 1.0j], [-0 - 1.0j, 2]])
5707+
>>> rel = np.array([[-1.8, 0 + 0.1j], [0 + 0.1j, -1.8]])
5708+
>>> eg.multivariate_complex_normal(loc, gamma, size=3)
5709+
array([[2.97279918-5.64185732j, 2.32361134+3.23587346j],
5710+
[1.91476019-7.91802901j, 1.76788821+3.84832672j],
5711+
[4.44740101-7.93782402j, 1.59809459+1.35360097j]])
5712+
"""
5713+
cdef np.ndarray garr, rarr, larr
5714+
cdef np.npy_intp *gshape
5715+
cdef np.npy_intp *rshape
5716+
cdef int gdim, rdim, dim, ldim
5717+
5718+
larr = <np.ndarray>np.PyArray_FROM_OTF(loc,
5719+
np.NPY_CDOUBLE,
5720+
np.NPY_ARRAY_ALIGNED |
5721+
np.NPY_ARRAY_C_CONTIGUOUS)
5722+
ldim = np.PyArray_NDIM(larr)
5723+
if ldim < 1:
5724+
raise ValueError("loc must be at least 1-dimensional")
5725+
dim = np.PyArray_DIMS(larr)[ldim - 1]
5726+
5727+
if gamma is None:
5728+
garr = <np.ndarray>np.eye(dim, dtype=complex)
5729+
else:
5730+
garr = <np.ndarray>np.PyArray_FROM_OTF(gamma,
5731+
np.NPY_CDOUBLE,
5732+
np.NPY_ARRAY_ALIGNED |
5733+
np.NPY_ARRAY_C_CONTIGUOUS)
5734+
5735+
gdim = np.PyArray_NDIM(garr)
5736+
gshape = np.PyArray_DIMS(garr)
5737+
if gdim < 2 or gshape[gdim - 2] != gshape[gdim - 1] or gshape[gdim - 1] != dim:
5738+
raise ValueError(
5739+
"gamma must be at least 2-dimensional and the final two dimensions "
5740+
f"must match the final dimension of loc, {dim}."
5741+
)
5742+
if relation is None:
5743+
rarr = <np.ndarray>np.zeros((dim,dim), dtype=complex)
5744+
else:
5745+
rarr = <np.ndarray>np.PyArray_FROM_OTF(relation,
5746+
np.NPY_CDOUBLE,
5747+
np.NPY_ARRAY_ALIGNED |
5748+
np.NPY_ARRAY_C_CONTIGUOUS)
5749+
rdim = np.PyArray_NDIM(rarr)
5750+
rshape = np.PyArray_DIMS(rarr)
5751+
if rdim < 2 or rshape[rdim - 2] != rshape[rdim - 1] or rshape[rdim - 1] != dim:
5752+
raise ValueError(
5753+
"relation must be at least 2-dimensional and the final two dimensions "
5754+
f"must match the final dimension of loc, {dim}."
5755+
)
5756+
can_bcast, cov_shape = broadcast_shape(np.shape(garr), np.shape(rarr), False)
5757+
if not can_bcast:
5758+
raise ValueError(
5759+
f"The leading dimensions of gamma {np.shape(garr)[:gdim-2]} "
5760+
"must broadcast with the leading dimension of relation "
5761+
f"{np.shape(rarr)[:rdim-2]}.")
5762+
common_shape = cov_shape[: len(cov_shape) - 2]
5763+
l_shape = np.shape(larr)
5764+
l_common = l_shape[: len(l_shape) - 1]
5765+
can_bcast, bcast_shape = broadcast_shape(l_common, common_shape, False)
5766+
if size is not None:
5767+
if isinstance(size, (int, np.integer)):
5768+
size = (size, )
5769+
can_bcast, bcast_shape = broadcast_shape(tuple(size), common_shape, True)
5770+
temp = np.empty((2 * dim, 2 * dim))
5771+
p = np.arange(2 * dim).reshape((2, -1))
5772+
p = p.T.ravel()
5773+
ix = np.ix_(p, p)
5774+
5775+
if gdim == 2:
5776+
gidx = np.array([0])
5777+
garr = np.reshape(garr, (1,) + np.shape(garr))
5778+
else:
5779+
_shape = np.shape(garr)[: gdim - 2]
5780+
gidx = np.arange(np.prod(_shape)).reshape(_shape)
5781+
if rdim == 2:
5782+
ridx = np.array([0])
5783+
rarr = np.reshape(rarr, (1,) + np.shape(rarr))
5784+
else:
5785+
_shape = np.shape(rarr)[: rdim - 2]
5786+
ridx = np.arange(np.prod(_shape)).reshape(_shape)
5787+
5788+
factors = np.empty(common_shape + (2 * dim, 2 * dim)).reshape(
5789+
(-1, 2 * dim, 2 * dim)
5790+
)
5791+
fidx = 0
5792+
5793+
for i, j, in np.broadcast(gidx, ridx):
5794+
g = garr[i]
5795+
r = rarr[j]
5796+
gpr = 0.5 * (g + r)
5797+
gmr = 0.5 * (g - r)
5798+
temp[:dim, :dim] = gpr.real
5799+
temp[:dim, dim:] = -gmr.imag
5800+
temp[dim:, :dim] = gpr.imag
5801+
temp[dim:, dim:] = gmr.real
5802+
if not np.allclose(temp, temp.T, rtol=tol):
5803+
raise ValueError(
5804+
"The covariance matrix implied by gamma and relation is "
5805+
"not symmetric. Each component in gamma must be positive "
5806+
"semi-definite Hermetian and each component in relation "
5807+
"must be symmetric."
5808+
)
5809+
factors[fidx] = _factorize(
5810+
temp[ix], meth=method, check_valid=check_valid, tol=tol, rank=2 * dim
5811+
)
5812+
fidx += 1
5813+
factors = factors.reshape(common_shape + (2 * dim, 2 * dim))
5814+
out = self.multivariate_normal(
5815+
larr.view(np.float64), factors, size=size, method="factor"
5816+
)
5817+
return out.view(complex)
5818+
55565819
with warnings.catch_warnings():
55575820
warnings.simplefilter("ignore")
55585821
_random_generator = Generator()

0 commit comments

Comments
 (0)