Skip to content

Commit

Permalink
add torch models
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Jun 28, 2024
1 parent c89fc7c commit 077850c
Show file tree
Hide file tree
Showing 20 changed files with 572 additions and 0 deletions.
Binary file added examples/data/brains.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ pytensor~=2.23
tomli >= 2.0.0 ; python_version < "3.11"
pytest==8
pytest-cov==5
torch~=2.3.0
torchvision~=0.18.1
einops~=0.8.0
Binary file added siren_model.pth
Binary file not shown.
Binary file added siren_model_ngrid.pth
Binary file not shown.
Binary file added src/innate/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added src/innate/__pycache__/approximation.cpython-310.pyc
Binary file not shown.
Binary file added src/innate/__pycache__/io.cpython-310.pyc
Binary file not shown.
Binary file added src/innate/__pycache__/main.cpython-310.pyc
Binary file not shown.
Binary file added src/innate/__pycache__/plotting.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
185 changes: 185 additions & 0 deletions src/innate/torch/siren.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Original code from:
# https://github.com/lucidrains/siren-pytorch (MIT license)
#
# [email protected], 2024

import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange

# helpers

def exists(val):
return val is not None

def cast_tuple(val, repeat = 1):
return val if isinstance(val, tuple) else ((val,) * repeat)

# sin activation

class Sine(nn.Module):
def __init__(self, w0 = 1.):
super().__init__()
self.w0 = w0
def forward(self, x):
return torch.sin(self.w0 * x)

# siren layer

class Siren(nn.Module):
def __init__(
self,
dim_in,
dim_out,
w0 = 1.,
c = 6.,
is_first = False,
use_bias = True,
activation = None,
dropout = 0.
):
super().__init__()
self.dim_in = dim_in
self.is_first = is_first

weight = torch.zeros(dim_out, dim_in)
bias = torch.zeros(dim_out) if use_bias else None
self.init_(weight, bias, c = c, w0 = w0)

self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias) if use_bias else None
self.activation = Sine(w0) if activation is None else activation
self.dropout = nn.Dropout(dropout)

def init_(self, weight, bias, c, w0):
dim = self.dim_in

w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
weight.uniform_(-w_std, w_std)

if exists(bias):
bias.uniform_(-w_std, w_std)

def forward(self, x):
out = F.linear(x, self.weight, self.bias)
out = self.activation(out)
out = self.dropout(out)
return out

# siren network

class SirenNet(nn.Module):
def __init__(
self,
dim_in,
dim_hidden,
dim_out,
num_layers,
w0 = 1.,
w0_initial = 30.,
use_bias = True,
final_activation = None,
dropout = 0.
):
super().__init__()
self.num_layers = num_layers
self.dim_hidden = dim_hidden

self.layers = nn.ModuleList([])
for ind in range(num_layers):
is_first = ind == 0
layer_w0 = w0_initial if is_first else w0
layer_dim_in = dim_in if is_first else dim_hidden

layer = Siren(
dim_in = layer_dim_in,
dim_out = dim_hidden,
w0 = layer_w0,
use_bias = use_bias,
is_first = is_first,
dropout = dropout
)

self.layers.append(layer)

final_activation = nn.Identity() if not exists(final_activation) else final_activation
self.last_layer = Siren(dim_in = dim_hidden, dim_out = dim_out, w0 = w0, use_bias = use_bias, activation = final_activation)

def forward(self, x, mods = None):
mods = cast_tuple(mods, self.num_layers)

for layer, mod in zip(self.layers, mods):
x = layer(x)

if exists(mod):
x *= rearrange(mod, 'd -> () d')

return self.last_layer(x)

# modulatory feed forward

class Modulator(nn.Module):
def __init__(self, dim_in, dim_hidden, num_layers):
super().__init__()
self.layers = nn.ModuleList([])

for ind in range(num_layers):
is_first = ind == 0
dim = dim_in if is_first else (dim_hidden + dim_in)

self.layers.append(nn.Sequential(
nn.Linear(dim, dim_hidden),
nn.ReLU()
))

def forward(self, z):
x = z
hiddens = []

for layer in self.layers:
x = layer(x)
hiddens.append(x)
x = torch.cat((x, z))

return tuple(hiddens)

# wrapper

class SirenImageWrapper(nn.Module):
def __init__(self, net, image_width, image_height, latent_dim = None):
super().__init__()
assert isinstance(net, SirenNet), 'SirenWrapper must receive a Siren network'

