Skip to content

Commit

Permalink
Merge branch 'documentation'
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Sep 29, 2021
2 parents b05095f + 067b637 commit f544b26
Show file tree
Hide file tree
Showing 20 changed files with 579 additions and 325 deletions.
Binary file added docs/imgs/uml-ganslate_engines.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# Using the command line interface

Interact with `ganslate` using:

## CLI
The command line interface for `ganslate` offers a very simple way to interact with various functionalities.

After installing the package, you can type
Expand Down
25 changes: 25 additions & 0 deletions docs/package_overview/2_projects.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# `ganslate` Projects

In `ganslate`, a _project_ refers to a collection of all custom code and configuration files pertaining to your specific task. The project directory is expected to have a certain structure that isolates logically different parts of the project, such as data pipeline, GAN implementation, and configuration. The directory structure is as follows

```text
<your_project_dir>
|
|- datasets
| |- custom_train_dataset.py
| |- custom_val_test_dataset.py
|
|- architectures
| |- custom_gan.py
|
|- experiments
| |- exp1_config.yaml
|
|- __init__.py
|
|- README.md
```

The `__init__.py` file initializes your project directory as Python module which is necessary for `ganslate`'s configuration system to correctly function. (See [configuration](./7_configuration.md) for details). The `README.md` file could contain a description of your task.

`ganslate` provides a Cookiecutter template which can automatically generate an empty project for you. The tutorial [Your First Project](../tutorials_basic/2_new_project.md) provides detailed instructions on how to create and operate your own project.
76 changes: 76 additions & 0 deletions docs/package_overview/3_datasets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Datasets

If you are familiar with the [standard data loading workflow in _PyTorch_](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), `ganslate` uses similar dataset classes which are derived from `torch.utils.data.Dataset` to define the data fetching and preprocessing pipeline. Additionally, `ganslate`'s datasets have an associated Python `dataclass` in which all the data-related settings are defined.



--------------------------------------------------------------
## The `PairedImageDataset` and `UnpairedImageDataset` classes

Two classes - `PairedImageDataset` and `UnpairedImageDataset` - are supplied by the `ganslate` by default which can be used out-of-the-box on your image data.

The `PairedImageDataset` class enables loading from the storage an _A_-_B_ image pair given a common index and applying optional joint transformations on the pair. This class is to be used in paired training as well during validation and/or testing when paired data is available.

On the other hand, `UnpairedImageDataset` fetches randomly a domain _A_ image and a domain _B_ image and applies optional transformations on each independently. As the name suggects, this class is meant to be used for unpaired training.

### Input and Output
Both classes expect your data directory to be structured in the following manner
```text
<train_dataset_root_dir>
|
|- A
| |- ...
| |- ...
|
|- B
| |- ...
| |- ...
```

And if using validation or testing data
```text
<val_dataset_root_dir>
|
|- A
| |- ...
| |- ...
|
|- B
| |- ...
| |- ...
```

where the sub-directories _A_ and _B_ contain the images. In situations where paired data is provided (i.e. paired training or all valdation/testing), the ordering of images in _A_ corresponds to the ordering of images in _B_, meaning that the first _A_ image the first _B_ image are pairs and so on. Images with extensions `.jpg`, `.jpeg`, and `png` are supported.


Both image dataset classes implemet a `__getitem__` method that outputs a sample dictionary of the following form
```python
sample = {'A': a_tensor, 'B': b_tensor}
```
where the each tensor is of shape (`C`, `H`, `W`).


### Available Settings
The configuration `dataclasses` associated with both default image datasets are inherited from `configs.base.BaseDatasetConfig` which two settings common to all dataset classes. These are `num_workers` and `pin_memory` which are the settings for the `torch.utils.data.DataLoader` used by `ganslate` internally.


The two image datasets have additional settings which are same across the two datasets. These are:

- `image_channels`: Refers to the number of channels in the images. Only the images with 1 and 3 channels are supported (i.e. grayscale and RGB), and the channels should be the same across _A_ and _B_ images (i.e. either both should be grayscale or both should be RGB).

- `preprocess`: Accepts a tuple of predefined strings that defines the preprocessing instructions. These predefined strings include `resize`, `scale_width`, `random_zoom`, `random_crop`, and `random_flip` of which `resize` and `scale_width` specify the initial resizing operations (choose either one or none), whereas the rest correspond to the random transforms used as data augmentation during training. An example value for the `preprocess` settings is `('resize', 'random_crop', 'random_flip')`

Note: In `PairedImageDataset`, these transforms are applied jointly to the _A_ and _B_ images, whereas in `UnpairedImageDataset`, they are applied on each image independently.

- `load_size`: This parameter accepts a tuple that specifies the size (`H`, `W`) to which the images are to be loaded from the storage and resized as a result of the `resize` preprocessing instruction. If instead the `scale_width` instruction is specified, only the width component if of the `load_size` is considered and the image width is scaled to this value while preserving its aspect ratio.

