Skip to content

TezRomacH/layer-to-layer-pytorch

Repository files navigation

L2L execution algorithm PyTorch

Build status Python Version Dependencies Status

Code style: black Security: bandit Pre-commit Semantic Versions License

PyTorch implementation of L2L execution algorithm from paper Training Large Neural Networks with Constant Memory using a New Execution Algorithm

πŸš€ Example

You need to define a torch model where all layers are specified in ModuleList.

See examples folder

Basic usage

import torch
from torch import nn, optim

class M(nn.Module):
    def __init__(self, depth: int, dim: int, hidden_dim: Optional[int] = None):
        super().__init__()
        hidden_dim = hidden_dim or dim
        self.layers = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),
                    nn.LeakyReLU(),
                )
            ]
            + [
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.BatchNorm1d(hidden_dim),
                    nn.LeakyReLU(),
                )
                for i in range(depth)
            ]
            + [nn.Linear(hidden_dim, dim), nn.Sigmoid()]
        )

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        x = batch
        for l in self.layers:
            x = l(x)

        return x


model = M(depth=5, dim=40).train() # on CPU

Then, you can use the L2L wrapper over this model.

from layer_to_layer_pytorch.l2l import Layer2Layer

l2l_model = Layer2Layer(
    model,
    layers_attr="layers", # attribute with ModuleList
    microbatch_size=100,  # size of a microbatch in a minibatch :) from original paper
    verbose=False  # enable tqdm
)

And train it, like torch model (almost):

from tqdm.auto import tqdm, trange

x = torch.rand(1_000, 40) # on CPU
y = torch.rand(1_000, 40) # on CPU

losses = []
criterion = nn.MSELoss()

optimizer = optim.AdamW(l2l_model.main_params) # optimizer works with the main model on CPU

for i in trange(2000):
    l2l_model.zero_grad()
    _ = l2l_model.forward(x)

    loss_value: float = l2l_model.compute_loss(y, criterion)

    if i % 50 == 0:
        tqdm.write(f"[{i}] loss = {loss_value}")
    losses.append(loss_value)


    l2l_model.backward()
    optimizer.step()
    l2l_model.update_main_model_params() # Sync params with CPU

FP-16 usage

Cross-mixes-precision available in init params

from layer_to_layer_pytorch.l2l import Layer2Layer

l2l_model = Layer2Layer(
    model,
    layers_attr="layers",
    microbatch_size=100,

    # fp-16
    mixed_precision=True,
    loss_scale = 128.0
)

And then train the same way πŸ˜‰

Installation

pip install layer-to-layer-pytorch

or install with Poetry

poetry add layer-to-layer-pytorch

πŸ“ˆ Releases

You can see the list of available releases on the GitHub Releases page.

We follow Semantic Versions specification.

πŸ›‘ License

License

This project is licensed under the terms of the MIT license. See LICENSE for more details.

πŸ“ƒ Citation

This library

@misc{layer-to-layer-pytorch,
  author = {Roman Tezikov},
  title = {PyTorch implementation of L2L execution algorithm},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/TezRomacH/layer-to-layer-pytorch}}
}

Original paper

@article{Pudipeddi2020TrainingLN,
  title={Training Large Neural Networks with Constant Memory using a New Execution Algorithm},
  author={Bharadwaj Pudipeddi and Maral Mesmakhosroshahi and J. Xi and S. Bharadwaj},
  journal={ArXiv},
  year={2020},
  volume={abs/2002.05645}
}

Credits

This project was generated with python-package-template.