self.net = net
self.image_width = image_width
self.image_height = image_height

self.modulator = None
if exists(latent_dim):
self.modulator = Modulator(
dim_in = latent_dim,
dim_hidden = net.dim_hidden,
num_layers = net.num_layers
)

tensors = [torch.linspace(-1, 1, steps = image_height), torch.linspace(-1, 1, steps = image_width)]
mgrid = torch.stack(torch.meshgrid(*tensors, indexing = 'ij'), dim=-1)
mgrid = rearrange(mgrid, 'h w c -> (h w) c')
self.register_buffer('grid', mgrid)

def forward(self, img = None, *, latent = None):
modulate = exists(self.modulator)
assert not (modulate ^ exists(latent)), 'latent vector must be only supplied if `latent_dim` was passed in on instantiation'

mods = self.modulator(latent) if modulate else None

coords = self.grid.clone().detach().requires_grad_()
out = self.net(coords, mods)
out = rearrange(out, '(h w) c -> () c h w', h = self.image_height, w = self.image_width)

if exists(img):
return F.mse_loss(img, out)

return out
132 changes: 132 additions & 0 deletions src/innate/torch/siren_train_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Neural regression of brain image data with SIREN neural net
#
# Number of iterations required may be O(10000)
#
# [email protected], 2024

import sys
sys.path.append("./src/")

import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

import innate.torch.torch_tools as torch_tools
from innate.torch.siren import SirenNet, SirenImageWrapper


# -------------------------
# Open the input image file

img_jpeg = Image.open('examples/data/brains.jpeg')

# Define and apply the transformation to convert the image to a tensor
transform = transforms.ToTensor()
img = transform(img_jpeg).float()

# Add [dummy] batch dimension to beginning
img = img.unsqueeze(0)

torch_tools.visualize_img(img=img, savename=f'input_image.jpeg')
# -----------------------------


# -----------------------------
# Set reproducable seed first

torch_tools.set_seed(1234)

# Create the neural net: R^{dim_in} --> R^{dim_out}
net = SirenNet(
dim_in = 2, # input dimension, e.g. 2D coordinate
dim_out = 1, # output dimension per coordinate, e.g. for RGB image (3)
dim_hidden = 512, # (hyperparam) hidden dimension
num_layers = 4, # (hyperparam) number of layers
w0_initial = 30.0 # (hyperparam) init noise level for the first layer (default 30)
)

# Helper wrapper for image signals
model = SirenImageWrapper(
net,
image_height = img.shape[-2],
image_width = img.shape[-1],
)

print(f'img.shape = {img.shape} | (mean,std) = ({torch.mean(img.flatten()):0.2f}, {torch.std(img.flatten()):0.2f})')

# -----------------------------------------------
# Set the device and transfer tensors

device = torch_tools.get_device('auto')

img = img.to(device)
model = model.to(device)


# -----------------------------------------------
# Define the optimizer

num_iter = 10000
lr = 3e-4
weight_decay = 1e-5
max_grad_norm = 1.0

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)


# -----------------------------------------------
# Define learning rate scheduler

scheduler_param = {
'T_0' : 500, # Period
'eta_min': lr / 10, # Minimum learning rate
'T_mult' : 1
}

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, **scheduler_param)
# -----------------------------------------------

modelfile = 'siren_model.pth'
load_model = False
visualize = False

# Load the existing model
if load_model:
torch_tools.load_torch_model(modelfile = modelfile,
model = model,
optimizer = optimizer,
scheduler = scheduler)

# Training loop
# (single batch gradient descent == full image at once)
for i in range(1,num_iter+1):

optimizer.zero_grad() #!
loss = model(img) # MSE loss inside the wrapper
loss.backward()

# Clip gradients by norm
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

optimizer.step()
scheduler.step()

print(f'iter = {i} / {num_iter} | loss = {loss.item():0.3E} | lr = {scheduler.get_last_lr()[0]:0.3E}')

with torch.no_grad():
if (i % 10) == 0:

# This will evaluate the model prediction implicitly
# at the training image point grid coordinates
pred_img = model()

# Visualize
if visualize:
torch_tools.visualize_img(img=pred_img, savename=f'pred_image_iter_{i}.jpeg')

# Save the model
torch_tools.save_torch_model(modelfile = modelfile,
model = model,
optimizer = optimizer,
scheduler = scheduler)
Loading

0 comments on commit 077850c

Please sign in to comment.