Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to steerable pyramid #305

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
324 changes: 205 additions & 119 deletions examples/03_Steerable_Pyramid.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
class SteerablePyramidFreq(nn.Module):
r"""Steerable frequency pyramid in Torch

Construct a steerable pyramid on matrix two dimensional signals, in the
Fourier domain. Boundary-handling is circular. Reconstruction is exact
(within floating point errors). However, if the image has an odd-shape,
the reconstruction will not be exact due to boundary-handling issues
that have not been resolved.
Construct a steerable pyramid on matrix two dimensional signals, in the Fourier
domain. Boundary-handling is circular. Reconstruction is exact (within floating
point errors). However, if the image has an odd-shape, the reconstruction will not
be exact due to boundary-handling issues that have not been resolved. Similarly, if
a complex pyramid of order=0 has non-exact reconstruction and cannot be tight-frame.

The squared radial functions tile the Fourier plane with a raised-cosine
falloff. Angular functions are cos(theta-k*pi/order+1)^(order).
Expand All @@ -50,7 +50,7 @@
The height of the pyramid. If 'auto', will automatically determine
based on the size of `image`.
order : `int`.
The Gaussian derivative order used for the steerable filters, in [1,
The Gaussian derivative order used for the steerable filters, in [0,
15]. Note that to achieve steerability the minimum number of
orientation is `order` + 1, and is used here. To get more orientations
at the same order, use the method `steer_coeffs`
Expand Down Expand Up @@ -133,12 +133,18 @@
if height == "auto":
self.num_scales = int(max_ht)
elif height > max_ht:
raise ValueError("Cannot build pyramid higher than %d levels." % (max_ht))
raise ValueError(f"Cannot build pyramid higher than {max_ht:d} levels.")

Check warning on line 136 in src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py

View check run for this annotation

Codecov / codecov/patch

src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py#L136

Added line #L136 was not covered by tests
else:
self.num_scales = int(height)

if self.order > 15 or self.order <= 0:
raise ValueError("order must be an integer in the range [1,15].")
if self.order > 15 or self.order < 0:
raise ValueError("order must be an integer in the range [0, 15].")
if self.order == 0 and self.is_complex:
warnings.warn(
"Reconstruction will not be perfect for a complex pyramid with order=0"
)
if self.tight_frame:
raise ValueError("Complex pyramid with order=0 cannot be tight-frame!")
self.num_orientations = int(self.order + 1)

if twidth <= 0:
Expand Down Expand Up @@ -615,9 +621,9 @@
)
levs_nums = np.array([int(i) for i in levels if isinstance(i, int)])
assert (levs_nums >= 0).all(), "Level numbers must be non-negative."
assert (levs_nums < self.num_scales).all(), (
"Level numbers must be in the range [0, %d]" % (self.num_scales - 1)
)
assert (
levs_nums < self.num_scales
).all(), f"Level numbers must be in the range [0, {self.num_scales-1:d}]"
levs_tmp = list(np.sort(levs_nums)) # we want smallest first
if "residual_highpass" in levels:
levs_tmp = ["residual_highpass"] + levs_tmp
Expand Down Expand Up @@ -669,10 +675,11 @@
)
bands: NDArray = np.array(bands, ndmin=1)
assert (bands >= 0).all(), "Error: band numbers must be larger than 0."
assert (bands < self.num_orientations).all(), (
"Error: band numbers must be in the range [0, %d]"
% (self.num_orientations - 1)
)
if any(bands > self.num_orientations):
raise ValueError(

Check warning on line 679 in src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py

View check run for this annotation

Codecov / codecov/patch

src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py#L679

Added line #L679 was not covered by tests
"Error: band numbers must be in the range "
f"[0, {self.num_orientations-1:d}]"
)
return list(bands)

def _recon_keys(
Expand Down Expand Up @@ -719,9 +726,9 @@
for i in bands:
if i >= max_orientations:
warnings.warn(
"You wanted band %d in the reconstruction but"
" max_orientation is %d, so we're ignoring that band"
% (i, max_orientations)
f"You wanted band {i:d} in the reconstruction but"
f" max_orientation is {max_orientations:d}, so "
"we're ignoring that band"
)
bands = [i for i in bands if i < max_orientations]
recon_keys = []
Expand Down
2 changes: 1 addition & 1 deletion src/plenoptic/tools/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ def plot_representation(
# need to keep the shape the same because of how we
# check for shape below (unbinding removes a dimension,
# so we add it back)
data_dict[title + "_%02d" % i] = d.unsqueeze(1)
data_dict[title + f"_{i:02d}"] = d.unsqueeze(1)
else:
data_dict[title] = data
data = data_dict
Expand Down
131 changes: 94 additions & 37 deletions tests/test_steerable_pyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,35 @@
from conftest import DEVICE, IMG_DIR
from plenoptic.tools.data import to_numpy

ALL_SPYRS = (
[
f"{h}-{o}-{c}-{d}-{tf}"
for h, o, c, d, tf in product(
["auto", 1, 3, 4, 5],
[1, 2, 3],
[True, False],
[True, False],
[True, False],
)
]
+ [
# pyramid with order=0 can only be tight frame if it's not complex
f"{h}-0-False-{d}-{tf}"
for h, d, tf in product(
["auto", 1, 3, 4, 5],
[True, False],
[True, False],
)
]
+ [
f"{h}-0-True-{d}-False"
for h, d in product(
["auto", 1, 3, 4, 5],
[True, False],
)
]
)


def check_pyr_coeffs(coeff_1, coeff_2, rtol=1e-3, atol=1e-3):
"""
Expand Down Expand Up @@ -167,7 +196,7 @@ def spyr_multi(self, multichannel_img, request):
# can't use one of the spyr fixtures here because we need to instantiate separately
# for each of these shapes
@pytest.mark.parametrize("height", ["auto", 1, 3, 4, 5])
@pytest.mark.parametrize("order", [1, 2, 3])
@pytest.mark.parametrize("order", [0, 1, 2, 3])
@pytest.mark.parametrize("is_complex", [True, False])
@pytest.mark.parametrize(
"im_shape",
Expand All @@ -176,12 +205,20 @@ def spyr_multi(self, multichannel_img, request):
def test_pyramid(self, basic_stim, height, order, is_complex, im_shape):
if im_shape is not None:
basic_stim = basic_stim[..., : im_shape[0], : im_shape[1]]
spc = po.simul.SteerablePyramidFreq(
basic_stim.shape[-2:],
height=height,
order=order,
is_complex=is_complex,
).to(DEVICE)
expectation = does_not_raise()
if (order == 0 and is_complex) or (
im_shape is not None and any([im_shape[0] % 2, im_shape[1] % 2])
):
expectation = pytest.warns(
Warning, match="Reconstruction will not be perfect"
)
with expectation:
spc = po.simul.SteerablePyramidFreq(
basic_stim.shape[-2:],
height=height,
order=order,
is_complex=is_complex,
).to(DEVICE)
spc(basic_stim)

@pytest.mark.parametrize(
Expand All @@ -191,20 +228,48 @@ def test_pyramid(self, basic_stim, height, order, is_complex, im_shape):
for h, o, c, d in product(
["auto", 1, 2, 3], [1, 2, 3], [True, False], [True, False]
)
]
+ [
# pyramid with order=0 can only be tight frame if it's not complex
f"{h}-0-False-{d}-True"
for h, d in product(["auto", 1, 2, 3], [True, False])
],
indirect=True,
)
def test_tight_frame(self, img, spyr):
pyr_coeffs = spyr.forward(img)
check_parseval(img, pyr_coeffs)

