diff --git a/extension/training/examples/CIFAR/README.md b/extension/training/examples/CIFAR/README.md new file mode 100644 index 00000000000..822bf4f339d --- /dev/null +++ b/extension/training/examples/CIFAR/README.md @@ -0,0 +1,225 @@ +## Objective: + +This project enables the users to train PyTorch models on server infrastructure and get the required files to subsequently fine-tune these models on their edge devices. + +### Key Objectives + +1. **Server-Side Training**: Users can leverage server computational resources to perform initial model training using PyTorch, leveraging a more powerful hardware setup for the computationally intensive training phase. +2. **Edge Device Fine-Tuning**: Pre-trained models are lowered and deployed on mobile devices through ExecuTorch where they undergo fine-tuning. This allows us to create a more personalized model while maintaining data privacy and allowing the users to be in control of their data. +3. **Performance Benchmarking**: We will track comprehensive performance metrics for fine-tuning operations across various environments to see if the performance is consistent across various runtimes. + +### ExecuTorch Installation + +To install ExecuTorch in a python environment we can use the following commands in a new terminal: + +```bash +$ git clone https://github.com/pytorch/executorch.git +$ cd executorch +$ uv venv --seed --prompt et --python 3.10 +$ source .venv/bin/activate +$ which python +$ git fetch origin +$ git submodule sync --recursive +$ git submodule update --init --recursive +$ ./install_requirements.sh +$ ./install_executorch.sh +``` + +### Prerequisites + +We need the following packages for this example: +1. torch +2. torchvision +3. executorch +4. tqdm + +Make sure these are installed in the `et` venv created in the previous steps. Torchvision and Toech are installed by the installation script of ExecuTorch. Tqdm might have to be installed manually. + +### Dataset + +For simplicity and replicatability we will be using the [CIFAR 10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). CIFAR-10 is a dataset of 60,000 32x32 color images in 10 classes, with 6000 images per class. There are 50,000 training images and 10,000 test images. + +### PyTorch Model Architecture + +Here is a simple CNN Model that we have used for the classification of the CIFAR 10 dataset: + +```python +class CIFAR10Model(torch.nn.Module): + + def __init__(self, num_classes=10) -> None: + super(CIFAR10Model, self).__init__() + self.features = torch.nn.Sequential( + torch.nn.Conv2d(3, 32, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(32, 64, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(64, 128, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=2, stride=2), + ) + + self.classifier = torch.nn.Sequential( + torch.nn.Linear(128 * 4 * 4, 512), + torch.nn.ReLU(inplace=True), + torch.nn.Dropout(0.5), + torch.nn.Linear(512, num_classes), + ) + + def forward(self, x) -> torch.Tensor: + """ + The forward function takes the input image and applies the convolutional + layers and the fully connected layers to extract the features and + classify the image respectively. + """ + x = self.features(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x +``` + +While this implementation demonstrates a relatively simple convolutional neural network, it incorporates fundamental building blocks essential for developing sophisticated computer vision models: convolutional layers for feature extraction and max-pooling layers for spatial dimension reduction and computational efficiency. + +#### Core Components + +1. **Convolutional Layers**: Extract hierarchical features from input images, learning patterns ranging from edges and textures in the images for a more complex and comprehensive object representations. +2. **Max-Pooling Layers**: Reduce spatial dimensions while preserving the most important features, improving computational efficiency and providing translation invariance. + +### Exporting the PyTorch model to ExecuTorch runtime + +To enable efficient on-device execution and fine-tuning, the trained PyTorch model must be converted to the ExecuTorch format. This conversion process involves several key steps that optimize the model for mobile deployment while preserving its ability to be fine-tuned on edge devices. + +#### Wrapping the model with the loss function before export + +```python +class ModelWithLoss(torch.nn.Module): + + """ + NOTE: A wrapper class that combines a model and the loss function into a + single module. Used for capturing the entire computational graph, i.e. + forward pass and the loss calculation, to be captured during export. Our + objective is to enable on-device training, so the loss calculation should + also be included in the exported graph. + """ + + def __init__(self, model, criterion): + super().__init__() + self.model = model + self.criterion = criterion + + def forward(self, x, target): + # Forward pass through the model + output = self.model(x) + # Calculate loss + loss = self.criterion(output, target) + # Return loss and predicted class + return loss, output.detach().argmax(dim=1) +``` + +#### Conversion of PyTorch model to ExecuTorch + +1. **Graph Capture**: The PyTorch model's computational graph is captured and serialized, creating a portable representation that can be executed across different hardware platforms without requiring the full PyTorch runtime. + 1. The exported format can run consistently across different mobile operating systems and hardware configurations. +2. **Runtime Optimization**: The model is optimized for the ExecuTorch runtime environment, which is specifically designed for resource-constrained edge devices. This includes memory layout optimizations and operator fusion where applicable. + 1. ExecuTorch models have significantly lower memory requirements compared to full PyTorch models. +3. **Fine-Tuning Compatibility**: The exported model retains the necessary metadata and structure to support gradient computation and parameter updates, enabling on-device fine-tuning capabilities. + 1. Optimized execution paths provide improved inference performance on mobile hardware. + 2. Traditionally the models are exported as `.pte` files which are immutable. Therefore, we need the `.ptd` files decoupled from `.pte` to perform fine-tuning and save the updated weights and biases for future use. + 3. Unlike traditional inference-only exports, we will set the flags during the model export to preserve the ability to perform gradient-based updates for fine-tuning. +##### Tracing the model: + +The `strict=True` flag in the `export()`method controls the tracing method used during model export. If we set `strict=True`: +* Export method uses TorchDynamo for tracing +* It ensures complete soundness of the resulting graph by validating all implicit assumptions +* It provides stronger guarantees about the correctness of the exported model +* **Caveats:** TorchDynamo has limited Python feature coverage, so you may encounter more errors during export +##### Capturing the forward and backward graphs: + +`_export_forward_backward()` transforms a forward-only exported PyTorch model into a **joint forward-backward graph** that includes both the forward pass and the automatically generated backward pass (gradients) needed for training. +We get an `ExportedProgram` containing only the forward computation graph as the output of the `export()`method. +Steps carried out by this method: + +1. Apply core ATen decompositions to break down complex operations into simpler, more fundamental operations that are easier to handle during training. +2. Automatically generates the backward pass (gradient computation) for the forward graph, creating a joint graph that can compute both: + * Forward pass: Input → Output + * Backward pass: Loss gradients → Parameter gradients +3. **Graph Optimization**: + * Removes unnecessary `detach` operations that would break gradient flow. (During model export, sometimes unnecessary detach operations get inserted that would prevent gradients from flowing backward through the model. Removing these ensures the training graph remains connected.) + * Eliminates dead code to optimize the graph. (Dead code refers to computational nodes in the graph that are never used by any output and don't contribute to the final result.) + * Preserves the graph structure needed for gradient computation. + +##### Transform model from **ATen dialect** to **Edge dialect** +`to_edge()`converts exported PyTorch programs from ATen (A Tensor Library) dialect to Edge dialect, which is optimized for edge device deployment. + +`EdgeCompileConfig(_check_ir_validity=False)`skips intermediate representation (IR) validity checks during transformation and permits operations that might not pass strict validation. + +### Fine-tuning the ExecuTorch model + +The fine-tuning process involves updating the model's weights and biases based on new training data, typically collected from the edge devices. The support for `PTE` files is baked into the ExecuTorch runtime, which enables the model to be fine-tuned on the edge devices. However, at the time of writing, the support for training with `PTD` files is not yet available in the ExecuTorch Python runtime. Therefore, we export these files to be used in our `C++` and `Java` runtimes. + +### Command Line Arguments + +The training script supports various command line arguments to customize the training process. Here is a comprehensive list of all available flags: + +#### Data Configuration +- `--data-dir` (str, default: `./data`) + - Directory to download and store CIFAR-10 dataset + - Example: `--data-dir /path/to/data` + +- `--batch-size` (int, default: `4`) + - Batch size for data loaders during training and validation + - Example: `--batch-size 32` + +- `--use-balanced-dataset` (flag, default: `True`) + - Use balanced dataset instead of full CIFAR-10 + - When enabled, creates a subset with equal representation from each class + - Example: `--use-balanced-dataset` (to enable) or omit flag to use full dataset + +- `--images-per-class` (int, default: `100`) + - Number of images per class for balanced dataset + - Only applies when `--use-balanced-dataset` is enabled + - Example: `--images-per-class 200` + +#### Model Paths +- `--model-path` (str, default: `cifar10_model.pth`) + - Path to save/load the PyTorch model + - Example: `--model-path models/my_cifar_model.pth` + +- `--pte-model-path` (str, default: `cifar10_model.pte`) + - Path to save the PTE (PyTorch ExecuTorch) model file + - Example: `--pte-model-path models/cifar_model.pte` + +- `--split-pte-model-path` (str, default: `split_cifar10_model.pte`) + - Path to save the split PTE model (model architecture without weights) + - Used in conjunction with PTD files for external weight storage + - Example: `--split-pte-model-path models/split_model.pte` + +- `--ptd-model-dir` (str, default: `.`) + - Directory path to save PTD (PyTorch Tensor Data) files + - Contains external weights and constants separate from the PTE file + - Example: `--ptd-model-dir ./model_data` + +#### Training History and Logging +- `--save-pt-json` (str, default: `cifar10_pt_model_finetuned_history.json`) + - Path to save PyTorch model training history as JSON + - Contains metrics like loss, accuracy, and timing information + - Example: `--save-pt-json results/pytorch_history.json` + +- `--save-et-json` (str, default: `cifar10_et_pte_only_model_finetuned_history.json`) + - Path to save ExecuTorch model fine-tuning history as JSON + - Contains metrics from the ExecuTorch fine-tuning process + - Example: `--save-et-json results/executorch_history.json` + +#### Training Hyperparameters +- `--epochs` (int, default: `1`) + - Number of epochs for initial PyTorch model training + - Example: `--epochs 5` + +- `--fine-tune-epochs` (int, default: `10`) + - Number of epochs for fine-tuning the ExecuTorch model + - Example: `--fine-tune-epochs 20` + +- `--learning-rate` (float, default: `0.001`) + - Learning rate for both PyTorch training and ExecuTorch fine-tuning + - Example: `--learning-rate 0.01` diff --git a/extension/training/examples/CIFAR/TARGETS b/extension/training/examples/CIFAR/TARGETS new file mode 100644 index 00000000000..2341af9282f --- /dev/null +++ b/extension/training/examples/CIFAR/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/training/examples/CIFAR/main.py b/extension/training/examples/CIFAR/main.py new file mode 100644 index 00000000000..0b9772c47b4 --- /dev/null +++ b/extension/training/examples/CIFAR/main.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import argparse + +import torch + +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge +from executorch.extension.pybindings.portable_lib import ( + ExecuTorchModule, +) +from executorch.extension.training.examples.CIFAR.model import ( + CIFAR10Model, + ModelWithLoss, + train_model, + fine_tune_executorch_model, +) +from executorch.extension.training.examples.CIFAR.utils import ( + get_data_loaders, + save_json, +) + +from torch.export import export +from torch.export.experimental import _export_forward_backward + + +def export_model( + net: torch.nn.Module, input_tensor: torch.Tensor, label_tensor: torch.Tensor +) -> ExecuTorchModule: + """ + Export a PyTorch model to an ExecutorTorch module format. + + This function takes a PyTorch model and sample input/label tensors, wraps the model + with a loss function, exports it using torch.export, applies forward-backward pass + optimization, converts it to edge format, and finally to ExecutorTorch format. + + Args: + net (torch.nn.Module): The PyTorch model to be exported + input_tensor (torch.Tensor): A sample input tensor with the correct shape + label_tensor (torch.Tensor): A sample label tensor with the correct shape + + Returns: + ExecuTorchModule: The exported model in ExecutorTorch format ready for deployment + """ + criterion = torch.nn.CrossEntropyLoss() + model_with_loss = ModelWithLoss(net, criterion) + ep = export(model_with_loss, (input_tensor, label_tensor), strict=True) + ep = _export_forward_backward(ep) + ep = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) + ep = ep.to_executorch() + return ep + + +def export_model_with_ptd( + net: torch.nn.Module, input_tensor: torch.Tensor, label_tensor: torch.Tensor +) -> ExecuTorchModule: + """ + Export a PyTorch model to an ExecutorTorch module format with external tensor data. + + This function takes a PyTorch model and sample input/label tensors, wraps the model + with a loss function, exports it using torch.export, applies forward-backward pass + optimization, converts it to edge format, and finally to ExecutorTorch format with + external constants and mutable weights. + + Args: + net (torch.nn.Module): The PyTorch model to be exported + input_tensor (torch.Tensor): A sample input tensor with the correct shape + label_tensor (torch.Tensor): A sample label tensor with the correct shape + + Returns: + ExecuTorchModule: The exported model in ExecutorTorch format ready for deployment + """ + criterion = torch.nn.CrossEntropyLoss() + model_with_loss = ModelWithLoss(net, criterion) + ep = export(model_with_loss, (input_tensor, label_tensor), strict=True) + ep = _export_forward_backward(ep) + ep = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)) + ep = ep.to_executorch( + config=ExecutorchBackendConfig( + external_constants=True, # This is the flag that enables the external + # constants to be stored in a separate file external to the PTE file. + external_mutable_weights=True, # This is the flag that enables all trainable + # weights will be stored in a separate file external to the PTE file. + ) + ) + return ep + + +def save_model(ep: ExecuTorchModule, model_path: str) -> None: + """ + Save an ExecutorTorch model to a specified file path. + + This function writes the buffer of an ExecutorTorchModule to a file in binary format. + + Args: + ep (ExecuTorchModule): The ExecutorTorch module to be saved. + model_path (str): The file path where the model will be saved. + """ + with open(model_path, "wb") as file: + file.write(ep.buffer) + + +def parse_args() -> argparse.Namespace: + """ + Parse command line arguments for the CIFAR-10 training script. + + This function sets up an argument parser with various configuration options + for training a CIFAR-10 model with ExecutorTorch, including data paths, + training hyperparameters, and model save locations. + + Returns: + argparse.Namespace: An object containing all the parsed command line arguments + with their respective values (either user-provided or defaults). + """ + parser = argparse.ArgumentParser(description="CIFAR-10 Training Example") + parser.add_argument( + "--data-dir", + type=str, + default="./data", + help="Directory to download and store CIFAR-10 dataset (default: ./data)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=4, + help="Batch size for data loaders (default: 4)", + ) + parser.add_argument( + "--use-balanced-dataset", + action="store_true", + default=True, + help="Use balanced dataset instead of full CIFAR-10 (default: True)", + ) + parser.add_argument( + "--images-per-class", + type=int, + default=100, + help="Number of images per class for balanced dataset (default: 100)", + ) + parser.add_argument( + "--model-path", + type=str, + default="cifar10_model.pth", + help="PyTorch model path (default: cifar10_model.pth)", + ) + + parser.add_argument( + "--pte-model-path", + type=str, + default="cifar10_model.pte", + help="PTE model path (default: cifar10_model.pte)", + ) + + parser.add_argument( + "--split-pte-model-path", + type=str, + default="split_cifar10_model.pte", + help="Split PTE model path (default: split_cifar10_model.pte)", + ) + + parser.add_argument( + "--ptd-model-dir", type=str, default=".", help="PTD model path (default: .)" + ) + + parser.add_argument( + "--save-pt-json", + type=str, + default="cifar10_pt_model_finetuned_history.json", + help="Save the et json file (default: cifar10_et_pte_only_model_finetuned_history.json)", + ) + + parser.add_argument( + "--save-et-json", + type=str, + default="cifar10_et_pte_only_model_finetuned_history.json", + help="Save the et json file (default: cifar10_et_pte_only_model_finetuned_history.json)", + ) + + parser.add_argument( + "--epochs", + type=int, + default=1, + help="Number of epochs for training (default: 1)", + ) + + parser.add_argument( + "--fine-tune-epochs", + type=int, + default=10, + help="Number of fine-tuning epochs for fine-tuning (default: 150)", + ) + + parser.add_argument( + "--learning-rate", + type=float, + default=0.001, + help="Learning rate for fine-tuning (default: 0.001)", + ) + + return parser.parse_args() + + +def main() -> None: + + args = parse_args() + + train_loader, test_loader = get_data_loaders( + batch_size=args.batch_size, + data_dir=args.data_dir, + use_balanced_dataset=args.use_balanced_dataset, + images_per_class=args.images_per_class, + ) + + # initialize the main model + model = CIFAR10Model() + + model, train_hist = train_model( + model, + train_loader, + test_loader, + epochs=1, + lr=0.001, + momentum=0.9, + save_path=args.model_path, + ) + + save_json(train_hist, args.save_pt_json) + + # Export the model for et runtime + validation_sample_data = next(iter(test_loader)) + img, lbl = validation_sample_data + sample_input = img[0:1, :] + sample_label = lbl[0:1] + + ep = export_model(model, sample_input, sample_label) + + save_model(ep, args.pte_model_path) + + et_model, et_hist = fine_tune_executorch_model( + args.pte_model_path, + args.pte_model_path, + train_loader, + test_loader, + epochs=args.fine_tune_epochs, + learning_rate=args.learning_rate, + ) + + save_json(et_hist, args.save_et_json) + + # Split the model into the pte and ptd files + exported_program = export_model_with_ptd(model, sample_input, sample_label) + + exported_program._tensor_data["generic_cifar"] = exported_program._tensor_data.pop( + "_default_external_constant" + ) + exported_program.write_tensor_data_to_file(args.ptd_model_dir) + save_model(exported_program, args.split_pte_model_path) + + # Finetune the PyTorch model + model, train_hist = train_model( + model, + train_loader, + test_loader, + epochs=args.fine_tune_epochs, + lr=args.learning_rate, + momentum=0.9, + save_path=args.model_path, + ) + + save_json(train_hist, args.save_pt_json) + + +if __name__ == "__main__": + main() diff --git a/extension/training/examples/CIFAR/model.py b/extension/training/examples/CIFAR/model.py new file mode 100644 index 00000000000..3cb1ff6a7c8 --- /dev/null +++ b/extension/training/examples/CIFAR/model.py @@ -0,0 +1,350 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import os +import time +import typing +import torch +from torch.utils.data import DataLoader +from executorch.extension.pybindings.portable_lib import ( + _load_for_executorch_from_buffer, + ExecuTorchModule, +) +from tqdm import tqdm + + +class CIFAR10Model(torch.nn.Module): + + def __init__(self, num_classes: int = 10) -> None: + super(CIFAR10Model, self).__init__() + self.features = torch.nn.Sequential( + torch.nn.Conv2d(3, 32, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(32, 64, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(64, 128, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.MaxPool2d(kernel_size=2, stride=2), + ) + + self.classifier = torch.nn.Sequential( + torch.nn.Linear(128 * 4 * 4, 512), + torch.nn.ReLU(inplace=True), + torch.nn.Dropout(0.5), + torch.nn.Linear(512, num_classes), + ) + + def forward(self, x) -> torch.Tensor: + """ + The forward function takes the input image and applies the convolutional + layers and the fully connected layers to extract the features and + classify the image respectively. + """ + x = self.features(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + +class ModelWithLoss(torch.nn.Module): + """ + NOTE: A wrapper class that combines a model and the loss function into a + single module. Used for capturing the entire computational graph, i.e. + forward pass and the loss calculation, to be captured during export. Our + objective is to enable on-device training, so the loss calculation should + also be included in the exported graph. + """ + + def __init__( + self, model: torch.nn.Module, criterion: torch.nn.CrossEntropyLoss + ) -> None: + super().__init__() + self.model = model + self.criterion = criterion + + def forward( + self, x: torch.Tensor, target: torch.Tensor + ) -> typing.Tuple[torch.Tensor, torch.Tensor]: + # Forward pass through the model + output = self.model(x) + # Calculate loss + loss = self.criterion(output, target) + # Return loss and predicted class + return loss, output.detach().argmax(dim=1) + + +def train_model( + model: torch.nn.Module, + train_loader: DataLoader, + test_loader: DataLoader, + epochs: int = 1, + lr: float = 0.001, + momentum: float = 0.9, + save_path: str = "./best_cifar10_model.pth", +) -> typing.Tuple[torch.nn.Module, typing.Dict[int, typing.Dict[str, float]]]: + """ + The train_model function takes a model, a train_loader, and the number of + epochs as input.It then trains the model on the training data for the + specified number of epochs using the SGD optimizer and a cross-entropy loss + function. The function returns the trained model. + + args: + model (Required): The model to be trained. + train_loader (tuple, Required): The training data loader. + test_loader (tuple, Optional): The testing data loader. + epochs (int, optional): The number of epochs to train the model for. + lr (float, optional): The learning rate for the SGD optimizer. + momentum (float, optional): The momentum for the SGD optimizer. + save_path (str, optional): Path to save the best model. + """ + + history = {} + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) + + # Initialize best testing loss to a high value for checkpointing on the best model + best_test_loss = float("inf") + + # Create directory for save_path if it doesn't exist + save_dir = os.path.dirname(save_path) + if save_dir and not os.path.exists(save_dir): + os.makedirs(save_dir) + + train_start_time = time.time() + # Training loop + for epoch in range(epochs): + model.train() + epoch_loss = 0.0 + epoch_correct = 0 + epoch_total = 0 + for data in train_loader: + # Get the input data as a list of [inputs, labels] + inputs, labels = data + + # Set the gradients to zero for the next backward pass + optimizer.zero_grad() + + # Forward + Backward pass and optimization + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # Calculate correct predictions for epoch statistics + _, predicted = torch.max(outputs.data, 1) + total = labels.size(0) + correct = (predicted == labels).sum().item() + + # Accumulate statistics for epoch summary + epoch_loss += loss.detach().item() + epoch_correct += correct + epoch_total += total + + train_end_time = time.time() + # Calculate the stats for average loss and accuracy for the entire epoch + avg_epoch_loss = epoch_loss / len(train_loader) + avg_epoch_accuracy = 100 * epoch_correct / epoch_total + print( + f"Epoch {epoch + 1}: Train Loss: {avg_epoch_loss:.4f}, Train Accuracy: {avg_epoch_accuracy:.2f}%" + ) + + test_start_time = time.time() + # Testing phase + if test_loader is not None: + model.eval() # Set model to evaluation mode + test_loss = 0.0 + test_correct = 0 + test_total = 0 + with torch.no_grad(): # No need to track gradients + for data in test_loader: + images, labels = data + outputs = model(images) + loss = criterion(outputs, labels) + test_loss += loss.detach().item() + + # Calculate Testing accuracy as well + _, predicted = torch.max(outputs.data, 1) + test_total += labels.size(0) + test_correct += (predicted == labels).sum().item() + + # Calculate average Testing loss and accuracy + avg_test_loss = test_loss / len(test_loader) + test_accuracy = 100 * test_correct / test_total + test_end_time = time.time() + print( + f"\t Testing Loss: {avg_test_loss:.4f}, Testing Accuracy: {test_accuracy:.2f}%" + ) + + # Save the model with the best Testing loss + if avg_test_loss < best_test_loss: + best_test_loss = avg_test_loss + torch.save(model.state_dict(), save_path) + print( + f"New best model saved with Testing loss: {avg_test_loss:.4f} and Testing accuracy: {test_accuracy:.2f}%" + ) + + history[epoch] = { + "train_loss": avg_epoch_loss, + "train_accuracy": avg_epoch_accuracy, + "testing_loss": avg_test_loss, + "testing_accuracy": test_accuracy, + "training_time": train_end_time - train_start_time, + "train_time_per_image": (train_end_time - train_start_time) + / epoch_total, + "testing_time": test_end_time - test_start_time, + "test_time_per_image": (test_end_time - test_start_time) / test_total, + } + + print("\nTraining Completed!\n") + print("\n###########SUMMARY#############\n") + print(f"Best Testing loss: {best_test_loss:.4f}") + print(f"Model saved at: {save_path}\n") + print("################################\n") + + return model, history + + +def fine_tune_executorch_model( + model_path: str, + save_path: str, + train_loader: DataLoader, + val_loader: DataLoader, + epochs: int = 10, + learning_rate: float = 0.001, + momentum: float = 0.9, +) -> tuple[ExecuTorchModule, typing.Dict[str, typing.Any]]: + """ + Fine-tune an ExecutorTorch model using a training and validation dataset. + + This function loads an ExecutorTorch model from a file, fine-tunes it using + the provided training data loader, and evaluates it on the validation data + loader. The function returns the fine-tuned model and a history dictionary + containing training and validation metrics. + + Args: + model_path (str): Path to the ExecutorTorch model file to be fine-tuned. + save_path (str): Path where the fine-tuned model will be saved. + train_loader (DataLoader): DataLoader for the training dataset. + val_loader (DataLoader): DataLoader for the validation dataset. + epochs (int, optional): Number of epochs for fine-tuning (default: 10). + learning_rate (float, optional): Learning rate for parameter updates (default: 0.001). + momentum (float, optional): Momentum for parameter updates (default: 0.9). + + Returns: + tuple: A tuple containing the fine-tuned ExecutorTorchModule and a dictionary + with training and validation metrics. + """ + with open(model_path, "rb") as f: + model_bytes = f.read() + et_mod = _load_for_executorch_from_buffer(model_bytes) + + grad_start = et_mod.run_method("__et_training_gradients_index_forward", [])[0] + param_start = et_mod.run_method("__et_training_parameters_index_forward", [])[0] + history = {} + + # Initialize momentum buffers for SGD with momentum + momentum_buffers = {} + + for epoch in range(epochs): + print(f"Epoch {epoch+1}/{epochs}") + epoch_loss = 0.0 + train_correct = 0 + train_total = 0 + train_start_time = time.time() + + for batch in tqdm(train_loader): + inputs, labels = batch + # Process each image-label pair individually + for i in range(len(inputs)): + input_image = inputs[ + i : i + 1 + ] # Use list slicing to extract single image as the unqueeze method resulted in errors for some reason + label = labels[i : i + 1] + # Forward pass + out = et_mod.forward((input_image, label), clone_outputs=False) + loss = out[0] + predicted = out[1] + epoch_loss += loss.item() + + # Calculate accuracy + if predicted.item() == label.item(): + train_correct += 1 + train_total += 1 + + # Update parameters using SGD with momentum + with torch.no_grad(): + for param_idx, (grad, param) in enumerate( + zip(out[grad_start:param_start], out[param_start:]) + ): + if momentum > 0: + # Initialize momentum buffer if not exists + if param_idx not in momentum_buffers: + momentum_buffers[param_idx] = torch.zeros_like(grad) + + # Update momentum buffer: v = momentum * v + grad + momentum_buffers[param_idx].mul_(momentum).add_(grad) + # Update parameter: param = param - lr * v + param.sub_(learning_rate * momentum_buffers[param_idx]) + else: + # Standard SGD without momentum + param.sub_(learning_rate * grad) + + train_end_time = time.time() + + train_accuracy = 100 * train_correct / train_total if train_total != 0 else 0 + + avg_epoch_loss = epoch_loss / len(train_loader) / (train_loader.batch_size or 1) + + # Evaluate on validation set + + val_loss = 0.0 + val_correct = 0 + val_total = 0 + val_samples = 100 # Reducing the number of samples to 100 for faster validation + val_start_time = time.time() + + for i, val_batch in tqdm(enumerate(val_loader)): + if i == val_samples: + print(f"Reached {val_samples} samples. Terminating validation loop.") + break + + inputs, labels = val_batch + + for i in range(len(inputs)): + input_image = inputs[ + i : i + 1 + ] # Use list slicing to extract single image as the unqueeze method resulted in errors for some reason + label = labels[i : i + 1] + # Forward pass + out = et_mod.forward((input_image, label), clone_outputs=False) + loss = out[0] + predicted = out[1] + val_loss += loss.item() + # Calculate accuracy + if predicted.item() == label.item(): + val_correct += 1 + val_total += 1 + + val_end_time = time.time() + val_accuracy = 100 * val_correct / val_total if val_total != 0 else 0 + avg_val_loss = val_loss / len(val_loader) / (val_loader.batch_size or 1) + + history[epoch] = { + "train_loss": avg_epoch_loss, + "train_accuracy": train_accuracy, + "validation_loss": avg_val_loss, + "validation_accuracy": val_accuracy, + "training_time": train_end_time - train_start_time, + "train_time_per_image": (train_end_time - train_start_time) / train_total, + "testing_time": val_end_time - val_start_time, + "test_time_per_image": (val_end_time - val_start_time) / val_total, + } + + return et_mod, history diff --git a/extension/training/examples/CIFAR/targets.bzl b/extension/training/examples/CIFAR/targets.bzl new file mode 100644 index 00000000000..3131f8e496d --- /dev/null +++ b/extension/training/examples/CIFAR/targets.bzl @@ -0,0 +1,62 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + runtime.python_library( + name = "model", + srcs = ["model.py"], + visibility = [], # Private + deps = [ + "//caffe2:torch", + ], + ) + + runtime.python_library( + name = "utils", + srcs = ["utils.py"], + visibility = [], # Private + deps = [ + "//caffe2:torch", + "fbsource//third-party/pypi/tqdm:tqdm", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/extension/pybindings:portable_lib", # @manual + "//pytorch/vision:torchvision", + ], + ) + + runtime.python_binary( + name = "main", + srcs = ["main.py"], + main_function = "executorch.extension.training.examples.CIFAR.main.main", + deps = [ + ":model", + ":utils", + "fbsource//third-party/pypi/tqdm:tqdm", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/extension/pybindings:portable_lib", # @manual + "//pytorch/vision:torchvision", + ], + ) + + runtime.cxx_binary( + name = "train", + srcs = ["train.cpp"], + deps = [ + "//executorch/extension/training/module:training_module", + "//executorch/extension/tensor:tensor", + "//executorch/extension/training/optimizer:sgd", + "//executorch/runtime/executor:program", + "//executorch/extension/data_loader:file_data_loader", + "//executorch/kernels/portable:generated_lib", + "//executorch/extension/flat_tensor/serialize:serialize_cpp", + ], + external_deps = ["gflags"], + define_static_target = True, + ) diff --git a/extension/training/examples/CIFAR/train.cpp b/extension/training/examples/CIFAR/train.cpp new file mode 100644 index 00000000000..f32f2537c9c --- /dev/null +++ b/extension/training/examples/CIFAR/train.cpp @@ -0,0 +1,575 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Define namespace aliases for cleaner code +using executorch::extension::training::optimizer::SGD; // Stochastic Gradient + // Descent optimizer +using executorch::extension::training::optimizer::SGDOptions; // Options for SGD + // optimizer +using executorch::runtime::Error; // Error handling + +// Define command-line flags +DEFINE_string( + model_path, + "/data/sandcastle/boxes/fbsource/fbcode/executorch/extension/training/" + "examples/CIFAR/cifar10_model.pte", + "Model serialized in flatbuffer format."); // Path to the model file +DEFINE_string( + ptd_path, + "", + "Model weights serialized in flatbuffer format."); // Path to trained + // weights (optional) +DEFINE_string( + train_data_path, + "/data/sandcastle/boxes/fbsource/fbcode/executorch/extension/training/" + "examples/CIFAR/cifar-10/extracted_data/train_data.bin", + "Path to the combined training data file."); // Path to the combined train + // data file +DEFINE_string( + test_data_path, + "/data/sandcastle/boxes/fbsource/fbcode/executorch/extension/" + "training/examples/CIFAR/cifar-10/extracted_data/test_data.bin", + "Path to the combined test data file."); // Path to the combined + // test data file + +DEFINE_string( + ptd_save_path, + "/data/sandcastle/boxes/fbsource/fbcode/executorch/extension/training/" + "examples/CIFAR/CPP/", + "Path to save the cpp model trained weights."); // Path to save the trained + // weights + +DEFINE_int32( + batch_size, + 1, + "Batch size for training."); // Batch size for training (must match + // export batch size) + +DEFINE_int32( + num_epochs, + 1, + "Number of epochs to train."); // Number of epochs to train + +DEFINE_double( + learning_rate, + 0.001, + "Learning rate for SGD optimizer."); // Learning rate + +DEFINE_double(momentum, 0.9, + "Momentum for SGD optimizer."); // Momentum + +// Constants for the CIFAR-10 dataset +const size_t IMAGE_C = 3; // Number of color channels +const size_t IMAGE_H = 32; // Image height +const size_t IMAGE_W = 32; // Image width +const size_t IMAGE_TENSOR_SIZE = IMAGE_C * IMAGE_H * IMAGE_W; // Size of image + +void train_model( + executorch::extension::training::TrainingModule& mod, + const std::vector>& dataset, + SGD& optimizer, + std::mt19937& g) { + ET_LOG( + Info, + "Starting training for %d epochs with batch size %d...", + FLAGS_num_epochs, + FLAGS_batch_size); + + for (int epoch = 0; epoch < FLAGS_num_epochs; epoch++) { + auto epoch_start = std::chrono::high_resolution_clock::now(); + + float epoch_loss = 0.0; + size_t correct_predictions = 0; + size_t total_samples = 0; + + // Shuffling the dataset indices for each epoch for better learning + std::vector indices(dataset.size()); + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), g); + + // Process data in batches + size_t num_batches = 0; + for (size_t i = 0; i < dataset.size(); i += FLAGS_batch_size) { + // Skip incomplete batches at the end + if (i + FLAGS_batch_size > dataset.size()) { + break; + } + + // Start timing data batch preparation + auto data_prep_start = std::chrono::high_resolution_clock::now(); + + // Create batch tensors + auto batch_image_buffer = std::make_shared>( + FLAGS_batch_size * IMAGE_C * IMAGE_H * IMAGE_W); + auto batch_label_buffer = + std::make_shared>(FLAGS_batch_size); + + // Fill batch tensors with data from batch size samples + for (int j = 0; j < FLAGS_batch_size; j++) { + size_t idx = indices.at(i + j); + auto& data = dataset[idx]; + + // Copy image data + const float* src_img = data.first->const_data_ptr(); + float* dst_img = + batch_image_buffer->data() + (j * IMAGE_C * IMAGE_H * IMAGE_W); + std::memcpy( + dst_img, src_img, IMAGE_C * IMAGE_H * IMAGE_W * sizeof(float)); + + // Copy label data + batch_label_buffer->at(j) = data.second->const_data_ptr()[0]; + } + + // Create batch tensors + executorch::extension::TensorPtr batch_image_tensor = + executorch::extension::make_tensor_ptr( + {FLAGS_batch_size, IMAGE_C, IMAGE_H, IMAGE_W}, + *batch_image_buffer); + + // Convert int32_t labels to int64_t as expected by the model + auto batch_label_buffer_int64 = + std::make_shared>(FLAGS_batch_size); + for (int j = 0; j < FLAGS_batch_size; j++) { + batch_label_buffer_int64->at(j) = + static_cast(batch_label_buffer->at(j)); + } + + executorch::extension::TensorPtr batch_label_tensor = + executorch::extension::make_tensor_ptr( + {FLAGS_batch_size}, *batch_label_buffer_int64); + + // End timing data batch preparation + auto data_prep_end = std::chrono::high_resolution_clock::now(); + std::chrono::duration data_prep_time = + data_prep_end - data_prep_start; + + // Start timing model training + auto train_start = std::chrono::high_resolution_clock::now(); + + // Execute forward and backward pass on the batch + const auto& results = mod.execute_forward_backward( + "forward", {*batch_image_tensor, *batch_label_tensor}); + if (results.error() != Error::Ok) { + ET_LOG( + Error, + "Failed to execute the forward method on batch starting at " + "sample %zu", + i); + return; + } + + // Process results + float loss = results.get()[0].toTensor().const_data_ptr()[0]; + epoch_loss += loss; + + // Count correct predictions in the batch + const int64_t* predictions = + results.get()[1].toTensor().const_data_ptr(); + for (int j = 0; j < FLAGS_batch_size; j++) { + if (predictions[j] == static_cast(batch_label_buffer->at(j))) { + correct_predictions++; + } + } + total_samples += FLAGS_batch_size; + + // Get gradients and update parameters + auto grads = mod.named_gradients("forward"); + if (grads.error() != Error::Ok) { + ET_LOG(Error, "Failed to get named gradients"); + return; + } + optimizer.step(grads.get()); + + // End timing model training + auto train_end = std::chrono::high_resolution_clock::now(); + std::chrono::duration train_time = + train_end - train_start; + + num_batches++; + + // Log for tracking progress + if (num_batches % 100 == 0) { + ET_LOG( + Info, + "Epoch [%d/%d], Batch [%zu/%zu], Loss: %.4f, Data prep: %.2f " + "ms, Train: %.2f ms", + epoch + 1, + FLAGS_num_epochs, + num_batches, + dataset.size() / FLAGS_batch_size, + loss, + data_prep_time.count(), + train_time.count()); + } + } + + auto epoch_end = std::chrono::high_resolution_clock::now(); + std::chrono::duration epoch_time = epoch_end - epoch_start; + + // Log epoch summary + float avg_loss = epoch_loss / num_batches; + float accuracy = 100.0f * correct_predictions / total_samples; + ET_LOG( + Info, + "Epoch %d/%d Summary: Avg Loss: %.4f, Accuracy: %.2f%% (%zu/%zu), " + "Time: %.2f s", + epoch + 1, + FLAGS_num_epochs, + avg_loss, + accuracy, + correct_predictions, + total_samples, + epoch_time.count()); + } + + ET_LOG(Info, "Training finished..."); +} + +void evaluate_on_test_set( + executorch::extension::training::TrainingModule& mod, + const std::vector>& test_dataset) { + ET_LOG(Info, "Starting final evaluation on test set..."); + auto eval_start = std::chrono::high_resolution_clock::now(); + + float test_loss = 0.0; + size_t test_correct = 0; + size_t test_total = 0; + size_t test_batches = 0; + + for (size_t i = 0; i < test_dataset.size(); i += FLAGS_batch_size) { + if (i + FLAGS_batch_size > test_dataset.size()) { + break; + } + + // Create batch tensors for test data + auto batch_image_buffer = std::make_shared>( + FLAGS_batch_size * IMAGE_C * IMAGE_H * IMAGE_W); + auto batch_label_buffer = + std::make_shared>(FLAGS_batch_size); + + // Fill batch tensors with test data + for (int j = 0; j < FLAGS_batch_size; j++) { + auto& data = test_dataset[i + j]; + + // Copy image data + const float* src_img = data.first->const_data_ptr(); + float* dst_img = + batch_image_buffer->data() + (j * IMAGE_C * IMAGE_H * IMAGE_W); + std::memcpy( + dst_img, src_img, IMAGE_C * IMAGE_H * IMAGE_W * sizeof(float)); + + // Copy label data + batch_label_buffer->at(j) = data.second->const_data_ptr()[0]; + } + + // Create batch tensors + executorch::extension::TensorPtr batch_image_tensor = + executorch::extension::make_tensor_ptr( + {FLAGS_batch_size, IMAGE_C, IMAGE_H, IMAGE_W}, *batch_image_buffer); + + // Convert int32_t labels to int64_t as expected by the model + auto batch_label_buffer_int64 = + std::make_shared>(FLAGS_batch_size); + for (int j = 0; j < FLAGS_batch_size; j++) { + batch_label_buffer_int64->at(j) = + static_cast(batch_label_buffer->at(j)); + } + + executorch::extension::TensorPtr batch_label_tensor = + executorch::extension::make_tensor_ptr( + {FLAGS_batch_size}, *batch_label_buffer_int64); + + const auto& results = mod.execute_forward_backward( + "forward", {*batch_image_tensor, *batch_label_tensor}); + if (results.error() != Error::Ok) { + ET_LOG( + Error, + "Failed to execute forward pass on test batch starting at sample %zu", + i); + continue; + } + + // Process results + float loss = results.get()[0].toTensor().const_data_ptr()[0]; + test_loss += loss; + + // Count correct predictions + const int64_t* predictions = + results.get()[1].toTensor().const_data_ptr(); + for (int j = 0; j < FLAGS_batch_size; j++) { + if (predictions[j] == static_cast(batch_label_buffer->at(j))) { + test_correct++; + } + } + test_total += FLAGS_batch_size; + test_batches++; + } + + auto eval_end = std::chrono::high_resolution_clock::now(); + std::chrono::duration eval_time = eval_end - eval_start; + + float test_avg_loss = test_loss / test_batches; + float test_accuracy = 100.0f * test_correct / test_total; + + ET_LOG( + Info, + "Final Test Results: Avg Loss: %.4f, Accuracy: %.2f%% (%zu/%zu), " + "Time: %.2f s", + test_avg_loss, + test_accuracy, + test_correct, + test_total, + eval_time.count()); +} + +torch::executor::Error load_data_from_combined_binary( + const std::string& data_path, + std::vector>& data_set) { + std::ifstream data_file(data_path, std::ios::binary); + + if (!data_file.is_open()) { + ET_LOG(Error, "Failed to open data file: %s", data_path.c_str()); + return torch::executor::Error::InvalidState; + } + + ET_LOG( + Info, + "Loading the dataset from the combined binary file: %s", + data_path.c_str()); + + data_file.seekg(0, std::ios::end); + std::streampos file_size = data_file.tellg(); + data_file.seekg(0, std::ios::beg); + + // Debug: Read first 32 bytes to understand file format + char debug_bytes[32]; + data_file.read(debug_bytes, 32); + data_file.seekg(0, std::ios::beg); // Reset to beginning + + // Try CIFAR-10 format: label (1 byte) + image (3072 bytes) + // This is the standard CIFAR-10 binary format + size_t cifar_sample_size = + 1 + IMAGE_TENSOR_SIZE; // 1 byte label + 3072 bytes image + size_t cifar_max_samples = file_size / cifar_sample_size; + + for (size_t i = 0; i < cifar_max_samples; i++) { + // Read label (1 byte) + uint8_t label_byte; + data_file.read(reinterpret_cast(&label_byte), 1); + if (data_file.gcount() != 1) { + ET_LOG(Error, "Failed to read label byte at sample %zu", i); + return torch::executor::Error::InvalidState; + } + + // Read image data (3072 bytes as uint8_t, then convert to float) + std::vector image_bytes(IMAGE_TENSOR_SIZE); + data_file.read( + reinterpret_cast(image_bytes.data()), IMAGE_TENSOR_SIZE); + if (data_file.gcount() != IMAGE_TENSOR_SIZE) { + ET_LOG(Error, "Failed to read image bytes at sample %zu", i); + return torch::executor::Error::InvalidState; + } + + // Validate label range + if (label_byte > 9) { + ET_LOG( + Error, + "Invalid label value %u at sample %zu (expected 0-9)", + label_byte, + i); + return torch::executor::Error::InvalidState; + } + + // Convert image bytes to floats (normalize to 0-1 range) + auto image_buffer = std::make_shared>(IMAGE_TENSOR_SIZE); + for (size_t j = 0; j < IMAGE_TENSOR_SIZE; j++) { + (*image_buffer)[j] = static_cast(image_bytes[j]) / 255.0f; + } + + // Create label buffer + auto label_buffer = std::make_shared>(1); + (*label_buffer)[0] = static_cast(label_byte); + + // Store the image and label buffers + data_set.emplace_back( + executorch::extension::make_tensor_ptr( + {1, IMAGE_C, IMAGE_H, IMAGE_W}, *image_buffer), + executorch::extension::make_tensor_ptr({1}, *label_buffer)); + } + + ET_LOG( + Info, + "Successfully loaded %zu samples using CIFAR-10 format.", + data_set.size()); + return Error::Ok; +} + +int main(int argc, char** argv) { + // Parse command-line flags + gflags::ParseCommandLineFlags(&argc, &argv, true); + + // Load the model: The following code works for loading the pte model + executorch::runtime::Result + loader_res = + executorch::extension::FileDataLoader::from(FLAGS_model_path.c_str()); + if (loader_res.error() != Error::Ok) { + ET_LOG(Error, "Failed to open model file: %s", FLAGS_model_path.c_str()); + return 1; + } else { + ET_LOG( + Info, "Successfully opened model file: %s", FLAGS_model_path.c_str()); + } + + auto loader = std::make_unique( + std::move(loader_res.get())); + + std::unique_ptr ptd_loader = nullptr; + if (!FLAGS_ptd_path.empty()) { + executorch::runtime::Result + ptd_loader_res = + executorch::extension::FileDataLoader::from(FLAGS_ptd_path.c_str()); + if (ptd_loader_res.error() != Error::Ok) { + ET_LOG(Error, "Failed to open ptd file: %s", FLAGS_ptd_path.c_str()); + return 1; + } else { + ET_LOG( + Info, + "Successfully opened trained weights file: %s", + FLAGS_ptd_path.c_str()); + } + ptd_loader = std::make_unique( + std::move(ptd_loader_res.get())); + } + + auto mod = executorch::extension::training::TrainingModule( + std::move(loader), nullptr, nullptr, nullptr, std::move(ptd_loader)); + + // Load the training dataset from combined binary file + std::vector> + dataset; + Error data_load_res = + load_data_from_combined_binary(FLAGS_train_data_path, dataset); + if (data_load_res != Error::Ok) { + return 1; + } + + // Confirm that the dataset has been loaded correctly + ET_LOG( + Info, + "Successfully loaded the dataset with %zu samples.", + dataset.size()); + + // Create optimizer. + // Get the params and names + auto param_res = mod.named_parameters("forward"); + if (param_res.error() != Error::Ok) { + ET_LOG( + Error, + "Failed to get named parameters, error: %d", + static_cast(param_res.error())); + return 1; + } + + SGDOptions options{FLAGS_learning_rate, FLAGS_momentum}; + SGD optimizer(param_res.get(), options); + + ET_LOG( + Info, + "Successfully created the optimizer with lr=%.4f, momentum=%.2f.", + FLAGS_learning_rate, + FLAGS_momentum); + + // Initialize random number generator for shuffling + std::random_device rd; + std::mt19937 g(rd()); + + train_model(mod, dataset, optimizer, g); + + // Load test dataset for evaluation + std::vector> + test_dataset; + Error test_data_load_res = + load_data_from_combined_binary(FLAGS_test_data_path, test_dataset); + if (test_data_load_res != Error::Ok) { + ET_LOG(Error, "Failed to load test dataset, skipping evaluation"); + } else { + ET_LOG( + Info, + "Successfully loaded test dataset with %zu samples.", + test_dataset.size()); + + evaluate_on_test_set(mod, test_dataset); + } + + // Save the trained weights + std::map param_map; + for (auto& param : param_res.get()) { + param_map.insert({std::string(param.first.data()), param.second}); + } + + // Define the directory path for saving the model + const std::string model_path = FLAGS_ptd_save_path + "trained_cifar_cpp.ptd"; + + // Create the directory if it doesn't exist + int dir_fd = open(FLAGS_ptd_save_path.c_str(), O_RDONLY); + if (dir_fd == -1) { + // Directory doesn't exist or can't be accessed, create it + ET_LOG(Info, "Creating directory: %s", FLAGS_ptd_save_path.c_str()); + int result = mkdir( + FLAGS_ptd_save_path.c_str(), + 0755); // Create with permissions rwxr-xr-x + if (result != 0) { + ET_LOG( + Error, "Failed to create directory: %s", FLAGS_ptd_save_path.c_str()); + return 1; + } + } else { + // Directory exists, check if it's actually a directory + struct stat info {}; + if (fstat(dir_fd, &info) == 0 && !(info.st_mode & S_IFDIR)) { + close(dir_fd); + ET_LOG( + Error, + "Path exists but is not a directory: %s", + FLAGS_ptd_save_path.c_str()); + return 1; + } + close(dir_fd); + } + + executorch::extension::flat_tensor::save_ptd( + model_path.c_str(), param_map, 16); + ET_LOG(Info, "Trained weights saved to %s", model_path.c_str()); + + return 0; +} diff --git a/extension/training/examples/CIFAR/utils.py b/extension/training/examples/CIFAR/utils.py new file mode 100644 index 00000000000..1d9310d458e --- /dev/null +++ b/extension/training/examples/CIFAR/utils.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import os +import json +import pickle +import typing +from collections import defaultdict + +import numpy as np + +import torch +import torchvision +from torch.utils.data import DataLoader, Dataset, Subset + +from PIL import Image + + +class BalancedCIFARDataset(Dataset): + """Custom dataset class to load balanced CIFAR-10 data from binary file.""" + + def __init__( + self, + data_path: str, + transform: typing.Optional[torchvision.transforms.Compose] = None, + ) -> None: + """ + Args: + data_path: Path to the balanced dataset binary file + transform: Optional transformation to be applied on a sample + """ + self.data = [] + self.labels = [] + + # Read binary format: 1 byte label + 3072 bytes image data per record + with open(data_path, "rb") as f: + while True: + # Read label (1 byte) + label_byte = f.read(1) + if not label_byte: # End of file + break + label = int.from_bytes(label_byte, byteorder="big") + + # Read image data (3 * 32 * 32 = 3072 bytes) + image_bytes = f.read(3072) + if len(image_bytes) != 3072: + break # Incomplete record + + # Convert bytes to numpy array + image_data = np.frombuffer(image_bytes, dtype=np.uint8) + + self.data.append(image_data) + self.labels.append(label) + + self.data = np.array(self.data) + self.labels = np.array(self.labels) + self.transform = transform + + print(f"Loaded {len(self.data)} images from {data_path}") + + def __len__(self) -> int: + return len(self.data) + + def __getitem__(self, idx: int) -> typing.Tuple[Image.Image, int]: + # Reshape from (3072,) to (32, 32, 3) and convert to PIL Image + image_data = self.data[idx].reshape(3, 32, 32).transpose(1, 2, 0) + image = Image.fromarray(image_data) + label = self.labels[idx] + + if self.transform: + image = self.transform(image) + + return image, label + + +def create_balanced_cifar_dataset( + data_batch_path: str = "./data/cifar-10/cifar-10-batches-py/data_batch_1", + output_path: str = "./data/cifar-10/extracted_data/train_data.bin", + images_per_class: int = 100, +) -> str: + """ + Reads CIFAR-10 data from data_batch_1 file and creates a balanced dataset + with specified number of images per class, saved in binary format compatible with Android. + + Args: + data_batch_path: Path to the CIFAR-10 data_batch_1 file + output_path: Path where the balanced dataset will be saved + images_per_class: Number of images to extract per class (default: 100) + """ + # Load the CIFAR-10 data batch + with open(data_batch_path, "rb") as f: + data_dict = pickle.load(f, encoding="bytes") + + # Extract data and labels + data = data_dict[b"data"] # Shape: (10000, 3072) + labels = data_dict[b"labels"] # List of 10000 labels + + # Group images by class + class_images = defaultdict(list) + class_labels = defaultdict(list) + + for i, label in enumerate(labels): + if len(class_images[label]) < images_per_class: + class_images[label].append(data[i]) + class_labels[label].append(label) + + # Combine all selected images and labels + selected_data = [] + selected_labels = [] + + for class_id in range(10): # CIFAR-10 has 10 classes (0-9) + if class_id in class_images: + selected_data.extend(class_images[class_id]) + selected_labels.extend(class_labels[class_id]) + print(f"Class {class_id}: {len(class_images[class_id])} images selected") + + # Convert to numpy arrays + selected_data = np.array(selected_data, dtype=np.uint8) + selected_labels = np.array(selected_labels, dtype=np.uint8) + + # Ensure the output directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Save in binary format compatible with Android CIFAR-10 reader + # Format: 1 byte label + 3072 bytes image data per record + with open(output_path, "wb") as f: + for i in range(len(selected_data)): + # Write label as single byte + f.write(bytes([selected_labels[i]])) + # Write image data (3072 bytes) + f.write(selected_data[i].tobytes()) + + print(f"Balanced dataset saved to {output_path}") + print(f"Total images: {len(selected_data)}") + print(f"File size: {os.path.getsize(output_path)} bytes") + print(f"Expected size: {len(selected_data) * (1 + 3072)} bytes") + return output_path + + +def get_data_loaders( + batch_size: int = 4, + num_workers: int = 2, + data_dir: str = "./data", + use_balanced_dataset: bool = True, + images_per_class: int = 100, +) -> typing.Tuple[DataLoader, DataLoader]: + """ + Create data loaders for training, validation, and testing. + + Args: + batch_size: Batch size for data loaders + num_workers: Number of worker processes for data loading + data_dir: Root directory for data + use_balanced_dataset: Whether to use balanced dataset or standard CIFAR-10 + images_per_class: Number of images per class for balanced dataset + """ + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.RandomCrop(32, padding=4), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ), + ] + ) + + if use_balanced_dataset: + # Download CIFAR-10 first to ensure the raw data exists + print("Downloading CIFAR-10 dataset...") + torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True) + torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True) + + # The actual path where torchvision stores CIFAR-10 data + cifar_data_dir = os.path.join(data_dir, "cifar-10-batches-py") + + # Create balanced dataset if it doesn't exist + balanced_data_path = os.path.join( + data_dir, "cifar-10/extracted_data/train_data.bin" + ) + data_batch_path = os.path.join(cifar_data_dir, "data_batch_1") + + # Ensure the output directory exists + os.makedirs(os.path.dirname(balanced_data_path), exist_ok=True) + + # Create balanced dataset if it doesn't exist + if not os.path.exists(balanced_data_path): + print("Creating balanced train dataset...") + create_balanced_cifar_dataset( + data_batch_path=data_batch_path, + output_path=balanced_data_path, + images_per_class=images_per_class, + ) + + # Use balanced dataset for training + trainset = BalancedCIFARDataset(balanced_data_path, transform=transforms) + + indices = torch.randperm(len(trainset)).tolist() + + train_subset = Subset(trainset, indices) + + balanced_test_data_path = os.path.join( + data_dir, "cifar-10/extracted_data/test_data.bin" + ) + test_data_batch_path = os.path.join(cifar_data_dir, "test_batch") + # Ensure the output directory exists + os.makedirs(os.path.dirname(balanced_test_data_path), exist_ok=True) + # Create balanced dataset if it doesn't exist + if not os.path.exists(balanced_test_data_path): + print("Creating balanced test dataset...") + create_balanced_cifar_dataset( + data_batch_path=test_data_batch_path, + output_path=balanced_test_data_path, + images_per_class=images_per_class, + ) + # Use balanced dataset for testing + test_set = BalancedCIFARDataset(balanced_test_data_path, transform=transforms) + + else: + # Use standard CIFAR-10 dataset + trainset = torchvision.datasets.CIFAR10( + root=data_dir, train=True, download=True, transform=transforms + ) + + train_set_indices = torch.randperm(len(trainset)).tolist() + + train_subset = Subset(trainset, train_set_indices) + + # Test set always uses standard CIFAR-10 + test_set = torchvision.datasets.CIFAR10( + root=data_dir, train=False, download=True, transform=transforms + ) + + train_loader = DataLoader( + train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + + test_loader = DataLoader( + test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + + return train_loader, test_loader + + +def count_images_per_class(loader: DataLoader) -> typing.Dict[int, int]: + """ + Count the number of images per class in a DataLoader. + + This function iterates through a DataLoader and counts how many images + belong to each class based on their labels. + + Args: + loader (DataLoader): The DataLoader containing image-label pairs + + Returns: + Dict[int, int]: A dictionary mapping class IDs to their counts + """ + class_counts = defaultdict(int) + for _, labels in loader: + for label in labels: + class_counts[label.item()] += 1 + return class_counts + + +def save_json( + history: typing.Dict[int, typing.Dict[str, float]], json_path: str +) -> str: + """ + Save training/validation history to a JSON file. + + This function takes a dictionary containing training/validation metrics + organized by epoch and saves it to a JSON file at the specified path. + + Args: + history (Dict[int, Dict[str, float]]): Dictionary with epoch numbers as keys + and dictionaries of metrics (loss, accuracy, etc.) as values. + json_path (str): File path where the JSON file will be saved. + + Returns: + str: The path where the JSON file was saved. + """ + with open(json_path, "w") as f: + json.dump(history, f, indent=4) + print(f"History saved to {json_path}") + return json_path