-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c89fc7c
commit 077850c
Showing
20 changed files
with
572 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# Original code from: | ||
# https://github.com/lucidrains/siren-pytorch (MIT license) | ||
# | ||
# [email protected], 2024 | ||
|
||
import math | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from einops import rearrange | ||
|
||
# helpers | ||
|
||
def exists(val): | ||
return val is not None | ||
|
||
def cast_tuple(val, repeat = 1): | ||
return val if isinstance(val, tuple) else ((val,) * repeat) | ||
|
||
# sin activation | ||
|
||
class Sine(nn.Module): | ||
def __init__(self, w0 = 1.): | ||
super().__init__() | ||
self.w0 = w0 | ||
def forward(self, x): | ||
return torch.sin(self.w0 * x) | ||
|
||
# siren layer | ||
|
||
class Siren(nn.Module): | ||
def __init__( | ||
self, | ||
dim_in, | ||
dim_out, | ||
w0 = 1., | ||
c = 6., | ||
is_first = False, | ||
use_bias = True, | ||
activation = None, | ||
dropout = 0. | ||
): | ||
super().__init__() | ||
self.dim_in = dim_in | ||
self.is_first = is_first | ||
|
||
weight = torch.zeros(dim_out, dim_in) | ||
bias = torch.zeros(dim_out) if use_bias else None | ||
self.init_(weight, bias, c = c, w0 = w0) | ||
|
||
self.weight = nn.Parameter(weight) | ||
self.bias = nn.Parameter(bias) if use_bias else None | ||
self.activation = Sine(w0) if activation is None else activation | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def init_(self, weight, bias, c, w0): | ||
dim = self.dim_in | ||
|
||
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) | ||
weight.uniform_(-w_std, w_std) | ||
|
||
if exists(bias): | ||
bias.uniform_(-w_std, w_std) | ||
|
||
def forward(self, x): | ||
out = F.linear(x, self.weight, self.bias) | ||
out = self.activation(out) | ||
out = self.dropout(out) | ||
return out | ||
|
||
# siren network | ||
|
||
class SirenNet(nn.Module): | ||
def __init__( | ||
self, | ||
dim_in, | ||
dim_hidden, | ||
dim_out, | ||
num_layers, | ||
w0 = 1., | ||
w0_initial = 30., | ||
use_bias = True, | ||
final_activation = None, | ||
dropout = 0. | ||
): | ||
super().__init__() | ||
self.num_layers = num_layers | ||
self.dim_hidden = dim_hidden | ||
|
||
self.layers = nn.ModuleList([]) | ||
for ind in range(num_layers): | ||
is_first = ind == 0 | ||
layer_w0 = w0_initial if is_first else w0 | ||
layer_dim_in = dim_in if is_first else dim_hidden | ||
|
||
layer = Siren( | ||
dim_in = layer_dim_in, | ||
dim_out = dim_hidden, | ||
w0 = layer_w0, | ||
use_bias = use_bias, | ||
is_first = is_first, | ||
dropout = dropout | ||
) | ||
|
||
self.layers.append(layer) | ||
|
||
final_activation = nn.Identity() if not exists(final_activation) else final_activation | ||
self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation) | ||
|
||
def forward(self, x, mods = None): | ||
mods = cast_tuple(mods, self.num_layers) | ||
|
||
for layer, mod in zip(self.layers, mods): | ||
x = layer(x) | ||
|
||
if exists(mod): | ||
x *= rearrange(mod, 'd -> () d') | ||
|
||
return self.last_layer(x) | ||
|
||
# modulatory feed forward | ||
|
||
class Modulator(nn.Module): | ||
def __init__(self, dim_in, dim_hidden, num_layers): | ||
super().__init__() | ||
self.layers = nn.ModuleList([]) | ||
|
||
for ind in range(num_layers): | ||
is_first = ind == 0 | ||
dim = dim_in if is_first else (dim_hidden + dim_in) | ||
|
||
self.layers.append(nn.Sequential( | ||
nn.Linear(dim, dim_hidden), | ||
nn.ReLU() | ||
)) | ||
|
||
def forward(self, z): | ||
x = z | ||
hiddens = [] | ||
|
||
for layer in self.layers: | ||
x = layer(x) | ||
hiddens.append(x) | ||
x = torch.cat((x, z)) | ||
|
||
return tuple(hiddens) | ||
|
||
# wrapper | ||
|
||
class SirenImageWrapper(nn.Module): | ||
def __init__(self, net, image_width, image_height, latent_dim = None): | ||
super().__init__() | ||
assert isinstance(net, SirenNet), 'SirenWrapper must receive a Siren network' | ||
|
||
self.net = net | ||
self.image_width = image_width | ||
self.image_height = image_height | ||
|
||
self.modulator = None | ||
if exists(latent_dim): | ||
self.modulator = Modulator( | ||
dim_in = latent_dim, | ||
dim_hidden = net.dim_hidden, | ||
num_layers = net.num_layers | ||
) | ||
|
||
tensors = [torch.linspace(-1, 1, steps = image_height), torch.linspace(-1, 1, steps = image_width)] | ||
mgrid = torch.stack(torch.meshgrid(*tensors, indexing = 'ij'), dim=-1) | ||
mgrid = rearrange(mgrid, 'h w c -> (h w) c') | ||
self.register_buffer('grid', mgrid) | ||
|
||
def forward(self, img = None, *, latent = None): | ||
modulate = exists(self.modulator) | ||
assert not (modulate ^ exists(latent)), 'latent vector must be only supplied if `latent_dim` was passed in on instantiation' | ||
|
||
mods = self.modulator(latent) if modulate else None | ||
|
||
coords = self.grid.clone().detach().requires_grad_() | ||
out = self.net(coords, mods) | ||
out = rearrange(out, '(h w) c -> () c h w', h = self.image_height, w = self.image_width) | ||
|
||
if exists(img): | ||
return F.mse_loss(img, out) | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# Neural regression of brain image data with SIREN neural net | ||
# | ||
# Number of iterations required may be O(10000) | ||
# | ||
# [email protected], 2024 | ||
|
||
import sys | ||
sys.path.append("./src/") | ||
|
||
import torch | ||
import torchvision.transforms as transforms | ||
import numpy as np | ||
from PIL import Image | ||
|
||
import innate.torch.torch_tools as torch_tools | ||
from innate.torch.siren import SirenNet, SirenImageWrapper | ||
|
||
|
||
# ------------------------- | ||
# Open the input image file | ||
|
||
img_jpeg = Image.open('examples/data/brains.jpeg') | ||
|
||
# Define and apply the transformation to convert the image to a tensor | ||
transform = transforms.ToTensor() | ||
img = transform(img_jpeg).float() | ||
|
||
# Add [dummy] batch dimension to beginning | ||
img = img.unsqueeze(0) | ||
|
||
torch_tools.visualize_img(img=img, savename=f'input_image.jpeg') | ||
# ----------------------------- | ||
|
||
|
||
# ----------------------------- | ||
# Set reproducable seed first | ||
|
||
torch_tools.set_seed(1234) | ||
|
||
# Create the neural net: R^{dim_in} --> R^{dim_out} | ||
net = SirenNet( | ||
dim_in = 2, # input dimension, e.g. 2D coordinate | ||
dim_out = 1, # output dimension per coordinate, e.g. for RGB image (3) | ||
dim_hidden = 512, # (hyperparam) hidden dimension | ||
num_layers = 4, # (hyperparam) number of layers | ||
w0_initial = 30.0 # (hyperparam) init noise level for the first layer (default 30) | ||
) | ||
|
||
# Helper wrapper for image signals | ||
model = SirenImageWrapper( | ||
net, | ||
image_height = img.shape[-2], | ||
image_width = img.shape[-1], | ||
) | ||
|
||
print(f'img.shape = {img.shape} | (mean,std) = ({torch.mean(img.flatten()):0.2f}, {torch.std(img.flatten()):0.2f})') | ||
|
||
# ----------------------------------------------- | ||
# Set the device and transfer tensors | ||
|
||
device = torch_tools.get_device('auto') | ||
|
||
img = img.to(device) | ||
model = model.to(device) | ||
|
||
|
||
# ----------------------------------------------- | ||
# Define the optimizer | ||
|
||
num_iter = 10000 | ||
lr = 3e-4 | ||
weight_decay = 1e-5 | ||
max_grad_norm = 1.0 | ||
|
||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) | ||
|
||
|
||
# ----------------------------------------------- | ||
# Define learning rate scheduler | ||
|
||
scheduler_param = { | ||
'T_0' : 500, # Period | ||
'eta_min': lr / 10, # Minimum learning rate | ||
'T_mult' : 1 | ||
} | ||
|
||
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, **scheduler_param) | ||
# ----------------------------------------------- | ||
|
||
modelfile = 'siren_model.pth' | ||
load_model = False | ||
visualize = False | ||
|
||
# Load the existing model | ||
if load_model: | ||
torch_tools.load_torch_model(modelfile = modelfile, | ||
model = model, | ||
optimizer = optimizer, | ||
scheduler = scheduler) | ||
|
||
# Training loop | ||
# (single batch gradient descent == full image at once) | ||
for i in range(1,num_iter+1): | ||
|
||
optimizer.zero_grad() #! | ||
loss = model(img) # MSE loss inside the wrapper | ||
loss.backward() | ||
|
||
# Clip gradients by norm | ||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) | ||
|
||
optimizer.step() | ||
scheduler.step() | ||
|
||
print(f'iter = {i} / {num_iter} | loss = {loss.item():0.3E} | lr = {scheduler.get_last_lr()[0]:0.3E}') | ||
|
||
with torch.no_grad(): | ||
if (i % 10) == 0: | ||
|
||
# This will evaluate the model prediction implicitly | ||
# at the training image point grid coordinates | ||
pred_img = model() | ||
|
||
# Visualize | ||
if visualize: | ||
torch_tools.visualize_img(img=pred_img, savename=f'pred_image_iter_{i}.jpeg') | ||
|
||
# Save the model | ||
torch_tools.save_torch_model(modelfile = modelfile, | ||
model = model, | ||
optimizer = optimizer, | ||
scheduler = scheduler) |
Oops, something went wrong.