@pytest.mark.parametrize("height", ["auto", 1, 2, 3])
@pytest.mark.parametrize("downsample", [True, False])
def test_not_tight_frame(self, height, downsample):
with pytest.raises(ValueError, match="cannot be tight-frame"):
po.simul.SteerablePyramidFreq(
(256, 256),
height,
0,
is_complex=True,
downsample=downsample,
tight_frame=True,
)

@pytest.mark.parametrize(
"spyr",
[
f"{h}-{o}-{c}-True-{t}"
for h, o, c, t in product(
[3, 4, 5], [1, 2, 3], [True, False], [True, False]
)
]
+ [
# pyramid with order=0 can only be tight frame if it's not complex
f"{h}-0-False-True-{t}"
for h, t in product([3, 4, 5], [True, False])
]
+ [
# pyramid with order=0 can only be tight frame if it's not complex
f"{h}-0-True-True-False"
for h in [3, 4, 5]
],
indirect=True,
)
Expand Down Expand Up @@ -243,7 +308,7 @@ def test_not_downsample(self, img, spyr):
"spyr",
[
f"{h}-{o}-{c}-False-False"
for h, o, c in product([3, 4, 5], [1, 2, 3], [True, False])
for h, o, c in product([3, 4, 5], [0, 1, 2, 3], [True, False])
],
indirect=True,
)
Expand All @@ -262,7 +327,7 @@ def test_pyr_to_tensor(self, img, spyr, scales, rtol=1e-12, atol=1e-12):
"spyr",
[
f"{h}-{o}-{c}-True-False"
for h, o, c in product([3, 4, 5], [1, 2, 3], [True, False])
for h, o, c in product([3, 4, 5], [0, 1, 2, 3], [True, False])
],
indirect=True,
)
Expand All @@ -282,43 +347,35 @@ def test_torch_vs_numpy_pyr(self, img, spyr):

