Skip to content

Commit

Permalink
Added content for architectures and projects doc pages
Browse files Browse the repository at this point in the history
  • Loading branch information
cnmy-ro committed Sep 24, 2021
1 parent aa73987 commit cd6d941
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 14 deletions.
26 changes: 25 additions & 1 deletion docs/package_overview/2_projects.md
Original file line number Diff line number Diff line change
@@ -1 +1,25 @@
# Projects
# `ganslate` Projects

In `ganslate`, a _project_ refers to a collection of all custom code and configuration files pertaining to your specific task. The project directory is expected to have a certain structure that isolates logically different parts of the project, such as data pipeline, GAN implementation, and configuration. The directory structure is as follows

```text
<your_project_dir>
|
|- datasets
| |- custom_train_dataset.py
| |- custom_val_test_dataset.py
|
|- architectures
| |- custom_gan.py
|
|- experiments
| |- exp1_config.yaml
|
|- __init__.py
|
|- README.md
```

The `__init__.py` file initializes your project directory as Python module which is necessary for `ganslate`'s configuration system to correctly function. (See [configuration](./7_configuration.md) for details). The `README.md` file could contain a description of your task.

`ganslate` provides a Cookiecutter template which can automatically generate an empty project for you. The tutorial [Your First Project](../tutorials_basic/2_new_project.md) provides detailed instructions on how to create and run your own project.
25 changes: 17 additions & 8 deletions docs/package_overview/3_datasets.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Datasets

If you are familiar with the [standard dataloader workflow](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) in _PyTorch_, `ganslate` uses similar dataset classes which are derived from `torch.utils.data.Dataset` to define the data fetching and preprocessing pipeline. Additionally, `ganslate`'s datasets have an associated Python `dataclass` in which all the data-related settings are defined.
If you are familiar with the [standard data loading workflow in _PyTorch_](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), `ganslate` uses similar dataset classes which are derived from `torch.utils.data.Dataset` to define the data fetching and preprocessing pipeline. Additionally, `ganslate`'s datasets have an associated Python `dataclass` in which all the data-related settings are defined.

Two classes - `PairedImageDataset` and `UnpairedImageDataset` - are supplied by the `ganslate` by default which can be used out-of-the-box on your image data. It is also possible to define your custom dataset class for use-cases requiring specialized processing steps for the data, for exmaple, in case of medical images. See [Your First Project](../tutorials_basic/2_new_project.md) for more more details on creating custom datasets.


--------------------------------------------------------------
## The `PairedImageDataset` and `UnpairedImageDataset` classes

The `PairedImageDataset` class enables loading from the storage an `A`-`B` image pair given a common index and applying optional joint transformations on the pair. This class is to be used in paired training as well during validation and/or testing when paired data is available.
Two classes - `PairedImageDataset` and `UnpairedImageDataset` - are supplied by the `ganslate` by default which can be used out-of-the-box on your image data.

On the other hand, `UnpairedImageDataset` fetches randomly a domain `A` image and a domain `B` image and applies optional transformations on each independently. As the name suggects, this class is meant to be used for unpaired training.
The `PairedImageDataset` class enables loading from the storage an _A_-_B_ image pair given a common index and applying optional joint transformations on the pair. This class is to be used in paired training as well during validation and/or testing when paired data is available.

On the other hand, `UnpairedImageDataset` fetches randomly a domain _A_ image and a domain _B_ image and applies optional transformations on each independently. As the name suggects, this class is meant to be used for unpaired training.

