Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to steerable pyramid #305

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
324 changes: 205 additions & 119 deletions examples/03_Steerable_Pyramid.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
class SteerablePyramidFreq(nn.Module):
r"""Steerable frequency pyramid in Torch

Construct a steerable pyramid on matrix two dimensional signals, in the
Fourier domain. Boundary-handling is circular. Reconstruction is exact
(within floating point errors). However, if the image has an odd-shape,
the reconstruction will not be exact due to boundary-handling issues
that have not been resolved.
Construct a steerable pyramid on matrix two dimensional signals, in the Fourier
domain. Boundary-handling is circular. Reconstruction is exact (within floating
point errors). However, if the image has an odd-shape, the reconstruction will not
be exact due to boundary-handling issues that have not been resolved. Similarly, if
a complex pyramid of order=0 has non-exact reconstruction and cannot be tight-frame.

The squared radial functions tile the Fourier plane with a raised-cosine
falloff. Angular functions are cos(theta-k*pi/order+1)^(order).
Expand All @@ -52,7 +52,7 @@
log2(min(image_shape[1], image_shape[1]))-2. If height=0, this only returns the
residuals.
order : `int`.
The Gaussian derivative order used for the steerable filters, in [1,
The Gaussian derivative order used for the steerable filters, in [0,
15]. Note that to achieve steerability the minimum number of
orientation is `order` + 1, and is used here. To get more orientations
at the same order, use the method `steer_coeffs`
Expand Down Expand Up @@ -141,8 +141,14 @@
else:
self.num_scales = int(height)

if self.order > 15 or self.order <= 0:
raise ValueError("order must be an integer in the range [1,15].")
if self.order > 15 or self.order < 0:
raise ValueError("order must be an integer in the range [0, 15].")
if self.order == 0 and self.is_complex:
warnings.warn(
"Reconstruction will not be perfect for a complex pyramid with order=0"
)
if self.tight_frame:
raise ValueError("Complex pyramid with order=0 cannot be tight-frame!")
self.num_orientations = int(self.order + 1)

if twidth <= 0:
Expand Down Expand Up @@ -673,10 +679,11 @@
)
bands: NDArray = np.array(bands, ndmin=1)
assert (bands >= 0).all(), "Error: band numbers must be larger than 0."
assert (bands < self.num_orientations).all(), (
"Error: band numbers must be in the range [0, "
f"{self.num_orientations - 1:d}]"
)
if any(bands > self.num_orientations):
raise ValueError(

Check warning on line 683 in src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py

View check run for this annotation

Codecov / codecov/patch

src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py#L683

Added line #L683 was not covered by tests
"Error: band numbers must be in the range "
f"[0, {self.num_orientations - 1:d}]"
)
return list(bands)

def _recon_keys(
Expand Down Expand Up @@ -724,8 +731,8 @@
if i >= max_orientations:
warnings.warn(
f"You wanted band {i:d} in the reconstruction but"
f" max_orientation is {max_orientations:d}, so we"
"'re ignoring that band"
f" max_orientation is {max_orientations:d}, so "
"we're ignoring that band"
)
bands = [i for i in bands if i < max_orientations]
recon_keys = []
Expand Down
131 changes: 94 additions & 37 deletions tests/test_steerable_pyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,35 @@
from conftest import DEVICE, IMG_DIR
from plenoptic.tools.data import to_numpy

ALL_SPYRS = (
[
f"{h}-{o}-{c}-{d}-{tf}"
for h, o, c, d, tf in product(
["auto", 1, 3, 4, 5],
[1, 2, 3],
[True, False],
[True, False],
[True, False],
)
]
+ [
# pyramid with order=0 can only be tight frame if it's not complex
f"{h}-0-False-{d}-{tf}"
for h, d, tf in product(
["auto", 1, 3, 4, 5],
[True, False],
[True, False],
)
]
+ [
f"{h}-0-True-{d}-False"
for h, d in product(
["auto", 1, 3, 4, 5],
[True, False],
)
]
)


def check_pyr_coeffs(coeff_1, coeff_2, rtol=1e-3, atol=1e-3):
"""
Expand Down Expand Up @@ -167,7 +196,7 @@ def spyr_multi(self, multichannel_img, request):
# can't use one of the spyr fixtures here because we need to instantiate separately
# for each of these shapes
@pytest.mark.parametrize("height", ["auto", 1, 3, 4, 5])
@pytest.mark.parametrize("order", [1, 2, 3])
@pytest.mark.parametrize("order", [0, 1, 2, 3])
@pytest.mark.parametrize("is_complex", [True, False])
@pytest.mark.parametrize(
"im_shape",
Expand All @@ -176,12 +205,20 @@ def spyr_multi(self, multichannel_img, request):
def test_pyramid(self, basic_stim, height, order, is_complex, im_shape):
if im_shape is not None:
basic_stim = basic_stim[..., : im_shape[0], : im_shape[1]]
spc = po.simul.SteerablePyramidFreq(
basic_stim.shape[-2:],
height=height,
order=order,
is_complex=is_complex,
).to(DEVICE)
expectation = does_not_raise()
if (order == 0 and is_complex) or (
im_shape is not None and any([im_shape[0] % 2, im_shape[1] % 2])
):
expectation = pytest.warns(
Warning, match="Reconstruction will not be perfect"
)
with expectation:
spc = po.simul.SteerablePyramidFreq(
basic_stim.shape[-2:],
height=height,
order=order,
is_complex=is_complex,
).to(DEVICE)
spc(basic_stim)

@pytest.mark.parametrize(
Expand All @@ -191,20 +228,48 @@ def test_pyramid(self, basic_stim, height, order, is_complex, im_shape):
for h, o, c, d in product(
["auto", 1, 2, 3], [1, 2, 3], [True, False], [True, False]
)
]
+ [
# pyramid with order=0 can only be tight frame if it's not complex
f"{h}-0-False-{d}-True"
for h, d in product(["auto", 1, 2, 3], [True, False])
],
indirect=True,
)
def test_tight_frame(self, img, spyr):
pyr_coeffs = spyr.forward(img)
check_parseval(img, pyr_coeffs)

