Skip to content

Commit 9b6e4a7

Browse files
authored
MIDI-92: Improvements to Discriminator and Data Normalization (#2)
* experiment with noise size * combine all data for loader * batch_size * broken TGAN * changed to 2 epochs * code from vqvae-midi * encode, decode function * copied transformer code to revise * rename to vqvae * refactor 01 * vqvae training for 1 chanel * cleaned dataset * interpolation eval * remove tanh * compute statistics for all data * normalized data, tanh * denormalization in evals * Add ResidualBlocks * update config * Update README * preprocessing ltafdb * change name
1 parent aff6db0 commit 9b6e4a7

26 files changed

+1788
-251
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# ecg-gan
22

3-
Check out this [wandb run](https://wandb.ai/sjanas/ECG%20GAN/runs/bpr8skng) to get a feel for the baseline model or run `python train.py` to reproduce the results
3+
Check out this [wandb run](https://wandb.ai/sjanas/ECG%20GAN/runs/fz2ptaxy) to get a feel for the baseline model or run `python train.py` to reproduce the results
44

55
### Suggested reading
66

configs/config.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
train:
2-
batch_size: 64
3-
epochs: 5
2+
batch_size: 16
3+
epochs: 3
44
log_interval: 100
5-
generator_lr: 1e-4
5+
generator_lr: 2e-4
66
discriminator_adam_lr: 1e-5
7-
num_workers: 4
7+
num_workers: 6
88
# For SGD optimizer
99
use_sgd: true
10-
discriminator_sgd_lr: 2e-4 # It's slightly slower and needs higher learning rate
10+
discriminator_sgd_lr: 1e-4 # It's slightly slower and needs higher learning rate
1111
# for continuing training
1212
load_checkpoint: # change from None to checkpoint path to continue training
1313
more_epochs: 20 # How many more epochs to train
@@ -29,9 +29,9 @@ discriminator:
2929
beta: 0.5
3030

3131
generator:
32-
noise_size: 200
32+
noise_size: 300
3333
beta: 0.5
3434

3535
project: "ECG GAN"
3636
run_date: ${now:%Y_%m_%d_%H_%M}
37-
run_name: "ECG_GAN_${run_date}"
37+
run_name: "residual_discriminator_${run_date}"

configs/config_tgan.yaml

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
train:
2+
batch_size: 16
3+
epochs: 5
4+
log_interval: 100
5+
generator_lr: 1e-4
6+
discriminator_lr: 1e-5
7+
num_workers: 4
8+
9+
# for continuing training
10+
load_checkpoint: # change from None to checkpoint path to continue training
11+
more_epochs: 20 # How many more epochs to train
12+
13+
logger:
14+
checkpoint_path: "checkpoints/"
15+
chart_path: "tmp/"
16+
17+
system:
18+
device: "cuda:0"
19+
seed: 23
20+
21+
data:
22+
channels: 1
23+
size: 1000
24+
25+
discriminator:
26+
n_layers: 1
27+
n_channel: 1
28+
kernel_size: 8
29+
dropout: 0.0
30+
31+
generator:
32+
noise_size: 100
33+
beta: 0.5
34+
n_layers: 1
35+
hidden_dim: 256
36+
37+
project: "ECG TGAN"
38+
run_date: ${now:%Y_%m_%d_%H_%M}
39+
run_name: "ECG_TGAN_${run_date}"

configs/config_transformer.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
latent_dim: 256
2+
image_size: 256
3+
num_codebook_vectors: 1024
4+
beta: 0.25
5+
image_channels: 1 # Since 1D data?
6+
dataset_path: "./data"
7+
checkpoint_path: "./checkpoints/last_ckpt.pt"
8+
batch_size: 20
9+
epochs: 100
10+
learning_rate: 2.25e-05
11+
beta1: 0.5
12+
beta2: 0.9
13+
disc_start: 10000
14+
disc_factor: 1.0
15+
l2_loss_factor: 1.0
16+
perceptual_loss_factor: 1.0
17+
pkeep: 0.5
18+
sos_token: 0
19+
20+
system:
21+
device: 'cuda:0'
22+
23+
# Parameters specific to our data shape
24+
seq_len: 1000

configs/config_vqvae.yaml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
train:
2+
batch_size: 16
3+
epochs: 2
4+
log_interval: 100
5+
lr: 2e-4
6+
num_workers: 4
7+
8+
logger:
9+
checkpoint_path: "checkpoints/"
10+
chart_path: "tmp/"
11+
12+
system:
13+
device: "cuda:0"
14+
seed: 23
15+
16+
data:
17+
channels: 1
18+
size: 1000
19+
20+
model:
21+
# Structure/Architecture
22+
output_features_filters: 1
23+
augment_output_features: False
24+
input_features_filters: 1 # Aligned with the number of channels
25+
augment_input_features: False
26+
output_features_dim: 1000 # Features size post decoding
27+
input_features_dim: 1000 # Aligned with the number of data points in each channel
28+
input_features_type: 'mfcc'
29+
# Encoder-Decoder details
30+
num_hiddens: 128
31+
num_residual_layers: 2
32+
num_residual_hiddens: 16
33+
use_kaiming_normal: True
34+
# VQ details
35+
embedding_dim: 100 # We might want to lower this value in the future, size of embedding vectors
36+
num_embeddings: 64 # K, needs empirical tuning
37+
commitment_cost: 0.25
38+
decay: 0.99
39+
# Misc
40+
record_codebook_stats: False
41+
verbose: False
42+
43+
project: "ECG GAN"
44+
run_date: ${now:%Y_%m_%d_%H_%M}
45+
run_name: "ECG_GAN_${run_date}"

data_showcase.ipynb

Lines changed: 73 additions & 226 deletions
Large diffs are not rendered by default.

evals/checkpoint_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
from omegaconf import OmegaConf
3+
from huggingface_hub.file_download import hf_hub_download
4+
5+
6+
def load_checkpoint(ckpt_path: str = None, omegaconf: bool = True):
7+
if ckpt_path is not None:
8+
checkpoint = torch.load(ckpt_path)
9+
else:
10+
checkpoint = torch.load(hf_hub_download("SneakyInsect/GANs", filename="double_generator_update_2023_09_26_21_23_3.pt"))
11+
cfg = checkpoint["config"]
12+
if omegaconf:
13+
cfg = OmegaConf.create(cfg)
14+
return checkpoint, cfg

evals/interpolate_ecg.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch
2+
import imageio
3+
import numpy as np
4+
from matplotlib import pyplot as plt
5+
6+
from model.dcgan import Generator
7+
from utils.dataloaders import set_seed
8+
from evals.checkpoint_utils import load_checkpoint
9+
10+
11+
def interpolate(noise1: torch.Tensor, noise2: torch.Tensor, num_interpolations: int) -> torch.Tensor:
12+
"""
13+
Create interpolated noises between noise1 and noise2, including noise1 at the start and noise2 at the end.
14+
"""
15+
alphas = torch.linspace(0, 1, num_interpolations + 2).to(noise1.device) # +2 to account for noise1 and noise2
16+
noises = [(1 - alpha) * noise1 + alpha * noise2 for alpha in alphas]
17+
return torch.stack(noises, dim=0)
18+
19+
20+
@torch.no_grad()
21+
def save_interpolations(generator, noise1: torch.Tensor, noise2: torch.Tensor, num_interpolations: int, save_path: str):
22+
generator.eval()
23+
24+
# create num_interpolation noises between noise1 and noise2
25+
noises = interpolate(noise1, noise2, num_interpolations)
26+
27+
# generate fake data using the generator and noises
28+
fake_data = generator(noises).detach().cpu().numpy()
29+
30+
# denormalize data from [-1, 1]
31+
global_max = [9.494109153747559, 7.599456787109375]
32+
global_min = [-10.515237808227539, -7.820725917816162]
33+
fake_data = (fake_data + 1) * (global_max[0] - global_min[0]) / 2 + global_min[0]
34+
35+
# Visualize and save as gif
36+
images = []
37+
for i in range(len(fake_data)):
38+
fig, ax = plt.subplots(figsize=[10, 4], gridspec_kw={"hspace": 0})
39+
ax.plot(fake_data[i, 0, :]) # Plotting the data
40+
ax.set_title("Generated Data Visualization")
41+
42+
# Convert Figure to image and append to images list
43+
fig.canvas.draw()
44+
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype="uint8")
45+
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
46+
images.append(image)
47+
plt.close(fig)
48+
49+
# Save images as gif
50+
imageio.mimsave(save_path, images, duration=250)
51+
print(f"Saved interpolation gif to {save_path}")
52+
53+
54+
if __name__ == "__main__":
55+
checkpoint_path = (
56+
"checkpoints/ECG_GAN_2023_09_29_10_00_1.pt" # change this to string path of checkpoint if you want to load from local
57+
)
58+
number_of_interpolations = 20
59+
save_path = "tmp/interpolation.gif"
60+
61+
# set seed for reproducibility
62+
set_seed(23)
63+
64+
# load checkpoint, config
65+
checkpoint, cfg = load_checkpoint(checkpoint_path)
66+
67+
# Initialize generator
68+
generator = Generator(
69+
noise_size=cfg.generator.noise_size,
70+
output_size=cfg.data.size,
71+
).to(cfg.system.device)
72+
73+
generator.load_state_dict(checkpoint["generator_state_dict"])
74+
75+
# create 2 noise vectors
76+
noise1 = torch.randn(cfg.generator.noise_size, cfg.data.channels, device=cfg.system.device)
77+
noise2 = torch.randn(cfg.generator.noise_size, cfg.data.channels, device=cfg.system.device)
78+
79+
save_interpolations(generator, noise1, noise2, number_of_interpolations, save_path)

model/dcgan.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,46 @@ def weights_init(m):
1010
nn.init.constant_(m.bias.data, 0)
1111

1212

13+
class ResidualBlock(nn.Module):
14+
def __init__(self, in_channels, out_channels):
15+
super(ResidualBlock, self).__init__()
16+
17+
self.residual = nn.Sequential(
18+
nn.Conv1d(in_channels, out_channels, 4, 2, 1, bias=False),
19+
nn.BatchNorm1d(out_channels),
20+
nn.LeakyReLU(0.2, inplace=True),
21+
# This line calculates the padding to ensure the size remains consistent
22+
nn.ConstantPad1d((0, 1), 0),
23+
nn.Conv1d(out_channels, out_channels, 4, 1, 1, bias=False),
24+
nn.BatchNorm1d(out_channels),
25+
)
26+
27+
self.shortcut = nn.Sequential()
28+
if in_channels != out_channels:
29+
self.shortcut.add_module("conv_shortcut", nn.Conv1d(in_channels, out_channels, 4, 2, 1, bias=False))
30+
self.shortcut.add_module("bn_shortcut", nn.BatchNorm1d(out_channels))
31+
32+
def forward(self, x):
33+
return nn.LeakyReLU(0.2, inplace=True)(self.residual(x) + self.shortcut(x))
34+
35+
1336
class Discriminator(nn.Module):
1437
def __init__(self, input_channels, input_size, neurons):
1538
super().__init__()
39+
1640
layers = []
1741
prev_channels = input_channels
1842
for n in neurons:
19-
layers.extend(
20-
[
21-
nn.Conv1d(prev_channels, n, 4, 2, 1, bias=False),
22-
nn.BatchNorm1d(n) if prev_channels != input_channels else nn.Identity(),
23-
nn.LeakyReLU(0.2, inplace=True),
24-
]
25-
)
43+
layers.append(ResidualBlock(prev_channels, n))
2644
prev_channels = n
45+
2746
layers.append(nn.Conv1d(prev_channels, 1, input_size // (2 ** len(neurons)), 1, 0, bias=False))
2847
layers.append(nn.Sigmoid())
2948

3049
self.main = nn.Sequential(*layers)
3150

3251
def forward(self, x):
33-
x = self.main(x)
34-
return x
52+
return self.main(x)
3553

3654

3755
class Generator(nn.Module):

0 commit comments

Comments
 (0)