-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Utility functions for diffusion models
- Loading branch information
1 parent
dfe537d
commit b56b671
Showing
1 changed file
with
250 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from torchvision.utils import save_image, make_grid | ||
import matplotlib.pyplot as plt | ||
from matplotlib.animation import FuncAnimation, PillowWriter | ||
import os | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import Dataset | ||
from PIL import Image | ||
|
||
|
||
class ResidualConvBlock(nn.Module): | ||
def __init__( | ||
self, in_channels: int, out_channels: int, is_res: bool = False | ||
) -> None: | ||
super().__init__() | ||
|
||
# Check if input and output channels are the same for the residual connection | ||
self.same_channels = in_channels == out_channels | ||
|
||
# Flag for whether or not to use residual connection | ||
self.is_res = is_res | ||
|
||
# First convolutional layer | ||
self.conv1 = nn.Sequential( | ||
nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 | ||
nn.BatchNorm2d(out_channels), # Batch normalization | ||
nn.GELU(), # GELU activation function | ||
) | ||
|
||
# Second convolutional layer | ||
self.conv2 = nn.Sequential( | ||
nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1 | ||
nn.BatchNorm2d(out_channels), # Batch normalization | ||
nn.GELU(), # GELU activation function | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
|
||
# If using residual connection | ||
if self.is_res: | ||
# Apply first convolutional layer | ||
x1 = self.conv1(x) | ||
|
||
# Apply second convolutional layer | ||
x2 = self.conv2(x1) | ||
|
||
# If input and output channels are the same, add residual connection directly | ||
if self.same_channels: | ||
out = x + x2 | ||
else: | ||
# If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection | ||
shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device) | ||
out = shortcut(x) + x2 | ||
#print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}") | ||
|
||
# Normalize output tensor | ||
return out / 1.414 | ||
|
||
# If not using residual connection, return output of second convolutional layer | ||
else: | ||
x1 = self.conv1(x) | ||
x2 = self.conv2(x1) | ||
return x2 | ||
|
||
# Method to get the number of output channels for this block | ||
def get_out_channels(self): | ||
return self.conv2[0].out_channels | ||
|
||
# Method to set the number of output channels for this block | ||
def set_out_channels(self, out_channels): | ||
self.conv1[0].out_channels = out_channels | ||
self.conv2[0].in_channels = out_channels | ||
self.conv2[0].out_channels = out_channels | ||
|
||
|
||
|
||
class UnetUp(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(UnetUp, self).__init__() | ||
|
||
# Create a list of layers for the upsampling block | ||
# The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers | ||
layers = [ | ||
nn.ConvTranspose2d(in_channels, out_channels, 2, 2), | ||
ResidualConvBlock(out_channels, out_channels), | ||
ResidualConvBlock(out_channels, out_channels), | ||
] | ||
|
||
# Use the layers to create a sequential model | ||
self.model = nn.Sequential(*layers) | ||
|
||
def forward(self, x, skip): | ||
# Concatenate the input tensor x with the skip connection tensor along the channel dimension | ||
x = torch.cat((x, skip), 1) | ||
|
||
# Pass the concatenated tensor through the sequential model and return the output | ||
x = self.model(x) | ||
return x | ||
|
||
|
||
class UnetDown(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(UnetDown, self).__init__() | ||
|
||
# Create a list of layers for the downsampling block | ||
# Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling | ||
layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)] | ||
|
||
# Use the layers to create a sequential model | ||
self.model = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
# Pass the input through the sequential model and return the output | ||
return self.model(x) | ||
|
||
class EmbedFC(nn.Module): | ||
def __init__(self, input_dim, emb_dim): | ||
super(EmbedFC, self).__init__() | ||
''' | ||
This class defines a generic one layer feed-forward neural network for embedding input data of | ||
dimensionality input_dim to an embedding space of dimensionality emb_dim. | ||
''' | ||
self.input_dim = input_dim | ||
|
||
# define the layers for the network | ||
layers = [ | ||
nn.Linear(input_dim, emb_dim), | ||
nn.GELU(), | ||
nn.Linear(emb_dim, emb_dim), | ||
] | ||
|
||
# create a PyTorch sequential model consisting of the defined layers | ||
self.model = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
# flatten the input tensor | ||
x = x.view(-1, self.input_dim) | ||
# apply the model layers to the flattened tensor | ||
return self.model(x) | ||
|
||
def unorm(x): | ||
# unity norm. results in range of [0,1] | ||
# assume x (h,w,3) | ||
xmax = x.max((0,1)) | ||
xmin = x.min((0,1)) | ||
return(x - xmin)/(xmax - xmin) | ||
|
||
def norm_all(store, n_t, n_s): | ||
# runs unity norm on all timesteps of all samples | ||
nstore = np.zeros_like(store) | ||
for t in range(n_t): | ||
for s in range(n_s): | ||
nstore[t,s] = unorm(store[t,s]) | ||
return nstore | ||
|
||
def norm_torch(x_all): | ||
# runs unity norm on all timesteps of all samples | ||
# input is (n_samples, 3,h,w), the torch image format | ||
x = x_all.cpu().numpy() | ||
xmax = x.max((2,3)) | ||
xmin = x.min((2,3)) | ||
xmax = np.expand_dims(xmax,(2,3)) | ||
xmin = np.expand_dims(xmin,(2,3)) | ||
nstore = (x - xmin)/(xmax - xmin) | ||
return torch.from_numpy(nstore) | ||
|
||
def gen_tst_context(n_cfeat): | ||
""" | ||
Generate test context vectors | ||
""" | ||
vec = torch.tensor([ | ||
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | ||
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | ||
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | ||
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | ||
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0], # human, non-human, food, spell, side-facing | ||
[1,0,0,0,0], [0,1,0,0,0], [0,0,1,0,0], [0,0,0,1,0], [0,0,0,0,1], [0,0,0,0,0]] # human, non-human, food, spell, side-facing | ||
) | ||
return len(vec), vec | ||
|
||
def plot_grid(x,n_sample,n_rows,save_dir,w): | ||
# x:(n_sample, 3, h, w) | ||
ncols = n_sample//n_rows | ||
grid = make_grid(norm_torch(x), nrow=ncols) # curiously, nrow is number of columns.. or number of items in the row. | ||
save_image(grid, save_dir + f"run_image_w{w}.png") | ||
print('saved image at ' + save_dir + f"run_image_w{w}.png") | ||
return grid | ||
|
||
def plot_sample(x_gen_store,n_sample,nrows,save_dir, fn, w, save=False): | ||
ncols = n_sample//nrows | ||
sx_gen_store = np.moveaxis(x_gen_store,2,4) # change to Numpy image format (h,w,channels) vs (channels,h,w) | ||
nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_sample) # unity norm to put in range [0,1] for np.imshow | ||
|
||
# create gif of images evolving over time, based on x_gen_store | ||
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,figsize=(ncols,nrows)) | ||
def animate_diff(i, store): | ||
print(f'gif animating frame {i} of {store.shape[0]}', end='\r') | ||
plots = [] | ||
for row in range(nrows): | ||
for col in range(ncols): | ||
axs[row, col].clear() | ||
axs[row, col].set_xticks([]) | ||
axs[row, col].set_yticks([]) | ||
plots.append(axs[row, col].imshow(store[i,(row*ncols)+col])) | ||
return plots | ||
ani = FuncAnimation(fig, animate_diff, fargs=[nsx_gen_store], interval=200, blit=False, repeat=True, frames=nsx_gen_store.shape[0]) | ||
plt.close() | ||
if save: | ||
ani.save(save_dir + f"{fn}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5)) | ||
print('saved gif at ' + save_dir + f"{fn}_w{w}.gif") | ||
return ani | ||
|
||
|
||
class CustomDataset(Dataset): | ||
def __init__(self, sfilename, lfilename, transform, null_context=False): | ||
self.sprites = np.load(sfilename) | ||
self.slabels = np.load(lfilename) | ||
print(f"sprite shape: {self.sprites.shape}") | ||
print(f"labels shape: {self.slabels.shape}") | ||
self.transform = transform | ||
self.null_context = null_context | ||
self.sprites_shape = self.sprites.shape | ||
self.slabel_shape = self.slabels.shape | ||
|
||
# Return the number of images in the dataset | ||
def __len__(self): | ||
return len(self.sprites) | ||
|
||
# Get the image and label at a given index | ||
def __getitem__(self, idx): | ||
# Return the image and label as a tuple | ||
if self.transform: | ||
image = self.transform(self.sprites[idx]) | ||
if self.null_context: | ||
label = torch.tensor(0).to(torch.int64) | ||
else: | ||
label = torch.tensor(self.slabels[idx]).to(torch.int64) | ||
return (image, label) | ||
|
||
def getshapes(self): | ||
# return shapes of data and labels | ||
return self.sprites_shape, self.slabel_shape | ||
|
||
transform = transforms.Compose([ | ||
transforms.ToTensor(), # from [0,255] to range [0.0,1.0] | ||
transforms.Normalize((0.5,), (0.5,)) # range [-1,1] | ||
|
||
]) |