### Input and Output
Both classes expect your data directory to be structured in the following manner
Expand Down Expand Up @@ -38,7 +40,7 @@ And if using validation or testing data
| |- ...
```

where the sub-directories `A` and `B` contain the images. In situations where paired data is provided (i.e. paired training or all valdation/testing), the ordering of images in `A` corresponds to the ordering of images in `B`, meaning that the first `A` image the first `B` image are pairs and so on. Images with extensions `.jpg`, `.jpeg`, and `png` are supported.
where the sub-directories _A_ and _B_ contain the images. In situations where paired data is provided (i.e. paired training or all valdation/testing), the ordering of images in _A_ corresponds to the ordering of images in _B_, meaning that the first _A_ image the first _B_ image are pairs and so on. Images with extensions `.jpg`, `.jpeg`, and `png` are supported.


Both image dataset classes implemet a `__getitem__` method that outputs a sample dictionary of the following form
Expand All @@ -54,14 +56,21 @@ The configuration `dataclasses` associated with both default image datasets are

The two image datasets have additional settings which are same across the two datasets. These are:

- `image_channels`: Refers to the number of channels in the images. Only the images with 1 and 3 channels are supported (i.e. grayscale and RGB), and the channels should be the same across `A` and `B` images (i.e. either both should be grayscale or both should be RGB).
- `image_channels`: Refers to the number of channels in the images. Only the images with 1 and 3 channels are supported (i.e. grayscale and RGB), and the channels should be the same across _A_ and _B_ images (i.e. either both should be grayscale or both should be RGB).

- `preprocess`: Accepts a tuple of predefined strings that defines the preprocessing instructions. These predefined strings include `resize`, `scale_width`, `random_zoom`, `random_crop`, and `random_flip` of which `resize` and `scale_width` specify the initial resizing operations (choose either one or none), whereas the rest correspond to the random transforms used as data augmentation during training. An example value for the `preprocess` settings is `('resize', 'random_crop', 'random_flip')`

Note: In `PairedImageDataset`, these transforms are applied jointly to the `A` and `B` images, whereas in `UnpairedImageDataset`, they are applied on each image independently.
Note: In `PairedImageDataset`, these transforms are applied jointly to the _A_ and _B_ images, whereas in `UnpairedImageDataset`, they are applied on each image independently.

- `load_size`: This parameter accepts a tuple that specifies the size (`H`, `W`) to which the images are to be loaded from the storage and resized as a result of the `resize` preprocessing instruction.

- `final_size`: This parameter accepts a tuple that specifies the size (`H`, `W`) to which the images are converted as a result of the random transforms. This is the final size of the images that should be expected from the dataloader.

Note: When not using random transforms (for example, during validation), specify the `final_size` the same as `load_size`.
Note: When not using random transforms (for example, during validation), specify the `final_size` the same as `load_size`.



-------------------------
## Custom Dataset Classes

It is also possible to define your custom dataset class for use-cases requiring specialized processing for the data, for exmaple, in case of medical images. See [Your First Project](../tutorials_basic/2_new_project.md) for more more details on creating custom datasets.
93 changes: 93 additions & 0 deletions docs/package_overview/4_architectures.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Model Architectures and Loss Functions



--------------------
## GAN Architectures

`ganslate` provides implementations of several popular image translation GANs, which you can use out-of-the-box in your projects. Here is the list of currently supported GAN architectures:

1. Pix2Pix
- Class: `ganslate.nn.gans.paired.pix2pix.Pix2PixConditionalGAN`
- Data requirements: Paired pixel-wise aligned domain _A_ and domain _B_ images
- Original paper: Isola et. al - Image-to-Image Translation with Conditional Adversarial Networks ([arXiv](https://arxiv.org/abs/1611.07004))

2. CycleGAN
- Class: `ganslate.nn.gans.unpaired.cyclegan.CycleGAN`
- Data requirements: Unpaired domain _A_ and domain _B_ images
- Original paper: Zhu et. al - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks ([arXiv](https://arxiv.org/abs/1703.10593))

3. RevGAN
- Class: `ganslate.nn.gans.unpaired.revgan.RevGAN`
- Data requirements: Unpaired domain _A_ and domain _B_ images
- Original paper: Ouderaa et. al - Reversible GANs for Memory-efficient Image-to-Image Translation ([arXiv](https://arxiv.org/abs/1902.02729))

4. CUT
- Class: `ganslate.nn.gans.unpaired.cut.CUT`
- Data requirements: Unpaired domain _A_ and domain _B_ images
- Original paper: Park et. al - Contrastive Learning for Unpaired Image-to-Image Translation ([arXiv](https://arxiv.org/abs/2007.15651))

`ganslate` defines an abstract base class `ganslate.nn.gans.base.BaseGAN` ([source](https://github.com/ganslate-team/ganslate/nn/gans/base.py)) that implements some of the basic functionalty common to all the aforementioned GAN architectures, such as methods related to model setup, saving, loading, learning rate update, etc. Additionally, it also declares certain abstract methods whose implementation might differ across various GAN architectures, such as the forward pass and backpropagation logic. Each of the aforementioned GAN architectures inherits from `BaseGAN` and implements the necessary abstract methods.

The `BaseGAN` class has an associated `dataclass` at `ganslate.configs.base.BaseGANConfig` that defines all its basic settings including the settings for optimizer, generator, and discriminator. Since the different GAN architectures have their own specific settings, each of them also has an associated configuration `dataclass` that inherits from `ganslate.configs.base.BaseGANConfig` and defines additional architecture-specific settings.

Owing to its extensible design, `ganslate` additionally enables users to modify the existing GANs by overriding certain functionalities or to define their own custom image translation GAN from scratch. The former is discussed in the context of loss functions as part of the basic tutorial [Your First Project](../tutorials_basic/2_new_project.md). Whereas, the latter is part of the advanced tutorial [Writing Your Own GAN Class from Scratch](../tutorials_advanced/1_custom_gan_architecture.md).



--------------------------------------------
## Generator and Discriminator Architectures

Generators and discriminators are defined in `ganslate` as regular _PyTorch_ modules derived from `torch.nn.Module`.

Following is the list of the available generator architectures:

1. ResNet variants (Original ResNet paper - [arXiv](https://arxiv.org/abs/1512.03385)):
- 2D ResNet: `ganslate.nn.generators.resent.resnet2d.Resnet2D`
- 3D ResNet: `ganslate.nn.generators.resent.resnet3d.Resnet3D`
- Partially-invertible ResNet generator: `ganslate.nn.generators.resent.piresnet3d.Piresnet3D`

2. U-Net variants (Original U-Net paper - [arXiv](https://arxiv.org/abs/1505.04597)):
- 2D U-Net: `ganslate.nn.generators.unet.unet2d.Unet2D`
- 3D U-Net: `ganslate.nn.generators.unet.unet3d.Unet#D`

