-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use eclidean distance by default to follow the MATLAB implementation
- Loading branch information
1 parent
5dc7d87
commit ee14dbf
Showing
1 changed file
with
12 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
# encoding: utf8 | ||
""" | ||
r""" | ||
PyTorch implementation of the HMAX model of human vision. For more information | ||
about HMAX, check: | ||
|
@@ -24,6 +24,10 @@ | |
| | | | | | | | | ||
C2 C2 C2 C2 C2 C2 C2 C2 | ||
This implementation tries to follow the original MATLAB implementation by | ||
Maximilian Riesenhuber as closely as possible: | ||
https://maxlab.neuro.georgetown.edu/hmax.html | ||
Author: Marijn van Vliet <[email protected]> | ||
References | ||
|
@@ -196,8 +200,8 @@ def forward(self, s1_outputs): | |
# Pool over local (c1_space x c1_space) neighbourhood | ||
c1_output = self.local_pool(s1_output) | ||
|
||
# Hence we need to shift the output after the convolution by 1 pixel to | ||
# exactly match the wonky implementation. | ||
# We need to shift the output after the convolution by 1 pixel to | ||
# exactly match the wonky MATLAB implementation. | ||
c1_output = torch.roll(c1_output, (-1, -1), dims=(2, 3)) | ||
c1_output[:, :, -1, :] = 0 | ||
c1_output[:, :, :, -1] = 0 | ||
|
@@ -226,6 +230,7 @@ class S2(nn.Module): | |
Neuroscience uses the euclidean distance ('euclidean'). | ||
sigma : float | ||
The sharpness of the tuning (sigma in eqn 1 of [1]_). Defaults to 1. | ||
Only used when using gaussian activation. | ||
References: | ||
----------- | ||
|
@@ -235,7 +240,7 @@ class S2(nn.Module): | |
National Academy of Sciences 104, no. 15 (April 10, 2007): 6424–29. | ||
https://doi.org/10.1073/pnas.0700622104. | ||
""" | ||
def __init__(self, patches, activation='gaussian', sigma=1): | ||
def __init__(self, patches, activation='euclidean', sigma=1): | ||
super().__init__() | ||
self.activation = activation | ||
self.sigma = sigma | ||
|
@@ -359,7 +364,7 @@ def __init__(self, universal_patch_set, s2_act='gaussian'): | |
S1(size=33, wavelength=3.35), | ||
S1(size=35, wavelength=3.3), | ||
S1(size=37, wavelength=3.25), | ||
S1(size=39, wavelength=3.20), | ||
S1(size=39, wavelength=3.20), # Unused as far as I can tell | ||
] | ||
|
||
# Explicitly add the S1 units as submodules of the model | ||
|
@@ -421,7 +426,7 @@ def run_all_layers(self, img): | |
S2 units. | ||
c2_outputs : List of Tensors, shape (batch_size, num_patches) | ||
For each patch scale, the output of the layer of C2 units. | ||
""" | ||
""" # noqa | ||
s1_outputs = [s1(img) for s1 in self.s1_units] | ||
|
||
# Each C1 layer pools across two S1 layers | ||
|
@@ -460,7 +465,7 @@ def get_all_layers(self, img): | |
S2 units. | ||
c2_outputs : List of arrays, shape (batch_size, num_patches) | ||
For each patch scale, the output of the layer of C2 units. | ||
""" | ||
""" # noqa | ||
s1_out, c1_out, s2_out, c2_out = self.run_all_layers(img) | ||
return ( | ||
[s1.cpu().detach().numpy() for s1 in s1_out], | ||
|