- `final_size`: This parameter accepts a tuple that specifies the size (`H`, `W`) to which the images are converted as a result of the random transforms. This is the final size of the images that should be expected from the dataloader.

Note: When not using random transforms (for example, during validation/testing), specify the `final_size` the same as `load_size`.



-------------------------
## Custom Dataset Classes

It is also possible to define your custom dataset class for use-cases requiring specialized processing for the data, for exmaple, in case of medical images. See [Your First Project](../tutorials_basic/2_new_project.md) for more more details on creating custom datasets.
93 changes: 93 additions & 0 deletions docs/package_overview/4_architectures.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Model Architectures and Loss Functions



--------------------
## GAN Architectures

`ganslate` provides implementations of several popular image translation GANs, which you can use out-of-the-box in your projects. Here is the list of currently supported GAN architectures:

1. Pix2Pix
- Class: `ganslate.nn.gans.paired.pix2pix.Pix2PixConditionalGAN`
- Data requirements: Paired pixel-wise aligned domain _A_ and domain _B_ images
- Original paper: Isola et. al - Image-to-Image Translation with Conditional Adversarial Networks ([arXiv](https://arxiv.org/abs/1611.07004))

2. CycleGAN
- Class: `ganslate.nn.gans.unpaired.cyclegan.CycleGAN`
- Data requirements: Unpaired domain _A_ and domain _B_ images
- Original paper: Zhu et. al - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks ([arXiv](https://arxiv.org/abs/1703.10593))

3. RevGAN
- Class: `ganslate.nn.gans.unpaired.revgan.RevGAN`
- Data requirements: Unpaired domain _A_ and domain _B_ images
- Original paper: Ouderaa et. al - Reversible GANs for Memory-efficient Image-to-Image Translation ([arXiv](https://arxiv.org/abs/1902.02729))

4. CUT
- Class: `ganslate.nn.gans.unpaired.cut.CUT`
- Data requirements: Unpaired domain _A_ and domain _B_ images
- Original paper: Park et. al - Contrastive Learning for Unpaired Image-to-Image Translation ([arXiv](https://arxiv.org/abs/2007.15651))

`ganslate` defines an abstract base class `ganslate.nn.gans.base.BaseGAN` ([source](https://github.com/ganslate-team/ganslate/nn/gans/base.py)) that implements some of the basic functionalty common to all the aforementioned GAN architectures, such as methods related to model setup, saving, loading, learning rate update, etc. Additionally, it also declares certain abstract methods whose implementation might differ across various GAN architectures, such as the forward pass and backpropagation logic. Each of the aforementioned GAN architectures inherits from `BaseGAN` and implements the necessary abstract methods.

The `BaseGAN` class has an associated `dataclass` at `ganslate.configs.base.BaseGANConfig` that defines all its basic settings including the settings for optimizer, generator, and discriminator. Since the different GAN architectures have their own specific settings, each of them also has an associated configuration `dataclass` that inherits from `ganslate.configs.base.BaseGANConfig` and defines additional architecture-specific settings.

As a result to its extensible design, `ganslate` additionally enables users to modify the existing GANs by overriding certain functionalities or to define their own custom image translation GAN from scratch. The former is discussed in the context of loss functions as part of the basic tutorial [Your First Project](../tutorials_basic/2_new_project.md). Whereas, the latter is part of the advanced tutorial [Writing Your Own GAN Class from Scratch](../tutorials_advanced/1_custom_gan_architecture.md).



--------------------------------------------
## Generator and Discriminator Architectures

Generators and discriminators are defined in `ganslate` as regular _PyTorch_ modules derived from `torch.nn.Module`.

Following is the list of the available generator architectures:

1. ResNet variants (Original ResNet paper - [arXiv](https://arxiv.org/abs/1512.03385)):
- 2D ResNet: `ganslate.nn.generators.resent.resnet2d.Resnet2D`
- 3D ResNet: `ganslate.nn.generators.resent.resnet3d.Resnet3D`
- Partially-invertible ResNet generator: `ganslate.nn.generators.resent.piresnet3d.Piresnet3D`

2. U-Net variants (Original U-Net paper - [arXiv](https://arxiv.org/abs/1505.04597)):
- 2D U-Net: `ganslate.nn.generators.unet.unet2d.Unet2D`
- 3D U-Net: `ganslate.nn.generators.unet.unet3d.Unet#D`

3. V-Net variants (Original V-Net paper - [arXiv](https://arxiv.org/abs/1606.04797))
- 2D V-Net: `ganslate.nn.generators.vnet.vnet2d.Vnet2D`
- 3D V-Net: `ganslate.nn.generators.vnet.vnet3d.Vnet3D`
- Partially-invertible 3D V-Net generator with Self-Attention: `ganslate.nn.generators.vnet.sa_vnet3d.SAVnet3D`


And here is the list of the available discriminator architectures:

1. PatchGAN discriminator variants (PatchGAN originally described in the Pix2Pix paper - [arXiv](https://arxiv.org/abs/1611.07004))
- 2D PatchGAN: `ganslate.nn.discriminators.patchgan.patchgan2d.PatchGAN2D`
- 3D PatchGAN: `ganslate.nn.discriminators.patchgan.patchgan3d.PatchGAN3D`
- Multiscale 3D PatchGAN: `ganslate.nn.discriminators.patchgan.ms_patchgan3d.MSPatchGAN3D`
- 3D PatchGAN with Self-Attention: `ganslate.nn.discriminators.patchgan.sa_patchgan3d.SAPatchGAN3D`



-----------------
## Loss Functions

Several different loss function classes are provided in the `ganslate` package. These include different flavors of the adversarial loss as well as various GAN architecture-specific losses.

1. Adversarial loss
- Class: `ganslate.nn.losses.adversarial_loss.AdversarialLoss`
- Variants: `'vanilla'` (original adversarial loss based on cross-entropy), `'lsgan'` (least-squares loss), `'wgangp'` (Wasserstein-1 distance with gradient penalty), and `'nonsaturating'`

2. Pix2Pix loss
- Class: `ganslate.nn.losses.pix2pix_losses.Pix2PixLoss`
- Components:
- Pixel-to-pixel L1 loss between synthetic image and ground truth (weighted by the scalar `lambda_pix2pix`)

3. CycleGAN losses
- Class: `ganslate.nn.losses.cyclegan_losses.CycleGANLosses`
- Components:
- Cycle-consistency loss based on L1 distance (_A-B-A_ and _B-A-B_ components separated weighted by `lambda_AB` and `lambda_BA`, respectively). Option to compute cycle-consistency as using a weighted sum of L1 and SSIM losses (weights defined by the hyperparameter `proportion_ssim`).
- Identity loss implemented with L1 distance

3. CUT losses
- Class: `ganslate.nn.losses.cut_losses.PatchNCELoss`
- Components:
- PatchNCE loss
54 changes: 54 additions & 0 deletions docs/package_overview/5_engines.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Engines

`ganslate` defines four _engines_ that implement processes crucial to deep learning workflow. These are `Trainer`, `Validator`, `Tester`, and `Inferer`. The following UML diagram shows the design of the `ganslate`'s `engines` module and the relationship between the different engine classes defined in it.

![alt text](../imgs/uml-ganslate_engines.png "Relationship between ganslate's engine classes")



------------
## `Trainer`

The `Trainer` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/trainer.py)) implements the training procedure and is instantiated at the start of the training process. Upon initialization, the trainer object executes the following tasks:
1. Preparing the environment.
2. Initializing the GAN model, training data loader, traning tracker, and validator.

The `Trainer` class provides the `run()` method which defines the training logic. This includes:
1. Fetching data from the training dataloader
2. Invoking the GAN model's methods that set the inputs and perform forward pass, backpropagation, and parameter update.
3. Obtaining the results of the iteration which includes the computed images, loss values, metrics, and I/O and computation times, and pushing them into the experiment tracker for logging.
4. Running model validation.
5. Saving checkpoints locally.
6. Updating the learning rates.

All configuration pertaining to the `Trainer` is grouped under the `'train'` mode in `ganslate`.



--------------
## `Validator`
The `Validator`class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)) inherits almost all of its properties and functionalities from the `BaseValTestEngine`, and is responsible for performing validation given a model during the training process. It is instantiated and utilized within the `Trainer` where it is supplied with its configuration and the model. Upon initialization, a `Validator` object executes the following:
1. Initializes the sliding window inferer, validation data loader, validation tracker, and the validation-test metricizer

The `run()` method of the `Validator` iterates over the validation dataset and executes the following steps:
1. Fetching data from the validation data loader.
2. Running inference on the given model and holding the computed images.
3. Saving the computed image and its relevant metadata (useful in case of medical images).
4. Calculate image quality/similarity metrics by comparing the generated image with the geound truth.
5. Pushing the images and metrics into the validation tracker for logging.

All configuration pertaining to the `Validator` is grouped under the `'val'` mode in `ganslate`.



-----------
## `Tester`
The `Tester` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)), like the `Validator`, inherits from the `BaseValTestEngine` and has the same properties and functionalities as the `Validator`. The only difference is that a `Tester` instance sets up the environment and builds its own GAN model, and is therefore used independently of the `Trainer`.

All configuration pertaining to the `Tester` is grouped under the `'test'` mode in `ganslate`.



------------
## `Inferer`
The `Inferer` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)) represents a simplified inference engine without any mechanism for metric calculation. Therefore, it expects data without a ground truth to compare against. It does execute utility tasks like fetching data from a data loader, tracking I/O and computation time, and logging and saving images under normal circumstances. However, when used in the _deployment_ mode, the `Inferer` essentially acts as a minimal inference engine that can be easily integrated into other applications.
1 change: 1 addition & 0 deletions docs/package_overview/6_trackers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Logging and Visualization:
File renamed without changes.
91 changes: 0 additions & 91 deletions docs/tutorials/custom_dataset.md

This file was deleted.

Empty file.
Loading

0 comments on commit f544b26

Please sign in to comment.