3. V-Net variants (Original V-Net paper - [arXiv](https://arxiv.org/abs/1606.04797))
- 2D V-Net: `ganslate.nn.generators.vnet.vnet2d.Vnet2D`
- 3D V-Net: `ganslate.nn.generators.vnet.vnet3d.Vnet3D`
- Partially-invertible 3D V-Net generator with Self-Attention: `ganslate.nn.generators.vnet.sa_vnet3d.SAVnet3D`


And here is the list of the available discriminator architectures:

1. PatchGAN discriminator variants (PatchGAN originally described in the Pix2Pix paper - [arXiv](https://arxiv.org/abs/1611.07004))
- 2D PatchGAN: `ganslate.nn.discriminators.patchgan.patchgan2d.PatchGAN2D`
- 3D PatchGAN: `ganslate.nn.discriminators.patchgan.patchgan3d.PatchGAN3D`
- Multiscale 3D PatchGAN: `ganslate.nn.discriminators.patchgan.ms_patchgan3d.MSPatchGAN3D`
- 3D PatchGAN with Self-Attention: `ganslate.nn.discriminators.patchgan.sa_patchgan3d.SAPatchGAN3D`



-----------------
## Loss Functions

Several different loss function classes are provided in the `ganslate` package. These include flavors of the adversarial loss as well as various GAN architecture-specific losses.

1. Adversarial loss:
- Class: `ganslate.nn.losses.adversarial_loss.AdversarialLoss`
- Variants: `'vanilla'` (original adversarial loss based on cross-entropy), `'lsgan'` (least-squares loss), `'wgangp'` (Wasserstein-1 distance with gradient penalty), and `'nonsaturating'`

2. Pix2Pix loss:
- Class: `ganslate.nn.losses.pix2pix_losses.Pix2PixLoss`
- Components:
- Pixel-to-pixel L1 loss between synthetic image and ground truth (weighted by the scalar `lambda_pix2pix`)

3. CycleGAN losses:
- Class: `ganslate.nn.losses.cyclegan_losses.CycleGANLosses`
- Components:
- Cycle-consistency loss based on L1 distance (A-B-A and B-A-B components separated weighted by `lambda_AB` and `lambda_BA`, respectively). Option to compute cycle-consistency as using a weighted sum of L1 and SSIM losses (weights defined by the hyperparameter `proportion_ssim`).
- Identity loss implemented with L1 distance

3. CUT losses:
- Class: `ganslate.nn.losses.cut_losses.PatchNCELoss`
- Components:
- PatchNCE loss
1 change: 0 additions & 1 deletion docs/package_overview/4_gans.md

This file was deleted.

8 changes: 8 additions & 0 deletions docs/tutorials_basic/2_new_project.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# Your First Project



-------------------------------------
## Creating a Project from a Template

TODO: Cookiecutter stuff



-------------------------------------------------------------------
## Loading Your Own Data into ganslate with Custom Pytorch Datasets

Expand Down
1 change: 0 additions & 1 deletion ganslate/nn/gans/paired/pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
from ganslate import configs
from ganslate.data.utils.image_pool import ImagePool
from ganslate.nn.gans.base import BaseGAN
from ganslate.nn.losses.adversarial_loss import AdversarialLoss
from ganslate.nn.losses.pix2pix_losses import Pix2PixLoss
Expand Down
6 changes: 3 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ nav:

- Package Overview:
- Commandline Interface: package_overview/1_cli.md
- Projects: package_overview/2_projects.md
- ganslate Projects: package_overview/2_projects.md
- Datasets: package_overview/3_datasets.md
- GAN Architectures and Loss Functions: package_overview/4_gans.md
- Model Architectures and Loss Functions: package_overview/4_architectures.md
- Engines: package_overview/5_engines.md
- Logging and Visualization: package_overview/6_logging.md
- Configuration: package_overview/7_configuration.md
Expand All @@ -20,7 +20,7 @@ nav:
- Your New Project: tutorials_basic/2_new_project.md

- Advanced Tutorials:
- Custom GAN Architectures: tutorials_advanced/1_custom_gans_architecture.md
- Custom GAN Architectures: tutorials_advanced/1_custom_gan_architecture.md
- Custom Generator and Discriminator Architectures: tutorials_advanced/2_custom_G_and_D_architectures.md

- API: api/*
Expand Down

0 comments on commit cd6d941

Please sign in to comment.