diff --git a/docs/imgs/uml-ganslate_engines.png b/docs/imgs/uml-ganslate_engines.png new file mode 100644 index 00000000..226571ad Binary files /dev/null and b/docs/imgs/uml-ganslate_engines.png differ diff --git a/docs/getting_started/using_cli.md b/docs/package_overview/1_cli.md similarity index 95% rename from docs/getting_started/using_cli.md rename to docs/package_overview/1_cli.md index bc85f326..4ff8cc1e 100644 --- a/docs/getting_started/using_cli.md +++ b/docs/package_overview/1_cli.md @@ -1,8 +1,5 @@ # Using the command line interface -Interact with `ganslate` using: - -## CLI The command line interface for `ganslate` offers a very simple way to interact with various functionalities. After installing the package, you can type diff --git a/docs/package_overview/2_projects.md b/docs/package_overview/2_projects.md new file mode 100644 index 00000000..77181604 --- /dev/null +++ b/docs/package_overview/2_projects.md @@ -0,0 +1,25 @@ +# `ganslate` Projects + +In `ganslate`, a _project_ refers to a collection of all custom code and configuration files pertaining to your specific task. The project directory is expected to have a certain structure that isolates logically different parts of the project, such as data pipeline, GAN implementation, and configuration. The directory structure is as follows + +```text + + | + |- datasets + | |- custom_train_dataset.py + | |- custom_val_test_dataset.py + | + |- architectures + | |- custom_gan.py + | + |- experiments + | |- exp1_config.yaml + | + |- __init__.py + | + |- README.md +``` + +The `__init__.py` file initializes your project directory as Python module which is necessary for `ganslate`'s configuration system to correctly function. (See [configuration](./7_configuration.md) for details). The `README.md` file could contain a description of your task. + +`ganslate` provides a Cookiecutter template which can automatically generate an empty project for you. The tutorial [Your First Project](../tutorials_basic/2_new_project.md) provides detailed instructions on how to create and operate your own project. \ No newline at end of file diff --git a/docs/package_overview/3_datasets.md b/docs/package_overview/3_datasets.md new file mode 100644 index 00000000..941469d4 --- /dev/null +++ b/docs/package_overview/3_datasets.md @@ -0,0 +1,76 @@ +# Datasets + +If you are familiar with the [standard data loading workflow in _PyTorch_](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), `ganslate` uses similar dataset classes which are derived from `torch.utils.data.Dataset` to define the data fetching and preprocessing pipeline. Additionally, `ganslate`'s datasets have an associated Python `dataclass` in which all the data-related settings are defined. + + + +-------------------------------------------------------------- +## The `PairedImageDataset` and `UnpairedImageDataset` classes + +Two classes - `PairedImageDataset` and `UnpairedImageDataset` - are supplied by the `ganslate` by default which can be used out-of-the-box on your image data. + +The `PairedImageDataset` class enables loading from the storage an _A_-_B_ image pair given a common index and applying optional joint transformations on the pair. This class is to be used in paired training as well during validation and/or testing when paired data is available. + +On the other hand, `UnpairedImageDataset` fetches randomly a domain _A_ image and a domain _B_ image and applies optional transformations on each independently. As the name suggects, this class is meant to be used for unpaired training. + +### Input and Output +Both classes expect your data directory to be structured in the following manner +```text + + | + |- A + | |- ... + | |- ... + | + |- B + | |- ... + | |- ... +``` + +And if using validation or testing data +```text + + | + |- A + | |- ... + | |- ... + | + |- B + | |- ... + | |- ... +``` + +where the sub-directories _A_ and _B_ contain the images. In situations where paired data is provided (i.e. paired training or all valdation/testing), the ordering of images in _A_ corresponds to the ordering of images in _B_, meaning that the first _A_ image the first _B_ image are pairs and so on. Images with extensions `.jpg`, `.jpeg`, and `png` are supported. + + +Both image dataset classes implemet a `__getitem__` method that outputs a sample dictionary of the following form +```python +sample = {'A': a_tensor, 'B': b_tensor} +``` +where the each tensor is of shape (`C`, `H`, `W`). + + +### Available Settings +The configuration `dataclasses` associated with both default image datasets are inherited from `configs.base.BaseDatasetConfig` which two settings common to all dataset classes. These are `num_workers` and `pin_memory` which are the settings for the `torch.utils.data.DataLoader` used by `ganslate` internally. + + +The two image datasets have additional settings which are same across the two datasets. These are: + +- `image_channels`: Refers to the number of channels in the images. Only the images with 1 and 3 channels are supported (i.e. grayscale and RGB), and the channels should be the same across _A_ and _B_ images (i.e. either both should be grayscale or both should be RGB). + +- `preprocess`: Accepts a tuple of predefined strings that defines the preprocessing instructions. These predefined strings include `resize`, `scale_width`, `random_zoom`, `random_crop`, and `random_flip` of which `resize` and `scale_width` specify the initial resizing operations (choose either one or none), whereas the rest correspond to the random transforms used as data augmentation during training. An example value for the `preprocess` settings is `('resize', 'random_crop', 'random_flip')` + +Note: In `PairedImageDataset`, these transforms are applied jointly to the _A_ and _B_ images, whereas in `UnpairedImageDataset`, they are applied on each image independently. + +- `load_size`: This parameter accepts a tuple that specifies the size (`H`, `W`) to which the images are to be loaded from the storage and resized as a result of the `resize` preprocessing instruction. If instead the `scale_width` instruction is specified, only the width component if of the `load_size` is considered and the image width is scaled to this value while preserving its aspect ratio. + +- `final_size`: This parameter accepts a tuple that specifies the size (`H`, `W`) to which the images are converted as a result of the random transforms. This is the final size of the images that should be expected from the dataloader. + +Note: When not using random transforms (for example, during validation/testing), specify the `final_size` the same as `load_size`. + + + +------------------------- +## Custom Dataset Classes + +It is also possible to define your custom dataset class for use-cases requiring specialized processing for the data, for exmaple, in case of medical images. See [Your First Project](../tutorials_basic/2_new_project.md) for more more details on creating custom datasets. \ No newline at end of file diff --git a/docs/package_overview/4_architectures.md b/docs/package_overview/4_architectures.md new file mode 100644 index 00000000..c4ad77ad --- /dev/null +++ b/docs/package_overview/4_architectures.md @@ -0,0 +1,93 @@ +# Model Architectures and Loss Functions + + + +-------------------- +## GAN Architectures + +`ganslate` provides implementations of several popular image translation GANs, which you can use out-of-the-box in your projects. Here is the list of currently supported GAN architectures: + +1. Pix2Pix + - Class: `ganslate.nn.gans.paired.pix2pix.Pix2PixConditionalGAN` + - Data requirements: Paired pixel-wise aligned domain _A_ and domain _B_ images + - Original paper: Isola et. al - Image-to-Image Translation with Conditional Adversarial Networks ([arXiv](https://arxiv.org/abs/1611.07004)) + +2. CycleGAN + - Class: `ganslate.nn.gans.unpaired.cyclegan.CycleGAN` + - Data requirements: Unpaired domain _A_ and domain _B_ images + - Original paper: Zhu et. al - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks ([arXiv](https://arxiv.org/abs/1703.10593)) + +3. RevGAN + - Class: `ganslate.nn.gans.unpaired.revgan.RevGAN` + - Data requirements: Unpaired domain _A_ and domain _B_ images + - Original paper: Ouderaa et. al - Reversible GANs for Memory-efficient Image-to-Image Translation ([arXiv](https://arxiv.org/abs/1902.02729)) + +4. CUT + - Class: `ganslate.nn.gans.unpaired.cut.CUT` + - Data requirements: Unpaired domain _A_ and domain _B_ images + - Original paper: Park et. al - Contrastive Learning for Unpaired Image-to-Image Translation ([arXiv](https://arxiv.org/abs/2007.15651)) + +`ganslate` defines an abstract base class `ganslate.nn.gans.base.BaseGAN` ([source](https://github.com/ganslate-team/ganslate/nn/gans/base.py)) that implements some of the basic functionalty common to all the aforementioned GAN architectures, such as methods related to model setup, saving, loading, learning rate update, etc. Additionally, it also declares certain abstract methods whose implementation might differ across various GAN architectures, such as the forward pass and backpropagation logic. Each of the aforementioned GAN architectures inherits from `BaseGAN` and implements the necessary abstract methods. + +The `BaseGAN` class has an associated `dataclass` at `ganslate.configs.base.BaseGANConfig` that defines all its basic settings including the settings for optimizer, generator, and discriminator. Since the different GAN architectures have their own specific settings, each of them also has an associated configuration `dataclass` that inherits from `ganslate.configs.base.BaseGANConfig` and defines additional architecture-specific settings. + +As a result to its extensible design, `ganslate` additionally enables users to modify the existing GANs by overriding certain functionalities or to define their own custom image translation GAN from scratch. The former is discussed in the context of loss functions as part of the basic tutorial [Your First Project](../tutorials_basic/2_new_project.md). Whereas, the latter is part of the advanced tutorial [Writing Your Own GAN Class from Scratch](../tutorials_advanced/1_custom_gan_architecture.md). + + + +-------------------------------------------- +## Generator and Discriminator Architectures + +Generators and discriminators are defined in `ganslate` as regular _PyTorch_ modules derived from `torch.nn.Module`. + +Following is the list of the available generator architectures: + +1. ResNet variants (Original ResNet paper - [arXiv](https://arxiv.org/abs/1512.03385)): + - 2D ResNet: `ganslate.nn.generators.resent.resnet2d.Resnet2D` + - 3D ResNet: `ganslate.nn.generators.resent.resnet3d.Resnet3D` + - Partially-invertible ResNet generator: `ganslate.nn.generators.resent.piresnet3d.Piresnet3D` + +2. U-Net variants (Original U-Net paper - [arXiv](https://arxiv.org/abs/1505.04597)): + - 2D U-Net: `ganslate.nn.generators.unet.unet2d.Unet2D` + - 3D U-Net: `ganslate.nn.generators.unet.unet3d.Unet#D` + +3. V-Net variants (Original V-Net paper - [arXiv](https://arxiv.org/abs/1606.04797)) + - 2D V-Net: `ganslate.nn.generators.vnet.vnet2d.Vnet2D` + - 3D V-Net: `ganslate.nn.generators.vnet.vnet3d.Vnet3D` + - Partially-invertible 3D V-Net generator with Self-Attention: `ganslate.nn.generators.vnet.sa_vnet3d.SAVnet3D` + + +And here is the list of the available discriminator architectures: + +1. PatchGAN discriminator variants (PatchGAN originally described in the Pix2Pix paper - [arXiv](https://arxiv.org/abs/1611.07004)) + - 2D PatchGAN: `ganslate.nn.discriminators.patchgan.patchgan2d.PatchGAN2D` + - 3D PatchGAN: `ganslate.nn.discriminators.patchgan.patchgan3d.PatchGAN3D` + - Multiscale 3D PatchGAN: `ganslate.nn.discriminators.patchgan.ms_patchgan3d.MSPatchGAN3D` + - 3D PatchGAN with Self-Attention: `ganslate.nn.discriminators.patchgan.sa_patchgan3d.SAPatchGAN3D` + + + +----------------- +## Loss Functions + +Several different loss function classes are provided in the `ganslate` package. These include different flavors of the adversarial loss as well as various GAN architecture-specific losses. + +1. Adversarial loss + - Class: `ganslate.nn.losses.adversarial_loss.AdversarialLoss` + - Variants: `'vanilla'` (original adversarial loss based on cross-entropy), `'lsgan'` (least-squares loss), `'wgangp'` (Wasserstein-1 distance with gradient penalty), and `'nonsaturating'` + +2. Pix2Pix loss + - Class: `ganslate.nn.losses.pix2pix_losses.Pix2PixLoss` + - Components: + - Pixel-to-pixel L1 loss between synthetic image and ground truth (weighted by the scalar `lambda_pix2pix`) + +3. CycleGAN losses + - Class: `ganslate.nn.losses.cyclegan_losses.CycleGANLosses` + - Components: + - Cycle-consistency loss based on L1 distance (_A-B-A_ and _B-A-B_ components separated weighted by `lambda_AB` and `lambda_BA`, respectively). Option to compute cycle-consistency as using a weighted sum of L1 and SSIM losses (weights defined by the hyperparameter `proportion_ssim`). + - Identity loss implemented with L1 distance + +3. CUT losses + - Class: `ganslate.nn.losses.cut_losses.PatchNCELoss` + - Components: + - PatchNCE loss \ No newline at end of file diff --git a/docs/package_overview/5_engines.md b/docs/package_overview/5_engines.md new file mode 100644 index 00000000..ff684dda --- /dev/null +++ b/docs/package_overview/5_engines.md @@ -0,0 +1,54 @@ +# Engines + +`ganslate` defines four _engines_ that implement processes crucial to deep learning workflow. These are `Trainer`, `Validator`, `Tester`, and `Inferer`. The following UML diagram shows the design of the `ganslate`'s `engines` module and the relationship between the different engine classes defined in it. + +![alt text](../imgs/uml-ganslate_engines.png "Relationship between ganslate's engine classes") + + + +------------ +## `Trainer` + +The `Trainer` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/trainer.py)) implements the training procedure and is instantiated at the start of the training process. Upon initialization, the trainer object executes the following tasks: +1. Preparing the environment. +2. Initializing the GAN model, training data loader, traning tracker, and validator. + +The `Trainer` class provides the `run()` method which defines the training logic. This includes: +1. Fetching data from the training dataloader +2. Invoking the GAN model's methods that set the inputs and perform forward pass, backpropagation, and parameter update. +3. Obtaining the results of the iteration which includes the computed images, loss values, metrics, and I/O and computation times, and pushing them into the experiment tracker for logging. +4. Running model validation. +5. Saving checkpoints locally. +6. Updating the learning rates. + +All configuration pertaining to the `Trainer` is grouped under the `'train'` mode in `ganslate`. + + + +-------------- +## `Validator` +The `Validator`class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)) inherits almost all of its properties and functionalities from the `BaseValTestEngine`, and is responsible for performing validation given a model during the training process. It is instantiated and utilized within the `Trainer` where it is supplied with its configuration and the model. Upon initialization, a `Validator` object executes the following: +1. Initializes the sliding window inferer, validation data loader, validation tracker, and the validation-test metricizer + +The `run()` method of the `Validator` iterates over the validation dataset and executes the following steps: +1. Fetching data from the validation data loader. +2. Running inference on the given model and holding the computed images. +3. Saving the computed image and its relevant metadata (useful in case of medical images). +4. Calculate image quality/similarity metrics by comparing the generated image with the geound truth. +5. Pushing the images and metrics into the validation tracker for logging. + +All configuration pertaining to the `Validator` is grouped under the `'val'` mode in `ganslate`. + + + +----------- +## `Tester` +The `Tester` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)), like the `Validator`, inherits from the `BaseValTestEngine` and has the same properties and functionalities as the `Validator`. The only difference is that a `Tester` instance sets up the environment and builds its own GAN model, and is therefore used independently of the `Trainer`. + +All configuration pertaining to the `Tester` is grouped under the `'test'` mode in `ganslate`. + + + +------------ +## `Inferer` +The `Inferer` class ([source](https://github.com/ganslate-team/ganslate/ganslate/engines/validator_tester.py)) represents a simplified inference engine without any mechanism for metric calculation. Therefore, it expects data without a ground truth to compare against. It does execute utility tasks like fetching data from a data loader, tracking I/O and computation time, and logging and saving images under normal circumstances. However, when used in the _deployment_ mode, the `Inferer` essentially acts as a minimal inference engine that can be easily integrated into other applications. \ No newline at end of file diff --git a/docs/package_overview/6_trackers.md b/docs/package_overview/6_trackers.md new file mode 100644 index 00000000..c414af96 --- /dev/null +++ b/docs/package_overview/6_trackers.md @@ -0,0 +1 @@ +# Logging and Visualization: \ No newline at end of file diff --git a/docs/tutorials/configuration.md b/docs/package_overview/7_configuration.md similarity index 100% rename from docs/tutorials/configuration.md rename to docs/package_overview/7_configuration.md diff --git a/docs/tutorials/custom_dataset.md b/docs/tutorials/custom_dataset.md deleted file mode 100644 index 6f4ba15e..00000000 --- a/docs/tutorials/custom_dataset.md +++ /dev/null @@ -1,91 +0,0 @@ -# Loading Your Own Data into ganslate with Custom Pytorch Datasets - -`ganslate` can be run on your own data through creating [your own Pytorch Dataset](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html). - - -## Integrating your Pytorch Dataset for Training -Once you have your custom Dataset, it needs to be modified as `ganslate` expects certain structure to make the Dataset class compatible with the framework. - -Namely, it expects *atleast* the following, - -1. Specific return format from the `__getitem__` function -2. Dataclass configuration for your dataset - -#### Specific return format from the `__getitem__` function -A sample of the format expected to be returned: -```python -def __getitem__(self, index): - ... - - return {'A': ..., 'B': ...} -``` -A dictionary with the following keys is expected to be returned - -1. `A` - corresponds to `torch.Tensor` image (2D/3D) from domain A -2. `B` - corresponds to `torch.Tensor` image (2D/3D) from domain B - - -#### Dataclass configuration for your dataset -The Dataset can be dynamically configured through Dataclasses as a part of the [OmegaConf configuration system](https://github.com/omry/omegaconf). Apart from configuration, the Dataclass is also important to allow the framework to easily import your Dataset while training. A sample of this can be found in the [default ImageDataset](https://github.com/Maastro-CDS-Imaging-Group/midaGAN/blob/26564fa721f71c024aa88fb278ecba7de748e55c/midaGAN/data/image_dataset.py#L15) provided with `ganslate`. - -The structure of the Dataclass configuration -```python -from dataclasses import dataclass -from ganslate import configs - -@dataclass -class YourDatasetNameConfig(configs.base.BaseDatasetConfig): # Your dataset always needs to inherit the BaseDatasetConfig - # Define additional parameters below, these parameters are passed to - # the dataset and can be used for dynamic configuration. - # Examples of parameters - flip: bool = True - -``` - -`YourDatasetName` is to be consistent with the name of your Pytorch Dataset. The name is also used to import the Dataset module. - - -To allow your Dataset to access parameters defined in the Dataclass configuration, the `___init___` function of the Dataset can be modified. -```python -from torch.utils.data import Dataset - -class YourDatasetName(Dataset): - def __init__(self, conf): # `conf` contains the entire configuration yaml - self.flip = conf[conf.mode].dataset.flip # Accessing the flip parameter defined in the dataclass - -``` - -#### Importing your Dataset with `ganslate` -Your Dataset along with its Dataclass configuration can be placed in the `projects` folder under a specific project name. -For example, -``` -projects/ - your_project/ - your_dataset.py # Contains both the Dataset and Dataclass configuration - default_docker.yaml -``` - -Modify the `default_docker.yaml` -```yaml -project: "./projects/your_project" # This needs to point to the directory where your_dataset.py is located -train: - dataset: - _target_: project.datasets.YourDatasetName - root: "" # Path to where the data is - # Additional parameters - flip: True -``` - -Apart from this, make sure the other parameters in the `default_docker.yaml` are set appropriately. [Refer to configuring your training with yaml files](configuration.md). - -You can now run training with your custom Dataset! Run this command from the root of the repository, -```python -python tools/train.py config=projects/your_project/default_docker.yaml -``` - - - - - - - diff --git a/docs/tutorials/starting_a_new_project.md b/docs/tutorials/starting_a_new_project.md deleted file mode 100644 index e69de29b..00000000 diff --git a/docs/tutorials/custom_architecture.md b/docs/tutorials_advanced/1_custom_gan_architecture.md similarity index 50% rename from docs/tutorials/custom_architecture.md rename to docs/tutorials_advanced/1_custom_gan_architecture.md index 24e02493..7a8723ac 100644 --- a/docs/tutorials/custom_architecture.md +++ b/docs/tutorials_advanced/1_custom_gan_architecture.md @@ -1,167 +1,4 @@ -# Defining Your Own GAN and Network Architectures - -In addition to using out-of-the-box the [popular architectures](https://github.com/Maastro-CDS-Imaging-Group/ganslate/docs/index.md) of GANs and of the generators and discriminators supplied by `ganslate`, you can easily define your custom architectures to suit your specific requirements. - - ------------------------------- -## 1. Custom GAN Architectures - -In `ganslate`, a `gan` represents the *system* of generator(s) and discriminator(s) which, during training, includes a set of loss criteria and optimizers, the specification of the flow of data among the generator and discriminator networks during forward pass, the computation of losses, and the update sequence for the generator and discriminator parameters. Depending on your requirements, you can either override one or more of these specific functionalities of the existing GAN classes or write new GAN classes with entirely different architectures. - - -### Example 1.1. CycleGAN with Custom Losses - Adding a New Loss Component - -This example shows how you can modify the default loss criteria of `CycleGAN` to include your custom loss criterion as an _additional_ loss component. This criterion could, for instance, be a *structure-consistency loss* that would constrain the high-level structure of a fake domain `B` image to be similar to that of its corresponding real domain `A` image. - -First, create a new file `projects/your_project/architectures/custom_cyclegan.py` and add the following lines. Note that your `CustomCycleGAN1` class must have an associated dataclass as shown. -```python -from dataclasses import dataclass -from ganslate import configs -from ganslate.nn.gans.unpaired import cyclegan - - -@dataclass -class OptimizerConfig(configs.base.BaseOptimizerConfig): - # Define your optimizer parameters specific to your GAN - # such as the scaling factor for your custom loss - lambda_structure_loss: float = 0.1 - - -@dataclass -class CustomCycleGAN1Config(cyclegan.CycleGANConfig): # Inherit from the `CycleGANConfig` class - """ Dataclass containing confiuration for your custom CycleGAN """ - optimizer: OptimizerConfig = OptimizerConfig - - -class CustomCycleGAN1(cyclegan.CycleGAN): # Inherit from the `CycleGAN` class - """ CycleGAN with a structure-consistency loss """ - - def __init__(self, conf): - # Initialize by invoking the constructor of the parent class - super().__init__(conf) - - # Now, extend or redefine method(s). - # In this example, we need to redefine only the `init_criterions` method. - def init_criterions(self): - # Standard adversarial loss [Same as in the original CycleGAN] - self.criterion_adv = AdversarialLoss( - self.conf.train.gan.optimizer.adversarial_loss_type).to(self.device) - - # Custom set of losses for the generators [Default CycleGAN losses plus your structure-consistency criterion] - self.criterion_G = CycleGANLossesWithStructure(self.conf) -``` - -Now, define the `CycleGANLossesWithStructure` by adding the following lines -```python -from ganslate.nn.losses.cyclegan_losses import CycleGANLosses - - -class CycleGANLossesWithStructure(CycleGANLosses): # Inherit from the default CycleGAN losses class - - def __init__(self, conf): - # Invoke the constructor of the parent class to initialize the default loss criteria - # such as cycle-consistency and identity (if enabled) losses - super.__init__(conf) - - # Initialize your structure criterion. - # The hyperparameter `lambda_structure_loss` is the scaling factor for this loss component - lambda_structure_loss = self.conf.train.optimizer.lambda_structure_loss - self.your_structure_criterion = YourStructureCriterion(lambda_structure_loss) - - def __call__(self, visuals): - # Invoke the `__call__` method of the parent class to compute the the default CycleGAN losses - losses = super.__call__(visuals) - - # Compute your custom loss and store as an addiitonal entry in the `losses` dictionary - real_A = visuals['real_A'] - fake_B = visuals['fake_B'] - losses['your_structure_loss'] = self.your_structure_criterion(real_A, fake_B) - - return losses -``` - -Define the class`YourStructureCriterion` that actually implements the structure-consistency criterion -```python -class YourStructureCriterion(): - def __init__(self, lambda_structure_loss): - self.lambda_structure_loss = lambda_structure_loss - # Your structure criterion could be, for instance, an L1 loss, an SSIM loss, - # or a custom distance metric - ... - - def __call__(self, real_image, fake_image): - # Compute the loss and return the scaled value - ... - return self.lambda_structure_loss * loss_value -``` - -Finally, edit your YAML configuration file to include the settings for your custom hyperparameter `lambda_structure_loss` -```yaml -project: projects/your_project -... - -train: - ... - - gan: - _target_: project.architectures.CustomCycleGAN1 # Location of your GAN class - ... - - optimizer: # Optimizer config that includes your custom hyperparameter - lambda_structure_loss: 0.1 - ... -... - -``` -Upon starting the training process, `ganslate` will search `your_project` directory for the `CustomCycleGAN1` class and instantiate from it the GAN object with the supplied settings. - - - -### Example 1.2. CycleGAN with Custom Losses - Writing a New Set of CycleGAN Losses - -In this example, we seek to not use the default CycleGAN losses at all but instead completely redefine them. The original cycle-consistency criterion involves computing an `L1` loss between the real domain `A` or domain `B`images and their corresponding reconstructed versions. For the sake of this example, let us consider implementing cycle-consistency using a custom distance metric. - -Let your custom CycleGAN class be named `CustomCycleGAN2`. Its definition would be mostly the same as that of `CustomCycleGAN1` from _Example 1_. Moving on to the definition of your `CustomCycleGANLosses`, it would be of the following form -```python -class CustomCycleGANLosses(CycleGANLosses): # Inherit from the default CycleGAN losses class - - def __init__(self, conf): - # Hyperparameters (here, scaling factors) for your loss - self.lambda_AB = conf.train.gan.optimizer.lambda_AB - self.lambda_BA = conf.train.gan.optimizer.lambda_BA - - # Instantiate your custom cycle-consistency - self.criterion_custom_cycle = CustomCycleLoss() - - def __call__(self, visuals): - real_A, real_B = visuals['real_A'], visuals['real_B'] - fake_A, fake_B = visuals['fake_A'], visuals['fake_B'] - rec_A, rec_B = visuals['rec_A'], visuals['rec_B'] - idt_A, idt_B = visuals['idt_A'], visuals['idt_B'] - - losses = {} - - # Compute cycle-consistency loss - losses['cycle_A'] = self.lambda_AB * self.criterion_cycle(real_A, rec_A) # L_cyc( real_A, G_BA(G_AB(real_A)) ) - losses['cycle_B'] = self.lambda_BA * self.criterion_cycle(real_B, rec_B) # L_cyc( real_B, G_AB(G_BA(real_B)) ) - - return losses - - -class CustomCycleLoss(): - - def __init__(self, proportion_ssim): - ... - - def __call__(self, real, reconstructed): - # Your alternate formulation of the cycle-consistency criterion - ... - return custom_cycle_loss -``` - - - -### Example 1.3. Writing Your Own GAN Class from Scratch +# Writing Your Own GAN Class from Scratch Advanced users may opt to implement a new GAN architecture from scratch. You can do this by inheriting from the abstract base class `BaseGAN` and implementing the required methods. All the existing GAN architectures in `ganslate` are defined in this manner. The file containing your `FancyNewGAN` must be structured as follows @@ -181,6 +18,7 @@ class OptimizerConfig(configs.base.BaseOptimizerConfig): @dataclass class FancyNewGANConfig(configs.base.BaseGANConfig): # Configuration dataclass for your GAN + name: str = "FancyNewGAN" optimizer: OptimizerConfig = OptimizerConfig @@ -365,14 +203,14 @@ def backward_D(self): The aforementioned methods are to be mandatorily implemented if you wish to contruct your own GAN architecture in `ganslate` from scratch. Additionally, We also recommend referring to the abstract `BaseGAN` class ([source](https://github.com/Maastro-CDS-Imaging-Group/ganslate/blob/documentation/ganslate/nn/gans/base.py)) to get an overview of other existing methods and of the internal logic. Finally, update your YAML configuration file to include the apporapriate settings for your custom-defined components ```yaml -project: projects/your_project +project_dir: projects/your_project ... train: ... gan: - _target_: project.architectures.YourFancyGAN # Location of your GAN class + name: "YourFancyGAN" # Name of your GAN class ... optimizer: # Optimizer config that includes your custom hyperparameter @@ -381,50 +219,4 @@ train: ... ... -``` - - ------------------------------------------------------ -## 2. Custom Generator or Discriminator Architectures - -In image translation GANs, the "generator" can be any network with an architecture that enables taking as input an image and producing an output image of the same size as the input. Whereas, the discriminator is any network that can take as input these images and produce a real/fake validity score which may either be a scalar or a 2D/3D map with each unit casting a fixed receptive field on the input. In `ganslate`, the generator and discriminator networks are defined as standard _PyTorch_ modules, [constructed by inheriting](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html) from the type `torch.nn.Module`. In addition to defining your custom generator or discriminator network, you must also define a configuration dataclass for your network in the same file as follows -```python -from torch import nn -from dataclasses import dataclass -from ganslate import configs - -@dataclass -class CustomGeneratorConfig(configs.base.BaseGeneratorConfig): - n_residual_blocks: int = 9 - use_dropout: bool = False - -class CustomGenerator(nn.Module): - """Create a custom generator module""" - def __init__(self, in_channels, out_channels, norm_type, n_residual_blocks, use_dropout): - # Define the class attributes - ... - - def forward(self, input_tensor): - # Define the forward pass operation - ... -``` - -Ensure that your YAML configuration file includes the pointer to your `CustomGenerator` as well as the appropriate settings -```yaml -project: projects/your_project -... - -train: - ... - - gan: - ... - - generator: - _target_: project.architectures.CustomGenerator # Location of your custom generator class - n_residual_blocks: 9 # Configuration - in_out_channels: - AB: [3, 3] - ... -... -``` +``` \ No newline at end of file diff --git a/docs/tutorials_advanced/2_custom_G_and_D_architecture.md b/docs/tutorials_advanced/2_custom_G_and_D_architecture.md new file mode 100644 index 00000000..597c746c --- /dev/null +++ b/docs/tutorials_advanced/2_custom_G_and_D_architecture.md @@ -0,0 +1,44 @@ +# Custom Generator or Discriminator Architectures + +In image translation GANs, the "generator" can be any network with an architecture that enables accepting as input an image and producing an output image of the same size as the input. Whereas, the discriminator is any network that can take as input these images and produce a real/fake validity score which may either be a scalar or a 2D/3D map with each unit casting a fixed receptive field on the input. In `ganslate`, the generator and discriminator networks are defined as standard _PyTorch_ modules, [constructed by inheriting](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html) from the type `torch.nn.Module`. In addition to defining your custom generator or discriminator network, you must also define a configuration dataclass for your network in the same file as follows +```python +from torch import nn +from dataclasses import dataclass +from ganslate import configs + +@dataclass +class CustomGeneratorConfig(configs.base.BaseGeneratorConfig): + name: str = 'CustomGenerator' + n_residual_blocks: int = 9 + use_dropout: bool = False + +class CustomGenerator(nn.Module): + """Create a custom generator module""" + def __init__(self, in_channels, out_channels, norm_type, n_residual_blocks, use_dropout): + # Define the class attributes + ... + + def forward(self, input_tensor): + # Define the forward pass operation + ... +``` + +Ensure that your YAML configuration file includes the pointer to your `CustomGenerator` as well as the appropriate settings +```yaml +project_dir: projects/your_project +... + +train: + ... + + gan: + ... + + generator: + name: "CustomGenerator" # Name of your custom generator class + n_residual_blocks: 9 # Configuration + in_out_channels: + AB: [3, 3] + ... +... +``` diff --git a/docs/getting_started/first_run.md b/docs/tutorials_basic/1_first_run_with_aerial2maps.md similarity index 100% rename from docs/getting_started/first_run.md rename to docs/tutorials_basic/1_first_run_with_aerial2maps.md diff --git a/docs/tutorials_basic/2_new_project.md b/docs/tutorials_basic/2_new_project.md new file mode 100644 index 00000000..731d51c2 --- /dev/null +++ b/docs/tutorials_basic/2_new_project.md @@ -0,0 +1,259 @@ +# Your First Project + + + +------------------------------------- +## Creating a Project from a Template + +TODO: Cookiecutter stuff + + + +------------------------------------------------------------------- +## Loading Your Own Data into ganslate with Custom Pytorch Datasets + +`ganslate` can be run on your own data through creating [your own Pytorch Dataset](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html). + + +### Integrating your Pytorch Dataset for Training +Once you have your custom Dataset, it needs to be modified as `ganslate` expects certain structure to make the Dataset class compatible with the framework. + +Namely, it expects *atleast* the following, + +1. Specific return format from the `__getitem__` function +2. Dataclass configuration for your dataset + +#### Specific return format from the `__getitem__` function +A sample of the format expected to be returned: +```python +def __getitem__(self, index): + ... + + return {'A': ..., 'B': ...} +``` +A dictionary with the following keys is expected to be returned + +1. `A` - corresponds to `torch.Tensor` image (2D/3D) from domain A +2. `B` - corresponds to `torch.Tensor` image (2D/3D) from domain B + + +#### Dataclass configuration for your dataset +The Dataset can be dynamically configured through Dataclasses as a part of the [OmegaConf configuration system](https://github.com/omry/omegaconf). Apart from configuration, the Dataclass is also important to allow the framework to easily import your Dataset while training. A sample of this can be found in the [default ImageDataset](https://github.com/Maastro-CDS-Imaging-Group/midaGAN/blob/26564fa721f71c024aa88fb278ecba7de748e55c/midaGAN/data/image_dataset.py#L15) provided with `ganslate`. + +The structure of the Dataclass configuration +```python +from dataclasses import dataclass +from ganslate import configs + +@dataclass +class YourDatasetNameConfig(configs.base.BaseDatasetConfig): # Your dataset always needs to inherit the BaseDatasetConfig + # Define additional parameters below, these parameters are passed to + # the dataset and can be used for dynamic configuration. + # Examples of parameters + flip: bool = True + +``` + +`YourDatasetName` is to be consistent with the name of your Pytorch Dataset. The name is also used to import the Dataset module. + + +To allow your Dataset to access parameters defined in the Dataclass configuration, the `___init___` function of the Dataset can be modified. +```python +from torch.utils.data import Dataset + +class YourDatasetName(Dataset): + def __init__(self, conf): # `conf` contains the entire configuration yaml + self.flip = conf[conf.mode].dataset.flip # Accessing the flip parameter defined in the dataclass + +``` + +#### Importing your Dataset with `ganslate` +Your Dataset along with its Dataclass configuration can be placed in the `projects` folder under a specific project name. +For example, +``` +projects/ + your_project/ + your_dataset.py # Contains both the Dataset and Dataclass configuration + default_docker.yaml +``` + +Modify the `default_docker.yaml` +```yaml +project: "./projects/your_project" # This needs to point to the directory where your_dataset.py is located +train: + dataset: + _target_: project.datasets.YourDatasetName + root: "" # Path to where the data is + # Additional parameters + flip: True +``` + +Apart from this, make sure the other parameters in the `default_docker.yaml` are set appropriately. [Refer to configuring your training with yaml files](configuration.md). + +You can now run training with your custom Dataset! Run this command from the root of the repository, +```python +python tools/train.py config=projects/your_project/default_docker.yaml +``` + + + +----------------------------------- +## Adding a Custom Loss to CycleGAN + +TODO: Edit this section to remove redundant examples + +In addition to using out-of-the-box the [popular architectures](https://github.com/Maastro-CDS-Imaging-Group/ganslate/docs/index.md) of GANs and of the generators and discriminators supplied by `ganslate`, you can easily define your custom architectures to suit your specific requirements. + +In `ganslate`, a `gan` represents the *system* of generator(s) and discriminator(s) which, during training, includes a set of loss criteria and optimizers, the specification of the flow of data among the generator and discriminator networks during forward pass, the computation of losses, and the update sequence for the generator and discriminator parameters. Depending on your requirements, you can either override one or more of these specific functionalities of the existing GAN classes or write new GAN classes with entirely different architectures. + + +### Example 1.1. CycleGAN with Custom Losses - Adding a New Loss Component + +This example shows how you can modify the default loss criteria of `CycleGAN` to include your custom loss criterion as an _additional_ loss component. This criterion could, for instance, be a *structure-consistency loss* that would constrain the high-level structure of a fake domain `B` image to be similar to that of its corresponding real domain `A` image. + +First, create a new file `projects/your_project/architectures/custom_cyclegan.py` and add the following lines. Note that your `CustomCycleGAN1` class must have an associated dataclass as shown. +```python +from dataclasses import dataclass +from ganslate import configs +from ganslate.nn.gans.unpaired import cyclegan + + +@dataclass +class OptimizerConfig(configs.base.BaseOptimizerConfig): + # Define your optimizer parameters specific to your GAN + # such as the scaling factor for your custom loss + lambda_structure_loss: float = 0.1 + + +@dataclass +class CustomCycleGAN1Config(cyclegan.CycleGANConfig): # Inherit from the `CycleGANConfig` class + """ Dataclass containing confiuration for your custom CycleGAN """ + name: str = "CustomCycleGAN1" + optimizer: OptimizerConfig = OptimizerConfig + + +class CustomCycleGAN1(cyclegan.CycleGAN): # Inherit from the `CycleGAN` class + """ CycleGAN with a structure-consistency loss """ + + def __init__(self, conf): + # Initialize by invoking the constructor of the parent class + super().__init__(conf) + + # Now, extend or redefine method(s). + # In this example, we need to redefine only the `init_criterions` method. + def init_criterions(self): + # Standard adversarial loss [Same as in the original CycleGAN] + self.criterion_adv = AdversarialLoss( + self.conf.train.gan.optimizer.adversarial_loss_type).to(self.device) + + # Custom set of losses for the generators [Default CycleGAN losses plus your structure-consistency criterion] + self.criterion_G = CycleGANLossesWithStructure(self.conf) +``` + +Now, define the `CycleGANLossesWithStructure` by adding the following lines +```python +from ganslate.nn.losses.cyclegan_losses import CycleGANLosses + + +class CycleGANLossesWithStructure(CycleGANLosses): # Inherit from the default CycleGAN losses class + + def __init__(self, conf): + # Invoke the constructor of the parent class to initialize the default loss criteria + # such as cycle-consistency and identity (if enabled) losses + super.__init__(conf) + + # Initialize your structure criterion. + # The hyperparameter `lambda_structure_loss` is the scaling factor for this loss component + lambda_structure_loss = self.conf.train.optimizer.lambda_structure_loss + self.your_structure_criterion = YourStructureCriterion(lambda_structure_loss) + + def __call__(self, visuals): + # Invoke the `__call__` method of the parent class to compute the the default CycleGAN losses + losses = super.__call__(visuals) + + # Compute your custom loss and store as an addiitonal entry in the `losses` dictionary + real_A = visuals['real_A'] + fake_B = visuals['fake_B'] + losses['your_structure_loss'] = self.your_structure_criterion(real_A, fake_B) + + return losses +``` + +Define the class`YourStructureCriterion` that actually implements the structure-consistency criterion +```python +class YourStructureCriterion(): + def __init__(self, lambda_structure_loss): + self.lambda_structure_loss = lambda_structure_loss + # Your structure criterion could be, for instance, an L1 loss, an SSIM loss, + # or a custom distance metric + ... + + def __call__(self, real_image, fake_image): + # Compute the loss and return the scaled value + ... + return self.lambda_structure_loss * loss_value +``` + +Finally, edit your YAML configuration file to include the settings for your custom hyperparameter `lambda_structure_loss` +```yaml +project_dir: projects/your_project +... + +train: + ... + + gan: + name: "CustomCycleGAN1" # Name of your GAN class + ... + + optimizer: # Optimizer config that includes your custom hyperparameter + lambda_structure_loss: 0.1 + ... +... + +``` +Upon starting the training process, `ganslate` will search `your_project` directory for the `CustomCycleGAN1` class and instantiate from it the GAN object with the supplied settings. + + + +### Example 1.2. CycleGAN with Custom Losses - Writing a New Set of CycleGAN Losses + +In this example, we seek to not use the default CycleGAN losses at all but instead completely redefine them. The original cycle-consistency criterion involves computing an `L1` loss between the real domain `A` or domain `B`images and their corresponding reconstructed versions. For the sake of this example, let us consider implementing cycle-consistency using a custom distance metric. + +Let your custom CycleGAN class be named `CustomCycleGAN2`. Its definition would be mostly the same as that of `CustomCycleGAN1` from _Example 1_. Moving on to the definition of your `CustomCycleGANLosses`, it would be of the following form +```python +class CustomCycleGANLosses(CycleGANLosses): # Inherit from the default CycleGAN losses class + + def __init__(self, conf): + # Hyperparameters (here, scaling factors) for your loss + self.lambda_AB = conf.train.gan.optimizer.lambda_AB + self.lambda_BA = conf.train.gan.optimizer.lambda_BA + + # Instantiate your custom cycle-consistency + self.criterion_custom_cycle = CustomCycleLoss() + + def __call__(self, visuals): + real_A, real_B = visuals['real_A'], visuals['real_B'] + fake_A, fake_B = visuals['fake_A'], visuals['fake_B'] + rec_A, rec_B = visuals['rec_A'], visuals['rec_B'] + idt_A, idt_B = visuals['idt_A'], visuals['idt_B'] + + losses = {} + + # Compute cycle-consistency loss + losses['cycle_A'] = self.lambda_AB * self.criterion_cycle(real_A, rec_A) # L_cyc( real_A, G_BA(G_AB(real_A)) ) + losses['cycle_B'] = self.lambda_BA * self.criterion_cycle(real_B, rec_B) # L_cyc( real_B, G_AB(G_BA(real_B)) ) + + return losses + + +class CustomCycleLoss(): + + def __init__(self, proportion_ssim): + ... + + def __call__(self, real, reconstructed): + # Your alternate formulation of the cycle-consistency criterion + ... + return custom_cycle_loss +``` diff --git a/docs/user_guide/engines.md b/docs/user_guide/engines.md deleted file mode 100644 index e69de29b..00000000 diff --git a/docs/user_guide/visualizations.md b/docs/user_guide/visualizations.md deleted file mode 100644 index 20b28e05..00000000 --- a/docs/user_guide/visualizations.md +++ /dev/null @@ -1 +0,0 @@ -# Visualizations and Logging with ganslate \ No newline at end of file diff --git a/ganslate/data/paired_image_dataset.py b/ganslate/data/paired_image_dataset.py index c98fc634..8144863c 100644 --- a/ganslate/data/paired_image_dataset.py +++ b/ganslate/data/paired_image_dataset.py @@ -43,8 +43,6 @@ def __init__(self, conf): self.transform = get_paired_image_transform(conf) self.rgb_or_grayscale = 'RGB' if conf[conf.mode].dataset.image_channels == 3 else 'L' - self.mode = conf.mode - def __getitem__(self, index): index = index % self.n_samples diff --git a/ganslate/nn/gans/paired/pix2pix.py b/ganslate/nn/gans/paired/pix2pix.py index 4cc9f99a..8f869acf 100644 --- a/ganslate/nn/gans/paired/pix2pix.py +++ b/ganslate/nn/gans/paired/pix2pix.py @@ -3,7 +3,6 @@ import torch from ganslate import configs -from ganslate.data.utils.image_pool import ImagePool from ganslate.nn.gans.base import BaseGAN from ganslate.nn.losses.adversarial_loss import AdversarialLoss from ganslate.nn.losses.pix2pix_losses import Pix2PixLoss diff --git a/mkdocs.yml b/mkdocs.yml index 5f27b18c..fdd1eefe 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,28 +1,36 @@ site_name: ganslate + nav: + - Home: index.md - Installation: installation.md - - Getting Started: - - Using the command line interface: getting_started/using_cli.md - - First run with maps to aerial photo dataset: getting_started/first_run.md - - - Tutorials: - - Intro to Configuration: tutorials/configuration.md - - Starting a New Project: tutorials/starting_a_new_project.md - - Custom Dataset: tutorials/custom_dataset.md - - Custom Architecture: tutorials/custom_architecture.md - - - User Guide: - - Visualizations and logging with ganslate: user_guide/visualizations.md - - Training, Validation, Testing, Inference: user_guide/engines.md + + - Package Overview: + - Commandline Interface: package_overview/1_cli.md + - ganslate Projects: package_overview/2_projects.md + - Datasets: package_overview/3_datasets.md + - Model Architectures and Loss Functions: package_overview/4_architectures.md + - Engines: package_overview/5_engines.md + - Logging and Visualization: package_overview/6_trackers.md + - Configuration: package_overview/7_configuration.md + + - Basic Tutorials: + - First Run: tutorials_basic/1_first_run_with_aerial2maps.md + - Your New Project: tutorials_basic/2_new_project.md + + - Advanced Tutorials: + - Custom GAN Architectures: tutorials_advanced/1_custom_gan_architecture.md + - Custom Generator and Discriminator Architectures: tutorials_advanced/2_custom_G_and_D_architectures.md - API: api/* - Community: - Contributing: community/contributing.md + theme: readthedocs + markdown_extensions: - admonition \ No newline at end of file diff --git a/projects/horse2zebra/experiments/default.yaml b/projects/horse2zebra/experiments/default.yaml index c908a7ca..869ec644 100644 --- a/projects/horse2zebra/experiments/default.yaml +++ b/projects/horse2zebra/experiments/default.yaml @@ -21,7 +21,7 @@ train: root: "/home/chinmay/Datasets/horse2zebra/train/" num_workers: 16 image_channels: 3 - preprocess: ["resize", "flip"] + preprocess: ["resize", "random_flip"] load_size: [128, 128] # (H, W) final_size: [128, 128] # (H, W)