@pytest.mark.parametrize("height", ["auto", 1, 2, 3])
@pytest.mark.parametrize("downsample", [True, False])
def test_not_tight_frame(self, height, downsample):
with pytest.raises(ValueError, match="cannot be tight-frame"):
po.simul.SteerablePyramidFreq(
(256, 256),
height,
0,
is_complex=True,
downsample=downsample,
tight_frame=True,
)

@pytest.mark.parametrize(
"spyr",
[
f"{h}-{o}-{c}-True-{t}"
for h, o, c, t in product(
[3, 4, 5], [1, 2, 3], [True, False], [True, False]
)
]
+ [
# pyramid with order=0 can only be tight frame if it's not complex
f"{h}-0-False-True-{t}"
for h, t in product([3, 4, 5], [True, False])
]
+ [
# pyramid with order=0 can only be tight frame if it's not complex
f"{h}-0-True-True-False"
for h in [3, 4, 5]
],
indirect=True,
)
Expand Down Expand Up @@ -243,7 +308,7 @@ def test_not_downsample(self, img, spyr):
"spyr",
[
f"{h}-{o}-{c}-False-False"
for h, o, c in product([3, 4, 5], [1, 2, 3], [True, False])
for h, o, c in product([3, 4, 5], [0, 1, 2, 3], [True, False])
],
indirect=True,
)
Expand All @@ -262,7 +327,7 @@ def test_pyr_to_tensor(self, img, spyr, scales, rtol=1e-12, atol=1e-12):
"spyr",
[
f"{h}-{o}-{c}-True-False"
for h, o, c in product([3, 4, 5], [1, 2, 3], [True, False])
for h, o, c in product([3, 4, 5], [0, 1, 2, 3], [True, False])
],
indirect=True,
)
Expand All @@ -282,43 +347,35 @@ def test_torch_vs_numpy_pyr(self, img, spyr):

@pytest.mark.parametrize(
"spyr",
[
f"{h}-{o}-{c}-{d}-{tf}"
for h, o, c, d, tf in product(
["auto", 1, 3, 4, 5],
[1, 2, 3],
[True, False],
[True, False],
[True, False],
)
],
ALL_SPYRS,
indirect=True,
)
def test_complete_recon(self, img, spyr):
pyr_coeffs = spyr.forward(img)
recon = to_numpy(spyr.recon_pyr(pyr_coeffs))
np.testing.assert_allclose(recon, to_numpy(img), rtol=1e-4, atol=1e-4)
# reconstruction is bad in this context
if spyr.order == 0 and spyr.is_complex:
np.testing.assert_allclose(recon, to_numpy(img), atol=5e-1, rtol=1e-1)
else:
np.testing.assert_allclose(recon, to_numpy(img), rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize(
"spyr_multi",
[
f"{h}-{o}-{c}-{d}-{tf}"
for h, o, c, d, tf in product(
["auto", 1, 3, 4, 5],
[1, 2, 3],
[True, False],
[True, False],
[True, False],
)
],
ALL_SPYRS,
indirect=True,
)
def test_complete_recon_multi(self, multichannel_img, spyr_multi):
pyr_coeffs = spyr_multi.forward(multichannel_img)
recon = to_numpy(spyr_multi.recon_pyr(pyr_coeffs))
np.testing.assert_allclose(
recon, to_numpy(multichannel_img), rtol=1e-4, atol=1e-4
)
# reconstruction is bad in this context
if spyr_multi.order == 0 and spyr_multi.is_complex:
np.testing.assert_allclose(
recon, to_numpy(multichannel_img), atol=5e-1, rtol=1e-1
)
else:
np.testing.assert_allclose(
recon, to_numpy(multichannel_img), rtol=1e-4, atol=1e-4
)

@pytest.mark.parametrize(
"spyr",
Expand Down Expand Up @@ -352,7 +409,7 @@ def test_partial_recon(self, img, spyr):
"spyr",
[
f"{h}-{o}-{c}-True-False"
for h, o, c in product(["auto", 1, 3, 4], [1, 2, 3], [True, False])
for h, o, c in product(["auto", 1, 3, 4], [0, 1, 2, 3], [True, False])
],
indirect=True,
)
Expand Down Expand Up @@ -423,9 +480,9 @@ def test_height_values(self, img, height):
)
pyr(img)

@pytest.mark.parametrize("order", range(17))
@pytest.mark.parametrize("order", range(-1, 17))
def test_order_values(self, img, order):
if order in [0, 16]:
if order in [-1, 16]:
expectation = pytest.raises(
ValueError, match="order must be an integer in the range"
)
Expand All @@ -435,7 +492,7 @@ def test_order_values(self, img, order):
pyr = po.simul.SteerablePyramidFreq(img.shape[-2:], order=order).to(DEVICE)
pyr(img)

@pytest.mark.parametrize("order", range(1, 16))
@pytest.mark.parametrize("order", range(0, 16))
def test_buffers(self, order):
pyr = po.simul.SteerablePyramidFreq((256, 256), order=order)
buffers = [k for k, _ in pyr.named_buffers()]
Expand Down
Loading