Skip to content

Commit

Permalink
Merge pull request #14 from iKintosh/gaborlayer_as_nn.module
Browse files Browse the repository at this point in the history
Gaborlayer as nn.module
  • Loading branch information
iKintosh authored Aug 31, 2020
2 parents f00c06a + d8a853d commit a104371
Show file tree
Hide file tree
Showing 13 changed files with 292 additions and 308 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@
/GaborNet.egg-info/
/.coverage
/.ipynb_checkpoints/
/metrics.json
/poetry.lock
126 changes: 126 additions & 0 deletions GaborNet/GaborLayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import math
from typing import Any

import torch
from torch.nn import Parameter
from torch.nn.modules import Module, Conv2d


class GaborConv2d(Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
padding_mode="zeros",
):
super().__init__()

self.is_calculated = False

self.conv_layer = Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
)
self.kernel_size = self.conv_layer.kernel_size

# small addition to avoid division by zero
self.delta = 1e-3

# freq, theta, sigma are set up according to S. Meshgini,
# A. Aghagolzadeh and H. Seyedarabi, "Face recognition using
# Gabor filter bank, kernel principal component analysis
# and support vector machine"
self.freq = Parameter(
(math.pi / 2)
* math.sqrt(2)
** (-torch.randint(0, 5, (out_channels, in_channels))).type(torch.Tensor),
requires_grad=True,
)
self.theta = Parameter(
(math.pi / 8)
* torch.randint(0, 8, (out_channels, in_channels)).type(torch.Tensor),
requires_grad=True,
)
self.sigma = Parameter(math.pi / self.freq, requires_grad=True)
self.psi = Parameter(
math.pi * torch.rand(out_channels, in_channels), requires_grad=True
)

self.x0 = Parameter(
torch.ceil(torch.Tensor([self.kernel_size[0] / 2]))[0], requires_grad=False
)
self.y0 = Parameter(
torch.ceil(torch.Tensor([self.kernel_size[1] / 2]))[0], requires_grad=False
)

self.y, self.x = torch.meshgrid(
[
torch.linspace(-self.x0 + 1, self.x0 + 0, self.kernel_size[0]),
torch.linspace(-self.y0 + 1, self.y0 + 0, self.kernel_size[1]),
]
)
self.y = Parameter(self.y)
self.x = Parameter(self.x)

self.weight = Parameter(
torch.empty(self.conv_layer.weight.shape, requires_grad=True),
requires_grad=True,
)

self.register_parameter("freq", self.freq)
self.register_parameter("theta", self.theta)
self.register_parameter("sigma", self.sigma)
self.register_parameter("psi", self.psi)
self.register_parameter("x_shape", self.x0)
self.register_parameter("y_shape", self.y0)
self.register_parameter("y_grid", self.y)
self.register_parameter("x_grid", self.x)
self.register_parameter("weight", self.weight)

def forward(self, input_tensor):
if self.training:
self.calculate_weights()
self.is_calculated = False
if not self.training:
if not self.is_calculated:
self.calculate_weights()
self.is_calculated = True
return self.conv_layer(input_tensor)

def calculate_weights(self):
for i in range(self.conv_layer.out_channels):
for j in range(self.conv_layer.in_channels):
sigma = self.sigma[i, j].expand_as(self.y)
freq = self.freq[i, j].expand_as(self.y)
theta = self.theta[i, j].expand_as(self.y)
psi = self.psi[i, j].expand_as(self.y)

rotx = self.x * torch.cos(theta) + self.y * torch.sin(theta)
roty = -self.x * torch.sin(theta) + self.y * torch.cos(theta)

g = torch.exp(
-0.5 * ((rotx ** 2 + roty ** 2) / (sigma + self.delta) ** 2)
)
g = g * torch.cos(freq * rotx + psi)
g = g / (2 * math.pi * sigma ** 2)
self.conv_layer.weight.data[i, j] = g

def _forward_unimplemented(self, *inputs: Any):
"""
code checkers makes implement this method,
looks like error in PyTorch
"""
raise NotImplementedError
55 changes: 0 additions & 55 deletions GaborNet/GaborLayers.py

This file was deleted.

4 changes: 3 additions & 1 deletion GaborNet/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .GaborLayers import GaborConv2d
from .GaborLayer import GaborConv2d

__version__ = "0.2.0"
11 changes: 5 additions & 6 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import print_function, division
from __future__ import division
from __future__ import print_function

import os
from typing import Dict, Any, Union

from skimage import io
from torch.utils.data import Dataset


class DogsCatsDataset(Dataset):
def __init__(self, root_dir, transform=None):
def __init__(self, root_dir: str, transform=None):
self.root_dir = root_dir
self.pics_list = os.listdir(self.root_dir)
self.transform = transform
Expand All @@ -18,11 +18,10 @@ def __len__(self):

def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.pics_list[idx])
target = 0 if 'cat' in self.pics_list[idx] else 1
target = 0 if "cat" in self.pics_list[idx] else 1
image = io.imread(img_name)
if self.transform:
image = self.transform(image)
sample = {'image': image,
'target': target} # type: Dict[str, Union[int, Any]]
sample = {"image": image, "target": target}

return sample
11 changes: 8 additions & 3 deletions dvc.lock
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
sanity-check:
cmd: python run_sanity_check.py
deps:
- path: GaborNet/GaborLayerNew.py
md5: ce749d117558c17d960339a96b39dda5
- path: GaborNet/GaborLayers.py
md5: e599ae293a6f3b6fe37065e71a1f9c26
md5: 7f38946a85a853164950b4e9edf6424e
- path: data
md5: 45680c6dfa0b2d63cb813bd33995301c.dir
- path: run_sanity_check.py
md5: f132a699372947aedb8b98b5017e034d
md5: b5c0886b91df66512b010504a821a5c7
params:
params.yaml:
sanity_check.epoch: 10
outs:
- path: metrics.json
md5: 44bad14190260263e082f2df61c2a4c3
md5: 94e307584cbaff3cc793981c7b8f6ace
13 changes: 8 additions & 5 deletions dvc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ stages:
sanity-check:
cmd: python run_sanity_check.py
deps:
- GaborNet/GaborLayers.py
- data
- run_sanity_check.py
- GaborNet/GaborLayers.py
- GaborNet/GaborLayer.py
- data
- run_sanity_check.py
params:
- sanity_check.epoch
metrics:
- metrics.json:
cache: true
- metrics.json:
cache: true
1 change: 0 additions & 1 deletion metrics.json

This file was deleted.

2 changes: 2 additions & 0 deletions params.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sanity_check:
epoch: 10
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pytest-pylint = "^0.17.0"
pytest-pycodestyle = "^2.2.0"
pytest-pep257 = "^0.0.5"
pytest-cov = "^2.10.1"
black = "^20.8b1"

[tool.poetry.dev-dependencies]

Expand Down
Loading

0 comments on commit a104371

Please sign in to comment.