|
| 1 | +## Objective: |
| 2 | + |
| 3 | +This project enables the users to train PyTorch models on server infrastructure and get the required files to subsequently fine-tune these models on their edge devices. |
| 4 | + |
| 5 | +### Key Objectives |
| 6 | + |
| 7 | +1. **Server-Side Training**: Users can leverage server computational resources to perform initial model training using PyTorch, leveraging a more powerful hardware setup for the computationally intensive training phase. |
| 8 | +2. **Edge Device Fine-Tuning**: Pre-trained models are lowered and deployed on mobile devices through ExecuTorch where they undergo fine-tuning. This allows us to create a more personalized model while maintaining data privacy and allowing the users to be in control of their data. |
| 9 | +3. **Performance Benchmarking**: We will track comprehensive performance metrics for fine-tuning operations across various environments to see if the performance is consistent across various runtimes. |
| 10 | + |
| 11 | +### ExecuTorch Installation |
| 12 | + |
| 13 | +To install ExecuTorch in a python environment we can use the following commands in a new terminal: |
| 14 | + |
| 15 | +```bash |
| 16 | +$ git clone https://github.com/pytorch/executorch.git |
| 17 | +$ cd executorch |
| 18 | +$ uv venv --seed --prompt et --python 3.10 |
| 19 | +$ source .venv/bin/activate |
| 20 | +$ which python |
| 21 | +$ git fetch origin |
| 22 | +$ git submodule sync --recursive |
| 23 | +$ git submodule update --init --recursive |
| 24 | +$ ./install_requirements.sh |
| 25 | +$ ./install_executorch.sh |
| 26 | +``` |
| 27 | + |
| 28 | +### Prerequisites |
| 29 | + |
| 30 | +We need the following packages for this example: |
| 31 | +1. torch |
| 32 | +2. torchvision |
| 33 | +3. executorch |
| 34 | +4. tqdm |
| 35 | + |
| 36 | +Make sure these are installed in the `et` venv created in the previous steps. Torchvision and Toech are installed by the installation script of ExecuTorch. Tqdm might have to be installed manually. |
| 37 | + |
| 38 | +### Dataset |
| 39 | + |
| 40 | +For simplicity and replicatability we will be using the [CIFAR 10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). CIFAR-10 is a dataset of 60,000 32x32 color images in 10 classes, with 6000 images per class. There are 50,000 training images and 10,000 test images. |
| 41 | + |
| 42 | +### PyTorch Model Architecture |
| 43 | + |
| 44 | +Here is a simple CNN Model that we have used for the classification of the CIFAR 10 dataset: |
| 45 | + |
| 46 | +```python |
| 47 | +class CIFAR10Model(torch.nn.Module): |
| 48 | + |
| 49 | + def __init__(self, num_classes=10) -> None: |
| 50 | + super(CIFAR10Model, self).__init__() |
| 51 | + self.features = torch.nn.Sequential( |
| 52 | + torch.nn.Conv2d(3, 32, kernel_size=3, padding=1), |
| 53 | + torch.nn.ReLU(inplace=True), |
| 54 | + torch.nn.MaxPool2d(kernel_size=2, stride=2), |
| 55 | + torch.nn.Conv2d(32, 64, kernel_size=3, padding=1), |
| 56 | + torch.nn.ReLU(inplace=True), |
| 57 | + torch.nn.MaxPool2d(kernel_size=2, stride=2), |
| 58 | + torch.nn.Conv2d(64, 128, kernel_size=3, padding=1), |
| 59 | + torch.nn.ReLU(inplace=True), |
| 60 | + torch.nn.MaxPool2d(kernel_size=2, stride=2), |
| 61 | + ) |
| 62 | + |
| 63 | + self.classifier = torch.nn.Sequential( |
| 64 | + torch.nn.Linear(128 * 4 * 4, 512), |
| 65 | + torch.nn.ReLU(inplace=True), |
| 66 | + torch.nn.Dropout(0.5), |
| 67 | + torch.nn.Linear(512, num_classes), |
| 68 | + ) |
| 69 | + |
| 70 | + def forward(self, x) -> torch.Tensor: |
| 71 | + """ |
| 72 | + The forward function takes the input image and applies the convolutional |
| 73 | + layers and the fully connected layers to extract the features and |
| 74 | + classify the image respectively. |
| 75 | + """ |
| 76 | + x = self.features(x) |
| 77 | + x = torch.flatten(x, 1) |
| 78 | + x = self.classifier(x) |
| 79 | + return x |
| 80 | +``` |
| 81 | + |
| 82 | +While this implementation demonstrates a relatively simple convolutional neural network, it incorporates fundamental building blocks essential for developing sophisticated computer vision models: convolutional layers for feature extraction and max-pooling layers for spatial dimension reduction and computational efficiency. |
| 83 | + |
| 84 | +#### Core Components |
| 85 | + |
| 86 | +1. **Convolutional Layers**: Extract hierarchical features from input images, learning patterns ranging from edges and textures in the images for a more complex and comprehensive object representations. |
| 87 | +2. **Max-Pooling Layers**: Reduce spatial dimensions while preserving the most important features, improving computational efficiency and providing translation invariance. |
| 88 | + |
| 89 | +### Exporting the PyTorch model to ExecuTorch runtime |
| 90 | + |
| 91 | +To enable efficient on-device execution and fine-tuning, the trained PyTorch model must be converted to the ExecuTorch format. This conversion process involves several key steps that optimize the model for mobile deployment while preserving its ability to be fine-tuned on edge devices. |
| 92 | + |
| 93 | +#### Wrapping the model with the loss function before export |
| 94 | + |
| 95 | +```python |
| 96 | +class ModelWithLoss(torch.nn.Module): |
| 97 | + |
| 98 | + """ |
| 99 | + NOTE: A wrapper class that combines a model and the loss function into a |
| 100 | + single module. Used for capturing the entire computational graph, i.e. |
| 101 | + forward pass and the loss calculation, to be captured during export. Our |
| 102 | + objective is to enable on-device training, so the loss calculation should |
| 103 | + also be included in the exported graph. |
| 104 | + """ |
| 105 | + |
| 106 | + def __init__(self, model, criterion): |
| 107 | + super().__init__() |
| 108 | + self.model = model |
| 109 | + self.criterion = criterion |
| 110 | + |
| 111 | + def forward(self, x, target): |
| 112 | + # Forward pass through the model |
| 113 | + output = self.model(x) |
| 114 | + # Calculate loss |
| 115 | + loss = self.criterion(output, target) |
| 116 | + # Return loss and predicted class |
| 117 | + return loss, output.detach().argmax(dim=1) |
| 118 | +``` |
| 119 | + |
| 120 | +#### Conversion of PyTorch model to ExecuTorch |
| 121 | + |
| 122 | +1. **Graph Capture**: The PyTorch model's computational graph is captured and serialized, creating a portable representation that can be executed across different hardware platforms without requiring the full PyTorch runtime. |
| 123 | + 1. The exported format can run consistently across different mobile operating systems and hardware configurations. |
| 124 | +2. **Runtime Optimization**: The model is optimized for the ExecuTorch runtime environment, which is specifically designed for resource-constrained edge devices. This includes memory layout optimizations and operator fusion where applicable. |
| 125 | + 1. ExecuTorch models have significantly lower memory requirements compared to full PyTorch models. |
| 126 | +3. **Fine-Tuning Compatibility**: The exported model retains the necessary metadata and structure to support gradient computation and parameter updates, enabling on-device fine-tuning capabilities. |
| 127 | + 1. Optimized execution paths provide improved inference performance on mobile hardware. |
| 128 | + 2. Traditionally the models are exported as `.pte` files which are immutable. Therefore, we need the `.ptd` files decoupled from `.pte` to perform fine-tuning and save the updated weights and biases for future use. |
| 129 | + 3. Unlike traditional inference-only exports, we will set the flags during the model export to preserve the ability to perform gradient-based updates for fine-tuning. |
| 130 | +##### Tracing the model: |
| 131 | + |
| 132 | +The `strict=True` flag in the `export()`method controls the tracing method used during model export. If we set `strict=True`: |
| 133 | +* Export method uses TorchDynamo for tracing |
| 134 | +* It ensures complete soundness of the resulting graph by validating all implicit assumptions |
| 135 | +* It provides stronger guarantees about the correctness of the exported model |
| 136 | +* **Caveats:** TorchDynamo has limited Python feature coverage, so you may encounter more errors during export |
| 137 | +##### Capturing the forward and backward graphs: |
| 138 | + |
| 139 | +`_export_forward_backward()` transforms a forward-only exported PyTorch model into a **joint forward-backward graph** that includes both the forward pass and the automatically generated backward pass (gradients) needed for training. |
| 140 | +We get an `ExportedProgram` containing only the forward computation graph as the output of the `export()`method. |
| 141 | +Steps carried out by this method: |
| 142 | + |
| 143 | +1. Apply core ATen decompositions to break down complex operations into simpler, more fundamental operations that are easier to handle during training. |
| 144 | +2. Automatically generates the backward pass (gradient computation) for the forward graph, creating a joint graph that can compute both: |
| 145 | + * Forward pass: Input → Output |
| 146 | + * Backward pass: Loss gradients → Parameter gradients |
| 147 | +3. **Graph Optimization**: |
| 148 | + * Removes unnecessary `detach` operations that would break gradient flow. (During model export, sometimes unnecessary detach operations get inserted that would prevent gradients from flowing backward through the model. Removing these ensures the training graph remains connected.) |
| 149 | + * Eliminates dead code to optimize the graph. (Dead code refers to computational nodes in the graph that are never used by any output and don't contribute to the final result.) |
| 150 | + * Preserves the graph structure needed for gradient computation. |
| 151 | + |
| 152 | +##### Transform model from **ATen dialect** to **Edge dialect** |
| 153 | +`to_edge()`converts exported PyTorch programs from ATen (A Tensor Library) dialect to Edge dialect, which is optimized for edge device deployment. |
| 154 | + |
| 155 | +`EdgeCompileConfig(_check_ir_validity=False)`skips intermediate representation (IR) validity checks during transformation and permits operations that might not pass strict validation. |
| 156 | + |
| 157 | +### Fine-tuning the ExecuTorch model |
| 158 | + |
| 159 | +The fine-tuning process involves updating the model's weights and biases based on new training data, typically collected from the edge devices. The support for `PTE` files is baked into the ExecuTorch runtime, which enables the model to be fine-tuned on the edge devices. However, at the time of writing, the support for training with `PTD` files is not yet available in the ExecuTorch Python runtime. Therefore, we export these files to be used in our `C++` and `Java` runtimes. |
| 160 | + |
| 161 | +### Command Line Arguments |
| 162 | + |
| 163 | +The training script supports various command line arguments to customize the training process. Here is a comprehensive list of all available flags: |
| 164 | + |
| 165 | +#### Data Configuration |
| 166 | +- `--data-dir` (str, default: `./data`) |
| 167 | + - Directory to download and store CIFAR-10 dataset |
| 168 | + - Example: `--data-dir /path/to/data` |
| 169 | + |
| 170 | +- `--batch-size` (int, default: `4`) |
| 171 | + - Batch size for data loaders during training and validation |
| 172 | + - Example: `--batch-size 32` |
| 173 | + |
| 174 | +- `--use-balanced-dataset` (flag, default: `True`) |
| 175 | + - Use balanced dataset instead of full CIFAR-10 |
| 176 | + - When enabled, creates a subset with equal representation from each class |
| 177 | + - Example: `--use-balanced-dataset` (to enable) or omit flag to use full dataset |
| 178 | + |
| 179 | +- `--images-per-class` (int, default: `100`) |
| 180 | + - Number of images per class for balanced dataset |
| 181 | + - Only applies when `--use-balanced-dataset` is enabled |
| 182 | + - Example: `--images-per-class 200` |
| 183 | + |
| 184 | +#### Model Paths |
| 185 | +- `--model-path` (str, default: `cifar10_model.pth`) |
| 186 | + - Path to save/load the PyTorch model |
| 187 | + - Example: `--model-path models/my_cifar_model.pth` |
| 188 | + |
| 189 | +- `--pte-model-path` (str, default: `cifar10_model.pte`) |
| 190 | + - Path to save the PTE (PyTorch ExecuTorch) model file |
| 191 | + - Example: `--pte-model-path models/cifar_model.pte` |
| 192 | + |
| 193 | +- `--split-pte-model-path` (str, default: `split_cifar10_model.pte`) |
| 194 | + - Path to save the split PTE model (model architecture without weights) |
| 195 | + - Used in conjunction with PTD files for external weight storage |
| 196 | + - Example: `--split-pte-model-path models/split_model.pte` |
| 197 | + |
| 198 | +- `--ptd-model-dir` (str, default: `.`) |
| 199 | + - Directory path to save PTD (PyTorch Tensor Data) files |
| 200 | + - Contains external weights and constants separate from the PTE file |
| 201 | + - Example: `--ptd-model-dir ./model_data` |
| 202 | + |
| 203 | +#### Training History and Logging |
| 204 | +- `--save-pt-json` (str, default: `cifar10_pt_model_finetuned_history.json`) |
| 205 | + - Path to save PyTorch model training history as JSON |
| 206 | + - Contains metrics like loss, accuracy, and timing information |
| 207 | + - Example: `--save-pt-json results/pytorch_history.json` |
| 208 | + |
| 209 | +- `--save-et-json` (str, default: `cifar10_et_pte_only_model_finetuned_history.json`) |
| 210 | + - Path to save ExecuTorch model fine-tuning history as JSON |
| 211 | + - Contains metrics from the ExecuTorch fine-tuning process |
| 212 | + - Example: `--save-et-json results/executorch_history.json` |
| 213 | + |
| 214 | +#### Training Hyperparameters |
| 215 | +- `--epochs` (int, default: `1`) |
| 216 | + - Number of epochs for initial PyTorch model training |
| 217 | + - Example: `--epochs 5` |
| 218 | + |
| 219 | +- `--fine-tune-epochs` (int, default: `10`) |
| 220 | + - Number of epochs for fine-tuning the ExecuTorch model |
| 221 | + - Example: `--fine-tune-epochs 20` |
| 222 | + |
| 223 | +- `--learning-rate` (float, default: `0.001`) |
| 224 | + - Learning rate for both PyTorch training and ExecuTorch fine-tuning |
| 225 | + - Example: `--learning-rate 0.01` |
0 commit comments