Skip to content

Commit

Permalink
Added engines.md doc page
Browse files Browse the repository at this point in the history
  • Loading branch information
cnmy-ro committed Sep 29, 2021
1 parent cd6d941 commit 067b637
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 13 deletions.
Binary file added docs/imgs/uml-ganslate_engines.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/package_overview/2_projects.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ In `ganslate`, a _project_ refers to a collection of all custom code and configu

The `__init__.py` file initializes your project directory as Python module which is necessary for `ganslate`'s configuration system to correctly function. (See [configuration](./7_configuration.md) for details). The `README.md` file could contain a description of your task.

`ganslate` provides a Cookiecutter template which can automatically generate an empty project for you. The tutorial [Your First Project](../tutorials_basic/2_new_project.md) provides detailed instructions on how to create and run your own project.
`ganslate` provides a Cookiecutter template which can automatically generate an empty project for you. The tutorial [Your First Project](../tutorials_basic/2_new_project.md) provides detailed instructions on how to create and operate your own project.
4 changes: 2 additions & 2 deletions docs/package_overview/3_datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ The two image datasets have additional settings which are same across the two da

Note: In `PairedImageDataset`, these transforms are applied jointly to the _A_ and _B_ images, whereas in `UnpairedImageDataset`, they are applied on each image independently.

- `load_size`: This parameter accepts a tuple that specifies the size (`H`, `W`) to which the images are to be loaded from the storage and resized as a result of the `resize` preprocessing instruction.
- `load_size`: This parameter accepts a tuple that specifies the size (`H`, `W`) to which the images are to be loaded from the storage and resized as a result of the `resize` preprocessing instruction. If instead the `scale_width` instruction is specified, only the width component if of the `load_size` is considered and the image width is scaled to this value while preserving its aspect ratio.

- `final_size`: This parameter accepts a tuple that specifies the size (`H`, `W`) to which the images are converted as a result of the random transforms. This is the final size of the images that should be expected from the dataloader.

Note: When not using random transforms (for example, during validation), specify the `final_size` the same as `load_size`.
Note: When not using random transforms (for example, during validation/testing), specify the `final_size` the same as `load_size`.



Expand Down
14 changes: 7 additions & 7 deletions docs/package_overview/4_architectures.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

The `BaseGAN` class has an associated `dataclass` at `ganslate.configs.base.BaseGANConfig` that defines all its basic settings including the settings for optimizer, generator, and discriminator. Since the different GAN architectures have their own specific settings, each of them also has an associated configuration `dataclass` that inherits from `ganslate.configs.base.BaseGANConfig` and defines additional architecture-specific settings.

Owing to its extensible design, `ganslate` additionally enables users to modify the existing GANs by overriding certain functionalities or to define their own custom image translation GAN from scratch. The former is discussed in the context of loss functions as part of the basic tutorial [Your First Project](../tutorials_basic/2_new_project.md). Whereas, the latter is part of the advanced tutorial [Writing Your Own GAN Class from Scratch](../tutorials_advanced/1_custom_gan_architecture.md).
As a result to its extensible design, `ganslate` additionally enables users to modify the existing GANs by overriding certain functionalities or to define their own custom image translation GAN from scratch. The former is discussed in the context of loss functions as part of the basic tutorial [Your First Project](../tutorials_basic/2_new_project.md). Whereas, the latter is part of the advanced tutorial [Writing Your Own GAN Class from Scratch](../tutorials_advanced/1_custom_gan_architecture.md).



Expand Down Expand Up @@ -70,24 +70,24 @@ And here is the list of the available discriminator architectures:
-----------------
## Loss Functions

Several different loss function classes are provided in the `ganslate` package. These include flavors of the adversarial loss as well as various GAN architecture-specific losses.
Several different loss function classes are provided in the `ganslate` package. These include different flavors of the adversarial loss as well as various GAN architecture-specific losses.

1. Adversarial loss:
1. Adversarial loss
- Class: `ganslate.nn.losses.adversarial_loss.AdversarialLoss`
- Variants: `'vanilla'` (original adversarial loss based on cross-entropy), `'lsgan'` (least-squares loss), `'wgangp'` (Wasserstein-1 distance with gradient penalty), and `'nonsaturating'`

2. Pix2Pix loss:
2. Pix2Pix loss
- Class: `ganslate.nn.losses.pix2pix_losses.Pix2PixLoss`
- Components:
- Pixel-to-pixel L1 loss between synthetic image and ground truth (weighted by the scalar `lambda_pix2pix`)

