Skip to content

Commit aff6db0

Browse files
authored
MIDI-84: ECG model (#1)
* MIDI-84: first model * create folders * MIDI-84: first whole run * save checkpoints * naming and comments * fix randomness? * run name change * normalize input * revert normalization change * random labels * continue training code * generator leaky relu * get in touch with data * training loop for 1 channel data * generator fitted for data * wider plots * combine loaders * log interval fix * dataloader changes * update README * change wandb logging * comment for soft labels * change nz name * log gradients * powers of 2 input size compatibility * config changes * suggested changes * option for SGD optim * larger model * update README with baseline run * Revert "larger model" This reverts commit 9e5140e. * update README
1 parent 7cea80a commit aff6db0

File tree

7 files changed

+1002
-0
lines changed

7 files changed

+1002
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,5 @@ checkpoints/
66
wandb/
77
tests/
88
ecgdata/
9+
tmp/
10+
multirun/

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# ecg-gan
22

3+
Check out this [wandb run](https://wandb.ai/sjanas/ECG%20GAN/runs/bpr8skng) to get a feel for the baseline model or run `python train.py` to reproduce the results
4+
5+
### Suggested reading
6+
7+
- [DCGAN TUTORIAL](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html) on PyTorch website.
8+
- [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434) by Alec Radford, Luke Metz, Soumith Chintala.
9+
- [Tips and tricks to make GANs work](https://github.com/soumith/ganhacks)
10+
- Focused on points: 6, 7. Applied most of them.
11+
312
### Code Style
413

514
This repository uses pre-commit hooks with forced python formatting ([black](https://github.com/psf/black),

configs/config.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
train:
2+
batch_size: 64
3+
epochs: 5
4+
log_interval: 100
5+
generator_lr: 1e-4
6+
discriminator_adam_lr: 1e-5
7+
num_workers: 4
8+
# For SGD optimizer
9+
use_sgd: true
10+
discriminator_sgd_lr: 2e-4 # It's slightly slower and needs higher learning rate
11+
# for continuing training
12+
load_checkpoint: # change from None to checkpoint path to continue training
13+
more_epochs: 20 # How many more epochs to train
14+
15+
logger:
16+
checkpoint_path: "checkpoints/"
17+
chart_path: "tmp/"
18+
19+
system:
20+
device: "cuda:0"
21+
seed: 23
22+
23+
data:
24+
channels: 1
25+
size: 1000
26+
27+
discriminator:
28+
neurons: [64, 128, 256, 512]
29+
beta: 0.5
30+
31+
generator:
32+
noise_size: 200
33+
beta: 0.5
34+
35+
project: "ECG GAN"
36+
run_date: ${now:%Y_%m_%d_%H_%M}
37+
run_name: "ECG_GAN_${run_date}"

data_showcase.ipynb

Lines changed: 581 additions & 0 deletions
Large diffs are not rendered by default.

model/dcgan.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch.nn as nn
2+
3+
4+
def weights_init(m):
5+
classname = m.__class__.__name__
6+
if classname.find("Conv") != -1:
7+
nn.init.normal_(m.weight.data, 0.0, 0.02)
8+
elif classname.find("BatchNorm") != -1:
9+
nn.init.normal_(m.weight.data, 1.0, 0.02)
10+
nn.init.constant_(m.bias.data, 0)
11+
12+
13+
class Discriminator(nn.Module):
14+
def __init__(self, input_channels, input_size, neurons):
15+
super().__init__()
16+
layers = []
17+
prev_channels = input_channels
18+
for n in neurons:
19+
layers.extend(
20+
[
21+
nn.Conv1d(prev_channels, n, 4, 2, 1, bias=False),
22+
nn.BatchNorm1d(n) if prev_channels != input_channels else nn.Identity(),
23+
nn.LeakyReLU(0.2, inplace=True),
24+
]
25+
)
26+
prev_channels = n
27+
layers.append(nn.Conv1d(prev_channels, 1, input_size // (2 ** len(neurons)), 1, 0, bias=False))
28+
layers.append(nn.Sigmoid())
29+
30+
self.main = nn.Sequential(*layers)
31+
32+
def forward(self, x):
33+
x = self.main(x)
34+
return x
35+
36+
37+
class Generator(nn.Module):
38+
def __init__(self, noise_size, output_size):
39+
super(Generator, self).__init__()
40+
self.output_size = output_size
41+
size_multiplier = 1000 // output_size
42+
self.main = nn.Sequential(
43+
nn.ConvTranspose1d(noise_size, 512, 64 // size_multiplier, 1, 0, bias=False),
44+
nn.BatchNorm1d(512),
45+
nn.LeakyReLU(0.2, inplace=True),
46+
nn.ConvTranspose1d(512, 256, 4, 2, 1, bias=False),
47+
nn.BatchNorm1d(256),
48+
nn.LeakyReLU(0.2, inplace=True),
49+
nn.ConvTranspose1d(256, 128, 4, 2, 1, bias=False),
50+
nn.BatchNorm1d(128),
51+
nn.LeakyReLU(0.2, inplace=True),
52+
nn.ConvTranspose1d(128, 64, 4, 2, 1, bias=False),
53+
nn.BatchNorm1d(64),
54+
nn.LeakyReLU(0.2, inplace=True),
55+
nn.ConvTranspose1d(64, 1, 4, 2, 1, bias=False),
56+
nn.Tanh(),
57+
)
58+
59+
def forward(self, x):
60+
x = self.main(x)
61+
62+
x = x[:, :, : self.output_size]
63+
return x

train.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import os
2+
3+
import hydra
4+
import torch
5+
import torch.nn as nn
6+
from tqdm import tqdm
7+
import torch.optim as optim
8+
import matplotlib.pyplot as plt
9+
from torch.utils.data import DataLoader
10+
from omegaconf import OmegaConf, DictConfig
11+
12+
import wandb
13+
from utils.dataloaders import set_seed, create_dataloader
14+
from model.dcgan import Generator, Discriminator, weights_init
15+
16+
17+
def create_folders(cfg: DictConfig):
18+
logger_dict = cfg.logger
19+
20+
for _, value in logger_dict.items():
21+
if value[-1] == "/":
22+
value = value[:-1]
23+
os.makedirs(value, exist_ok=True)
24+
25+
26+
@torch.no_grad()
27+
def visualize_training(
28+
generator: Generator,
29+
fixed_noise: torch.Tensor,
30+
epoch: int,
31+
batch_idx: int,
32+
chart_path: str = "tmp/",
33+
):
34+
generator.eval()
35+
36+
# Generate fake data using the generator and fixed noise
37+
fake_data = generator(fixed_noise).detach().cpu().numpy()
38+
39+
# Create a figure and a grid of subplots
40+
fig, axes = plt.subplots(
41+
nrows=4,
42+
ncols=1,
43+
figsize=[10, 4],
44+
gridspec_kw={"hspace": 0},
45+
)
46+
47+
# Loop through each subplot to plot the 2-channel data
48+
for i in range(len(fixed_noise)):
49+
axes[i].plot(fake_data[i, 0, :])
50+
51+
fig.suptitle(f"Epoch {epoch}, Batch {batch_idx}")
52+
fig.tight_layout()
53+
# save the figure to chart_path/ folder
54+
# fig.savefig(f"{chart_path}epoch_{epoch}_batch_{batch_idx}.png")
55+
return fig
56+
57+
58+
def average_gradient(model: nn.Module) -> dict:
59+
avg_gradients = {}
60+
for name, param in model.named_parameters():
61+
if param.grad is not None:
62+
avg_gradients[name] = torch.mean(torch.abs(param.grad)).item()
63+
return avg_gradients
64+
65+
66+
def train_epoch(
67+
generator: Generator,
68+
discriminator: Discriminator,
69+
train_loader: DataLoader,
70+
gen_optimizer: torch.optim,
71+
disc_optimizer: torch.optim,
72+
criterion: nn.Module,
73+
cfg: DictConfig,
74+
fixed_noise: torch.Tensor,
75+
epoch: int,
76+
) -> None:
77+
generator.train()
78+
discriminator.train()
79+
80+
def random_labels(batch_size, start, end, device):
81+
return (end - start) * torch.rand(batch_size, device=device) + start
82+
83+
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
84+
for batch_idx, batch in progress_bar:
85+
real_data = batch[0].to(cfg.system.device)
86+
batch_size = real_data.size(0)
87+
88+
# trick 6 from: https://github.com/soumith/ganhacks
89+
real_labels = random_labels(batch_size, 0.8, 1, cfg.system.device)
90+
fake_labels = random_labels(batch_size, 0.0, 0.1, cfg.system.device)
91+
92+
# train discriminator
93+
discriminator.zero_grad()
94+
label = real_labels
95+
real_data = real_data.view(real_data.shape[0], 1, real_data.shape[1]).to(cfg.system.device)
96+
disc_real_output = discriminator(real_data).view(-1)
97+
discriminator_error_real = criterion(disc_real_output, label)
98+
discriminator_error_real.backward()
99+
D_x = disc_real_output.mean().item() # Mean discriminator output for real data
100+
101+
# train generator
102+
noise = torch.randn(batch_size, cfg.generator.noise_size, cfg.data.channels, device=cfg.system.device)
103+
fake_data = generator(noise)
104+
label = fake_labels
105+
disc_fake_output = discriminator(fake_data.detach()).view(-1)
106+
107+
discriminator_error_fake = criterion(disc_fake_output, label)
108+
discriminator_error_fake.backward()
109+
D_G_z1 = disc_fake_output.mean().item() # Discriminator's average output when evaluating the fake data
110+
discriminator_error = discriminator_error_real + discriminator_error_fake
111+
disc_optimizer.step()
112+
generator.zero_grad()
113+
114+
label = real_labels
115+
disc_output_after_update = discriminator(fake_data).view(-1)
116+
generator_error = criterion(disc_output_after_update, label)
117+
generator_error.backward()
118+
D_G_z2 = disc_output_after_update.mean().item() # Discriminator's output after updating the generator, fake data
119+
gen_optimizer.step()
120+
121+
# log to wandb
122+
if batch_idx % cfg.train.log_interval == 0:
123+
generator_gradients = average_gradient(generator)
124+
wandb.log(
125+
{
126+
"generator_error": generator_error.item(),
127+
"discriminator_error": discriminator_error.item(),
128+
"D_x": D_x,
129+
"D_G_z1": D_G_z1,
130+
"D_G_z2": D_G_z2,
131+
"generator/generator_gradients": generator_gradients,
132+
},
133+
commit=False,
134+
)
135+
fig = visualize_training(generator, fixed_noise, epoch, batch_idx, cfg.logger.chart_path)
136+
wandb.log({"fixed noise": wandb.Image(fig)}, commit=True)
137+
# TODO: Might add local saving here as well
138+
plt.close(fig)
139+
140+
checkpoint = {
141+
"epoch": epoch,
142+
"generator_state_dict": generator.state_dict(),
143+
"discriminator_state_dict": discriminator.state_dict(),
144+
"gen_optimizer_state_dict": gen_optimizer.state_dict(),
145+
"disc_optimizer_state_dict": disc_optimizer.state_dict(),
146+
"config": OmegaConf.to_object(cfg),
147+
"fixed_noise": fixed_noise,
148+
}
149+
torch.save(checkpoint, f"{cfg.logger.checkpoint_path}{cfg.run_name}_{epoch}.pt")
150+
151+
152+
@hydra.main(version_base=None, config_path="configs", config_name="config")
153+
def main(cfg: DictConfig):
154+
run = wandb.init(
155+
project=cfg.project,
156+
name=cfg.run_name,
157+
job_type="train",
158+
config=OmegaConf.to_container(cfg, resolve=True),
159+
)
160+
set_seed(cfg.system.seed)
161+
create_folders(cfg)
162+
163+
# Initialize models
164+
discriminator = Discriminator(
165+
input_channels=cfg.data.channels,
166+
input_size=cfg.data.size,
167+
neurons=cfg.discriminator.neurons,
168+
).to(cfg.system.device)
169+
170+
generator = Generator(
171+
noise_size=cfg.generator.noise_size,
172+
output_size=cfg.data.size,
173+
).to(cfg.system.device)
174+
175+
# Add random weights
176+
discriminator.apply(weights_init)
177+
generator.apply(weights_init)
178+
179+
# criterion
180+
criterion = nn.BCELoss()
181+
# optimizer
182+
if cfg.train.use_sgd:
183+
optimizer_discriminator = optim.SGD(
184+
discriminator.parameters(),
185+
lr=cfg.train.discriminator_sgd_lr,
186+
)
187+
else:
188+
optimizer_discriminator = optim.Adam(
189+
discriminator.parameters(),
190+
lr=cfg.train.discriminator_adam_lr,
191+
betas=(cfg.discriminator.beta, 0.999),
192+
)
193+
194+
optimizer_generator = optim.Adam(
195+
generator.parameters(),
196+
lr=cfg.train.generator_lr,
197+
betas=(cfg.generator.beta, 0.999),
198+
)
199+
if cfg.train.load_checkpoint is not None:
200+
checkpoint = torch.load(cfg.train.load_checkpoint)
201+
generator.load_state_dict(checkpoint["generator_state_dict"])
202+
discriminator.load_state_dict(checkpoint["discriminator_state_dict"])
203+
optimizer_generator.load_state_dict(checkpoint["gen_optimizer_state_dict"])
204+
optimizer_discriminator.load_state_dict(checkpoint["disc_optimizer_state_dict"])
205+
fixed_noise = checkpoint["fixed_noise"]
206+
epoch = checkpoint["epoch"]
207+
else:
208+
num_test_noises = 4
209+
epoch = 0
210+
# Fixed noise, used for visualizing training process
211+
fixed_noise = torch.randn(num_test_noises, cfg.generator.noise_size, cfg.data.channels, device=cfg.system.device)
212+
# get loader:
213+
# train_loader, _, _ = create_dataloader(cfg, seed=cfg.system.seed)
214+
train_loader = create_dataloader(cfg, seed=cfg.system.seed, splits=["train"])
215+
print(len(train_loader))
216+
217+
# train epochs
218+
epochs = cfg.train.epochs if epoch == 0 else cfg.train.more_epochs
219+
start_epoch = epoch + 1
220+
for epoch in range(start_epoch, epochs + 1):
221+
train_epoch(
222+
generator=generator,
223+
discriminator=discriminator,
224+
train_loader=train_loader,
225+
gen_optimizer=optimizer_generator,
226+
disc_optimizer=optimizer_discriminator,
227+
criterion=criterion,
228+
cfg=cfg,
229+
fixed_noise=fixed_noise,
230+
epoch=epoch,
231+
)
232+
run.finish()
233+
234+
235+
if __name__ == "__main__":
236+
main()

0 commit comments

Comments
 (0)