|
| 1 | +import os |
| 2 | + |
| 3 | +import hydra |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +from tqdm import tqdm |
| 7 | +import torch.optim as optim |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +from torch.utils.data import DataLoader |
| 10 | +from omegaconf import OmegaConf, DictConfig |
| 11 | + |
| 12 | +import wandb |
| 13 | +from utils.dataloaders import set_seed, create_dataloader |
| 14 | +from model.dcgan import Generator, Discriminator, weights_init |
| 15 | + |
| 16 | + |
| 17 | +def create_folders(cfg: DictConfig): |
| 18 | + logger_dict = cfg.logger |
| 19 | + |
| 20 | + for _, value in logger_dict.items(): |
| 21 | + if value[-1] == "/": |
| 22 | + value = value[:-1] |
| 23 | + os.makedirs(value, exist_ok=True) |
| 24 | + |
| 25 | + |
| 26 | +@torch.no_grad() |
| 27 | +def visualize_training( |
| 28 | + generator: Generator, |
| 29 | + fixed_noise: torch.Tensor, |
| 30 | + epoch: int, |
| 31 | + batch_idx: int, |
| 32 | + chart_path: str = "tmp/", |
| 33 | +): |
| 34 | + generator.eval() |
| 35 | + |
| 36 | + # Generate fake data using the generator and fixed noise |
| 37 | + fake_data = generator(fixed_noise).detach().cpu().numpy() |
| 38 | + |
| 39 | + # Create a figure and a grid of subplots |
| 40 | + fig, axes = plt.subplots( |
| 41 | + nrows=4, |
| 42 | + ncols=1, |
| 43 | + figsize=[10, 4], |
| 44 | + gridspec_kw={"hspace": 0}, |
| 45 | + ) |
| 46 | + |
| 47 | + # Loop through each subplot to plot the 2-channel data |
| 48 | + for i in range(len(fixed_noise)): |
| 49 | + axes[i].plot(fake_data[i, 0, :]) |
| 50 | + |
| 51 | + fig.suptitle(f"Epoch {epoch}, Batch {batch_idx}") |
| 52 | + fig.tight_layout() |
| 53 | + # save the figure to chart_path/ folder |
| 54 | + # fig.savefig(f"{chart_path}epoch_{epoch}_batch_{batch_idx}.png") |
| 55 | + return fig |
| 56 | + |
| 57 | + |
| 58 | +def average_gradient(model: nn.Module) -> dict: |
| 59 | + avg_gradients = {} |
| 60 | + for name, param in model.named_parameters(): |
| 61 | + if param.grad is not None: |
| 62 | + avg_gradients[name] = torch.mean(torch.abs(param.grad)).item() |
| 63 | + return avg_gradients |
| 64 | + |
| 65 | + |
| 66 | +def train_epoch( |
| 67 | + generator: Generator, |
| 68 | + discriminator: Discriminator, |
| 69 | + train_loader: DataLoader, |
| 70 | + gen_optimizer: torch.optim, |
| 71 | + disc_optimizer: torch.optim, |
| 72 | + criterion: nn.Module, |
| 73 | + cfg: DictConfig, |
| 74 | + fixed_noise: torch.Tensor, |
| 75 | + epoch: int, |
| 76 | +) -> None: |
| 77 | + generator.train() |
| 78 | + discriminator.train() |
| 79 | + |
| 80 | + def random_labels(batch_size, start, end, device): |
| 81 | + return (end - start) * torch.rand(batch_size, device=device) + start |
| 82 | + |
| 83 | + progress_bar = tqdm(enumerate(train_loader), total=len(train_loader)) |
| 84 | + for batch_idx, batch in progress_bar: |
| 85 | + real_data = batch[0].to(cfg.system.device) |
| 86 | + batch_size = real_data.size(0) |
| 87 | + |
| 88 | + # trick 6 from: https://github.com/soumith/ganhacks |
| 89 | + real_labels = random_labels(batch_size, 0.8, 1, cfg.system.device) |
| 90 | + fake_labels = random_labels(batch_size, 0.0, 0.1, cfg.system.device) |
| 91 | + |
| 92 | + # train discriminator |
| 93 | + discriminator.zero_grad() |
| 94 | + label = real_labels |
| 95 | + real_data = real_data.view(real_data.shape[0], 1, real_data.shape[1]).to(cfg.system.device) |
| 96 | + disc_real_output = discriminator(real_data).view(-1) |
| 97 | + discriminator_error_real = criterion(disc_real_output, label) |
| 98 | + discriminator_error_real.backward() |
| 99 | + D_x = disc_real_output.mean().item() # Mean discriminator output for real data |
| 100 | + |
| 101 | + # train generator |
| 102 | + noise = torch.randn(batch_size, cfg.generator.noise_size, cfg.data.channels, device=cfg.system.device) |
| 103 | + fake_data = generator(noise) |
| 104 | + label = fake_labels |
| 105 | + disc_fake_output = discriminator(fake_data.detach()).view(-1) |
| 106 | + |
| 107 | + discriminator_error_fake = criterion(disc_fake_output, label) |
| 108 | + discriminator_error_fake.backward() |
| 109 | + D_G_z1 = disc_fake_output.mean().item() # Discriminator's average output when evaluating the fake data |
| 110 | + discriminator_error = discriminator_error_real + discriminator_error_fake |
| 111 | + disc_optimizer.step() |
| 112 | + generator.zero_grad() |
| 113 | + |
| 114 | + label = real_labels |
| 115 | + disc_output_after_update = discriminator(fake_data).view(-1) |
| 116 | + generator_error = criterion(disc_output_after_update, label) |
| 117 | + generator_error.backward() |
| 118 | + D_G_z2 = disc_output_after_update.mean().item() # Discriminator's output after updating the generator, fake data |
| 119 | + gen_optimizer.step() |
| 120 | + |
| 121 | + # log to wandb |
| 122 | + if batch_idx % cfg.train.log_interval == 0: |
| 123 | + generator_gradients = average_gradient(generator) |
| 124 | + wandb.log( |
| 125 | + { |
| 126 | + "generator_error": generator_error.item(), |
| 127 | + "discriminator_error": discriminator_error.item(), |
| 128 | + "D_x": D_x, |
| 129 | + "D_G_z1": D_G_z1, |
| 130 | + "D_G_z2": D_G_z2, |
| 131 | + "generator/generator_gradients": generator_gradients, |
| 132 | + }, |
| 133 | + commit=False, |
| 134 | + ) |
| 135 | + fig = visualize_training(generator, fixed_noise, epoch, batch_idx, cfg.logger.chart_path) |
| 136 | + wandb.log({"fixed noise": wandb.Image(fig)}, commit=True) |
| 137 | + # TODO: Might add local saving here as well |
| 138 | + plt.close(fig) |
| 139 | + |
| 140 | + checkpoint = { |
| 141 | + "epoch": epoch, |
| 142 | + "generator_state_dict": generator.state_dict(), |
| 143 | + "discriminator_state_dict": discriminator.state_dict(), |
| 144 | + "gen_optimizer_state_dict": gen_optimizer.state_dict(), |
| 145 | + "disc_optimizer_state_dict": disc_optimizer.state_dict(), |
| 146 | + "config": OmegaConf.to_object(cfg), |
| 147 | + "fixed_noise": fixed_noise, |
| 148 | + } |
| 149 | + torch.save(checkpoint, f"{cfg.logger.checkpoint_path}{cfg.run_name}_{epoch}.pt") |
| 150 | + |
| 151 | + |
| 152 | +@hydra.main(version_base=None, config_path="configs", config_name="config") |
| 153 | +def main(cfg: DictConfig): |
| 154 | + run = wandb.init( |
| 155 | + project=cfg.project, |
| 156 | + name=cfg.run_name, |
| 157 | + job_type="train", |
| 158 | + config=OmegaConf.to_container(cfg, resolve=True), |
| 159 | + ) |
| 160 | + set_seed(cfg.system.seed) |
| 161 | + create_folders(cfg) |
| 162 | + |
| 163 | + # Initialize models |
| 164 | + discriminator = Discriminator( |
| 165 | + input_channels=cfg.data.channels, |
| 166 | + input_size=cfg.data.size, |
| 167 | + neurons=cfg.discriminator.neurons, |
| 168 | + ).to(cfg.system.device) |
| 169 | + |
| 170 | + generator = Generator( |
| 171 | + noise_size=cfg.generator.noise_size, |
| 172 | + output_size=cfg.data.size, |
| 173 | + ).to(cfg.system.device) |
| 174 | + |
| 175 | + # Add random weights |
| 176 | + discriminator.apply(weights_init) |
| 177 | + generator.apply(weights_init) |
| 178 | + |
| 179 | + # criterion |
| 180 | + criterion = nn.BCELoss() |
| 181 | + # optimizer |
| 182 | + if cfg.train.use_sgd: |
| 183 | + optimizer_discriminator = optim.SGD( |
| 184 | + discriminator.parameters(), |
| 185 | + lr=cfg.train.discriminator_sgd_lr, |
| 186 | + ) |
| 187 | + else: |
| 188 | + optimizer_discriminator = optim.Adam( |
| 189 | + discriminator.parameters(), |
| 190 | + lr=cfg.train.discriminator_adam_lr, |
| 191 | + betas=(cfg.discriminator.beta, 0.999), |
| 192 | + ) |
| 193 | + |
| 194 | + optimizer_generator = optim.Adam( |
| 195 | + generator.parameters(), |
| 196 | + lr=cfg.train.generator_lr, |
| 197 | + betas=(cfg.generator.beta, 0.999), |
| 198 | + ) |
| 199 | + if cfg.train.load_checkpoint is not None: |
| 200 | + checkpoint = torch.load(cfg.train.load_checkpoint) |
| 201 | + generator.load_state_dict(checkpoint["generator_state_dict"]) |
| 202 | + discriminator.load_state_dict(checkpoint["discriminator_state_dict"]) |
| 203 | + optimizer_generator.load_state_dict(checkpoint["gen_optimizer_state_dict"]) |
| 204 | + optimizer_discriminator.load_state_dict(checkpoint["disc_optimizer_state_dict"]) |
| 205 | + fixed_noise = checkpoint["fixed_noise"] |
| 206 | + epoch = checkpoint["epoch"] |
| 207 | + else: |
| 208 | + num_test_noises = 4 |
| 209 | + epoch = 0 |
| 210 | + # Fixed noise, used for visualizing training process |
| 211 | + fixed_noise = torch.randn(num_test_noises, cfg.generator.noise_size, cfg.data.channels, device=cfg.system.device) |
| 212 | + # get loader: |
| 213 | + # train_loader, _, _ = create_dataloader(cfg, seed=cfg.system.seed) |
| 214 | + train_loader = create_dataloader(cfg, seed=cfg.system.seed, splits=["train"]) |
| 215 | + print(len(train_loader)) |
| 216 | + |
| 217 | + # train epochs |
| 218 | + epochs = cfg.train.epochs if epoch == 0 else cfg.train.more_epochs |
| 219 | + start_epoch = epoch + 1 |
| 220 | + for epoch in range(start_epoch, epochs + 1): |
| 221 | + train_epoch( |
| 222 | + generator=generator, |
| 223 | + discriminator=discriminator, |
| 224 | + train_loader=train_loader, |
| 225 | + gen_optimizer=optimizer_generator, |
| 226 | + disc_optimizer=optimizer_discriminator, |
| 227 | + criterion=criterion, |
| 228 | + cfg=cfg, |
| 229 | + fixed_noise=fixed_noise, |
| 230 | + epoch=epoch, |
| 231 | + ) |
| 232 | + run.finish() |
| 233 | + |
| 234 | + |
| 235 | +if __name__ == "__main__": |
| 236 | + main() |
0 commit comments