@pytest.mark.parametrize(
"spyr",
[
f"{h}-{o}-{c}-{d}-{tf}"
for h, o, c, d, tf in product(
["auto", 1, 3, 4, 5],
[1, 2, 3],
[True, False],
[True, False],
[True, False],
)
],
ALL_SPYRS,
indirect=True,
)
def test_complete_recon(self, img, spyr):
pyr_coeffs = spyr.forward(img)
recon = to_numpy(spyr.recon_pyr(pyr_coeffs))
np.testing.assert_allclose(recon, to_numpy(img), rtol=1e-4, atol=1e-4)
# reconstruction is bad in this context
if spyr.order == 0 and spyr.is_complex:
np.testing.assert_allclose(recon, to_numpy(img), atol=5e-1, rtol=1e-1)
else:
np.testing.assert_allclose(recon, to_numpy(img), rtol=1e-4, atol=1e-4)

@pytest.mark.parametrize(
"spyr_multi",
[
f"{h}-{o}-{c}-{d}-{tf}"
for h, o, c, d, tf in product(
["auto", 1, 3, 4, 5],
[1, 2, 3],
[True, False],
[True, False],
[True, False],
)
],
ALL_SPYRS,
indirect=True,
)
def test_complete_recon_multi(self, multichannel_img, spyr_multi):
pyr_coeffs = spyr_multi.forward(multichannel_img)
recon = to_numpy(spyr_multi.recon_pyr(pyr_coeffs))
np.testing.assert_allclose(
recon, to_numpy(multichannel_img), rtol=1e-4, atol=1e-4
)
# reconstruction is bad in this context
if spyr_multi.order == 0 and spyr_multi.is_complex:
np.testing.assert_allclose(
recon, to_numpy(multichannel_img), atol=5e-1, rtol=1e-1
)
else:
np.testing.assert_allclose(
recon, to_numpy(multichannel_img), rtol=1e-4, atol=1e-4
)

@pytest.mark.parametrize(
"spyr",
Expand Down Expand Up @@ -352,7 +409,7 @@ def test_partial_recon(self, img, spyr):
"spyr",
[
f"{h}-{o}-{c}-True-False"
for h, o, c in product(["auto", 1, 3, 4], [1, 2, 3], [True, False])
for h, o, c in product(["auto", 1, 3, 4], [0, 1, 2, 3], [True, False])
],
indirect=True,
)
Expand Down Expand Up @@ -405,9 +462,9 @@ def test_scales_arg(self, img, spyr, scales):
with pytest.raises(Exception):
spyr.recon_pyr(scales)

@pytest.mark.parametrize("order", range(17))
@pytest.mark.parametrize("order", range(-1, 17))
def test_order_values(self, img, order):
if order in [0, 16]:
if order in [-1, 16]:
expectation = pytest.raises(
ValueError, match="order must be an integer in the range"
)
Expand All @@ -417,7 +474,7 @@ def test_order_values(self, img, order):
pyr = po.simul.SteerablePyramidFreq(img.shape[-2:], order=order).to(DEVICE)
pyr(img)

@pytest.mark.parametrize("order", range(1, 16))
@pytest.mark.parametrize("order", range(0, 16))
def test_buffers(self, order):
pyr = po.simul.SteerablePyramidFreq((256, 256), order=order)
buffers = [k for k, _ in pyr.named_buffers()]
Expand Down