3. CycleGAN losses:
3. CycleGAN losses
- Class: `ganslate.nn.losses.cyclegan_losses.CycleGANLosses`
- Components:
- Cycle-consistency loss based on L1 distance (A-B-A and B-A-B components separated weighted by `lambda_AB` and `lambda_BA`, respectively). Option to compute cycle-consistency as using a weighted sum of L1 and SSIM losses (weights defined by the hyperparameter `proportion_ssim`).
- Cycle-consistency loss based on L1 distance (_A-B-A_ and _B-A-B_ components separated weighted by `lambda_AB` and `lambda_BA`, respectively). Option to compute cycle-consistency as using a weighted sum of L1 and SSIM losses (weights defined by the hyperparameter `proportion_ssim`).
- Identity loss implemented with L1 distance

3. CUT losses:
3. CUT losses
- Class: `ganslate.nn.losses.cut_losses.PatchNCELoss`
- Components:
- PatchNCE loss
55 changes: 54 additions & 1 deletion docs/package_overview/5_engines.md
Original file line number Diff line number Diff line change
@@ -1 +1,54 @@
# Engines
# Engines

`ganslate` defines four _engines_ that implement processes crucial to deep learning workflow. These are `Trainer`, `Validator`, `Tester`, and `Inferer`. The following UML diagram shows the design of the `ganslate`'s `engines` module and the relationship between the different engine classes defined in it.

![alt text](../imgs/uml-ganslate_engines.png "Relationship between ganslate's engine classes")



------------
## `Trainer`

The `Trainer` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/trainer.py)) implements the training procedure and is instantiated at the start of the training process. Upon initialization, the trainer object executes the following tasks:
1. Preparing the environment.
2. Initializing the GAN model, training data loader, traning tracker, and validator.

The `Trainer` class provides the `run()` method which defines the training logic. This includes:
1. Fetching data from the training dataloader
2. Invoking the GAN model's methods that set the inputs and perform forward pass, backpropagation, and parameter update.
3. Obtaining the results of the iteration which includes the computed images, loss values, metrics, and I/O and computation times, and pushing them into the experiment tracker for logging.
4. Running model validation.
5. Saving checkpoints locally.
6. Updating the learning rates.

All configuration pertaining to the `Trainer` is grouped under the `'train'` mode in `ganslate`.



--------------
## `Validator`
The `Validator`class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)) inherits almost all of its properties and functionalities from the `BaseValTestEngine`, and is responsible for performing validation given a model during the training process. It is instantiated and utilized within the `Trainer` where it is supplied with its configuration and the model. Upon initialization, a `Validator` object executes the following:
1. Initializes the sliding window inferer, validation data loader, validation tracker, and the validation-test metricizer

The `run()` method of the `Validator` iterates over the validation dataset and executes the following steps:
1. Fetching data from the validation data loader.
2. Running inference on the given model and holding the computed images.
3. Saving the computed image and its relevant metadata (useful in case of medical images).
4. Calculate image quality/similarity metrics by comparing the generated image with the geound truth.
5. Pushing the images and metrics into the validation tracker for logging.

All configuration pertaining to the `Validator` is grouped under the `'val'` mode in `ganslate`.



-----------
## `Tester`
The `Tester` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)), like the `Validator`, inherits from the `BaseValTestEngine` and has the same properties and functionalities as the `Validator`. The only difference is that a `Tester` instance sets up the environment and builds its own GAN model, and is therefore used independently of the `Trainer`.

All configuration pertaining to the `Tester` is grouped under the `'test'` mode in `ganslate`.



------------
## `Inferer`
The `Inferer` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)) represents a simplified inference engine without any mechanism for metric calculation. Therefore, it expects data without a ground truth to compare against. It does execute utility tasks like fetching data from a data loader, tracking I/O and computation time, and logging and saving images under normal circumstances. However, when used in the _deployment_ mode, the `Inferer` essentially acts as a minimal inference engine that can be easily integrated into other applications.
1 change: 0 additions & 1 deletion docs/package_overview/6_logging.md

This file was deleted.

1 change: 1 addition & 0 deletions docs/package_overview/6_trackers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Logging and Visualization:
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ nav:
- Datasets: package_overview/3_datasets.md
- Model Architectures and Loss Functions: package_overview/4_architectures.md
- Engines: package_overview/5_engines.md
- Logging and Visualization: package_overview/6_logging.md
- Logging and Visualization: package_overview/6_trackers.md
- Configuration: package_overview/7_configuration.md

- Basic Tutorials:
Expand Down

0 comments on commit 067b637

Please sign in to comment.