Skip to content

Commit 1bfc3a6

Browse files
IshanAryendufacebook-github-bot
authored andcommitted
CIFAR 10 Training Example (#12417)
Summary: This is a training example which demonstrates how a simple CNN model for CIFAR 10 can be trained using the traditional PTE only training and the PTE + PTD file export. **NOTE:** The PTE + PTD training doesn't work in Python yet Differential Revision: D78124851
1 parent 6d2106a commit 1bfc3a6

File tree

7 files changed

+1789
-0
lines changed

7 files changed

+1789
-0
lines changed
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
## Objective:
2+
3+
This project enables the users to train PyTorch models on server infrastructure and get the required files to subsequently fine-tune these models on their edge devices.
4+
5+
### Key Objectives
6+
7+
1. **Server-Side Training**: Users can leverage server computational resources to perform initial model training using PyTorch, leveraging a more powerful hardware setup for the computationally intensive training phase.
8+
2. **Edge Device Fine-Tuning**: Pre-trained models are lowered and deployed on mobile devices through ExecuTorch where they undergo fine-tuning. This allows us to create a more personalized model while maintaining data privacy and allowing the users to be in control of their data.
9+
3. **Performance Benchmarking**: We will track comprehensive performance metrics for fine-tuning operations across various environments to see if the performance is consistent across various runtimes.
10+
11+
### ExecuTorch Installation
12+
13+
To install ExecuTorch in a python environment we can use the following commands in a new terminal:
14+
15+
```bash
16+
$ git clone https://github.com/pytorch/executorch.git
17+
$ cd executorch
18+
$ uv venv --seed --prompt et --python 3.10
19+
$ source .venv/bin/activate
20+
$ which python
21+
$ git fetch origin
22+
$ git submodule sync --recursive
23+
$ git submodule update --init --recursive
24+
$ ./install_requirements.sh
25+
$ ./install_executorch.sh
26+
```
27+
28+
### Prerequisites
29+
30+
We need the following packages for this example:
31+
1. torch
32+
2. torchvision
33+
3. executorch
34+
4. tqdm
35+
36+
Make sure these are installed in the `et` venv created in the previous steps. Torchvision and Toech are installed by the installation script of ExecuTorch. Tqdm might have to be installed manually.
37+
38+
### Dataset
39+
40+
For simplicity and replicatability we will be using the [CIFAR 10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). CIFAR-10 is a dataset of 60,000 32x32 color images in 10 classes, with 6000 images per class. There are 50,000 training images and 10,000 test images.
41+
42+
### PyTorch Model Architecture
43+
44+
Here is a simple CNN Model that we have used for the classification of the CIFAR 10 dataset:
45+
46+
```python
47+
class CIFAR10Model(torch.nn.Module):
48+
49+
def __init__(self, num_classes=10) -> None:
50+
super(CIFAR10Model, self).__init__()
51+
self.features = torch.nn.Sequential(
52+
torch.nn.Conv2d(3, 32, kernel_size=3, padding=1),
53+
torch.nn.ReLU(inplace=True),
54+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
55+
torch.nn.Conv2d(32, 64, kernel_size=3, padding=1),
56+
torch.nn.ReLU(inplace=True),
57+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
58+
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1),
59+
torch.nn.ReLU(inplace=True),
60+
torch.nn.MaxPool2d(kernel_size=2, stride=2),
61+
)
62+
63+
self.classifier = torch.nn.Sequential(
64+
torch.nn.Linear(128 * 4 * 4, 512),
65+
torch.nn.ReLU(inplace=True),
66+
torch.nn.Dropout(0.5),
67+
torch.nn.Linear(512, num_classes),
68+
)
69+
70+
def forward(self, x) -> torch.Tensor:
71+
"""
72+
The forward function takes the input image and applies the convolutional
73+
layers and the fully connected layers to extract the features and
74+
classify the image respectively.
75+
"""
76+
x = self.features(x)
77+
x = torch.flatten(x, 1)
78+
x = self.classifier(x)
79+
return x
80+
```
81+
82+
While this implementation demonstrates a relatively simple convolutional neural network, it incorporates fundamental building blocks essential for developing sophisticated computer vision models: convolutional layers for feature extraction and max-pooling layers for spatial dimension reduction and computational efficiency.
83+
84+
#### Core Components
85+
86+
1. **Convolutional Layers**: Extract hierarchical features from input images, learning patterns ranging from edges and textures in the images for a more complex and comprehensive object representations.
87+
2. **Max-Pooling Layers**: Reduce spatial dimensions while preserving the most important features, improving computational efficiency and providing translation invariance.
88+
89+
### Exporting the PyTorch model to ExecuTorch runtime
90+
91+
To enable efficient on-device execution and fine-tuning, the trained PyTorch model must be converted to the ExecuTorch format. This conversion process involves several key steps that optimize the model for mobile deployment while preserving its ability to be fine-tuned on edge devices.
92+
93+
#### Wrapping the model with the loss function before export
94+
95+
```python
96+
class ModelWithLoss(torch.nn.Module):
97+
98+
"""
99+
NOTE: A wrapper class that combines a model and the loss function into a
100+
single module. Used for capturing the entire computational graph, i.e.
101+
forward pass and the loss calculation, to be captured during export. Our
102+
objective is to enable on-device training, so the loss calculation should
103+
also be included in the exported graph.
104+
"""
105+
106+
def __init__(self, model, criterion):
107+
super().__init__()
108+
self.model = model
109+
self.criterion = criterion
110+
111+
def forward(self, x, target):
112+
# Forward pass through the model
113+
output = self.model(x)
114+
# Calculate loss
115+
loss = self.criterion(output, target)
116+
# Return loss and predicted class
117+
return loss, output.detach().argmax(dim=1)
118+
```
119+
120+
#### Conversion of PyTorch model to ExecuTorch
121+
122+
1. **Graph Capture**: The PyTorch model's computational graph is captured and serialized, creating a portable representation that can be executed across different hardware platforms without requiring the full PyTorch runtime.
123+
1. The exported format can run consistently across different mobile operating systems and hardware configurations.
124+
2. **Runtime Optimization**: The model is optimized for the ExecuTorch runtime environment, which is specifically designed for resource-constrained edge devices. This includes memory layout optimizations and operator fusion where applicable.
125+
1. ExecuTorch models have significantly lower memory requirements compared to full PyTorch models.
126+
3. **Fine-Tuning Compatibility**: The exported model retains the necessary metadata and structure to support gradient computation and parameter updates, enabling on-device fine-tuning capabilities.
127+
1. Optimized execution paths provide improved inference performance on mobile hardware.
128+
2. Traditionally the models are exported as `.pte` files which are immutable. Therefore, we need the `.ptd` files decoupled from `.pte` to perform fine-tuning and save the updated weights and biases for future use.
129+
3. Unlike traditional inference-only exports, we will set the flags during the model export to preserve the ability to perform gradient-based updates for fine-tuning.
130+
##### Tracing the model:
131+
132+
The `strict=True` flag in the `export()`method controls the tracing method used during model export. If we set `strict=True`:
133+
* Export method uses TorchDynamo for tracing
134+
* It ensures complete soundness of the resulting graph by validating all implicit assumptions
135+
* It provides stronger guarantees about the correctness of the exported model
136+
* **Caveats:** TorchDynamo has limited Python feature coverage, so you may encounter more errors during export
137+
##### Capturing the forward and backward graphs:
138+
139+
`_export_forward_backward()` transforms a forward-only exported PyTorch model into a **joint forward-backward graph** that includes both the forward pass and the automatically generated backward pass (gradients) needed for training.
140+
We get an `ExportedProgram` containing only the forward computation graph as the output of the `export()`method.
141+
Steps carried out by this method:
142+
143+
1. Apply core ATen decompositions to break down complex operations into simpler, more fundamental operations that are easier to handle during training.
144+
2. Automatically generates the backward pass (gradient computation) for the forward graph, creating a joint graph that can compute both:
145+
* Forward pass: Input → Output
146+
* Backward pass: Loss gradients → Parameter gradients
147+
3. **Graph Optimization**:
148+
* Removes unnecessary `detach` operations that would break gradient flow. (During model export, sometimes unnecessary detach operations get inserted that would prevent gradients from flowing backward through the model. Removing these ensures the training graph remains connected.)
149+
* Eliminates dead code to optimize the graph. (Dead code refers to computational nodes in the graph that are never used by any output and don't contribute to the final result.)
150+
* Preserves the graph structure needed for gradient computation.
151+
152+
##### Transform model from **ATen dialect** to **Edge dialect**
153+
`to_edge()`converts exported PyTorch programs from ATen (A Tensor Library) dialect to Edge dialect, which is optimized for edge device deployment.
154+
155+
`EdgeCompileConfig(_check_ir_validity=False)`skips intermediate representation (IR) validity checks during transformation and permits operations that might not pass strict validation.
156+
157+
### Fine-tuning the ExecuTorch model
158+
159+
The fine-tuning process involves updating the model's weights and biases based on new training data, typically collected from the edge devices. The support for `PTE` files is baked into the ExecuTorch runtime, which enables the model to be fine-tuned on the edge devices. However, at the time of writing, the support for training with `PTD` files is not yet available in the ExecuTorch Python runtime. Therefore, we export these files to be used in our `C++` and `Java` runtimes.
160+
161+
### Command Line Arguments
162+
163+
The training script supports various command line arguments to customize the training process. Here is a comprehensive list of all available flags:
164+
165+
#### Data Configuration
166+
- `--data-dir` (str, default: `./data`)
167+
- Directory to download and store CIFAR-10 dataset
168+
- Example: `--data-dir /path/to/data`
169+
170+
- `--batch-size` (int, default: `4`)
171+
- Batch size for data loaders during training and validation
172+
- Example: `--batch-size 32`
173+
174+
- `--use-balanced-dataset` (flag, default: `True`)
175+
- Use balanced dataset instead of full CIFAR-10
176+
- When enabled, creates a subset with equal representation from each class
177+
- Example: `--use-balanced-dataset` (to enable) or omit flag to use full dataset
178+
179+
- `--images-per-class` (int, default: `100`)
180+
- Number of images per class for balanced dataset
181+
- Only applies when `--use-balanced-dataset` is enabled
182+
- Example: `--images-per-class 200`
183+
184+
#### Model Paths
185+
- `--model-path` (str, default: `cifar10_model.pth`)
186+
- Path to save/load the PyTorch model
187+
- Example: `--model-path models/my_cifar_model.pth`
188+
189+
- `--pte-model-path` (str, default: `cifar10_model.pte`)
190+
- Path to save the PTE (PyTorch ExecuTorch) model file
191+
- Example: `--pte-model-path models/cifar_model.pte`
192+
193+
- `--split-pte-model-path` (str, default: `split_cifar10_model.pte`)
194+
- Path to save the split PTE model (model architecture without weights)
195+
- Used in conjunction with PTD files for external weight storage
196+
- Example: `--split-pte-model-path models/split_model.pte`
197+
198+
- `--ptd-model-dir` (str, default: `.`)
199+
- Directory path to save PTD (PyTorch Tensor Data) files
200+
- Contains external weights and constants separate from the PTE file
201+
- Example: `--ptd-model-dir ./model_data`
202+
203+
#### Training History and Logging
204+
- `--save-pt-json` (str, default: `cifar10_pt_model_finetuned_history.json`)
205+
- Path to save PyTorch model training history as JSON
206+
- Contains metrics like loss, accuracy, and timing information
207+
- Example: `--save-pt-json results/pytorch_history.json`
208+
209+
- `--save-et-json` (str, default: `cifar10_et_pte_only_model_finetuned_history.json`)
210+
- Path to save ExecuTorch model fine-tuning history as JSON
211+
- Contains metrics from the ExecuTorch fine-tuning process
212+
- Example: `--save-et-json results/executorch_history.json`
213+
214+
#### Training Hyperparameters
215+
- `--epochs` (int, default: `1`)
216+
- Number of epochs for initial PyTorch model training
217+
- Example: `--epochs 5`
218+
219+
- `--fine-tune-epochs` (int, default: `10`)
220+
- Number of epochs for fine-tuning the ExecuTorch model
221+
- Example: `--fine-tune-epochs 20`
222+
223+
- `--learning-rate` (float, default: `0.001`)
224+
- Learning rate for both PyTorch training and ExecuTorch fine-tuning
225+
- Example: `--learning-rate 0.01`
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

0 commit comments

Comments
 (0)