diff --git a/yolo/config/task/export.yaml b/yolo/config/task/export.yaml new file mode 100644 index 00000000..0256be30 --- /dev/null +++ b/yolo/config/task/export.yaml @@ -0,0 +1,3 @@ +task: export + +format: onnx diff --git a/yolo/lazy.py b/yolo/lazy.py index 0f1cc55b..00c90921 100644 --- a/yolo/lazy.py +++ b/yolo/lazy.py @@ -8,7 +8,7 @@ sys.path.append(str(project_root)) from yolo.config.config import Config -from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel +from yolo.tools.solver import ExportModel, InferenceModel, TrainModel, ValidateModel from yolo.utils.logging_utils import setup @@ -39,6 +39,9 @@ def main(cfg: Config): if cfg.task.task == "inference": model = InferenceModel(cfg) trainer.predict(model) + if cfg.task.task == "export": + model = ExportModel(cfg) + model.export() if __name__ == "__main__": diff --git a/yolo/model/yolo.py b/yolo/model/yolo.py index 42d72208..b3977a07 100644 --- a/yolo/model/yolo.py +++ b/yolo/model/yolo.py @@ -21,12 +21,13 @@ class YOLO(nn.Module): parameters, and any other relevant configuration details. """ - def __init__(self, model_cfg: ModelConfig, class_num: int = 80): + def __init__(self, model_cfg: ModelConfig, class_num: int = 80, export_mode: bool = False): super(YOLO, self).__init__() self.num_classes = class_num self.layer_map = get_layer_map() # Get the map Dict[str: Module] self.model: List[YOLOLayer] = nn.ModuleList() self.reg_max = getattr(model_cfg.anchor, "reg_max", 16) + self.export_mode = export_mode self.build_model(model_cfg.model) def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): @@ -68,10 +69,44 @@ def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): setattr(layer, "out_c", out_channels) layer_idx += 1 + def generate_anchors(self, image_size: List[int], strides: List[int]): + W, H = image_size + anchors = [] + scaler = [] + for stride in strides: + anchor_num = W // stride * H // stride + scaler.append(torch.full((anchor_num,), stride)) + shift = stride // 2 + h = torch.arange(0, H, stride) + shift + w = torch.arange(0, W, stride) + shift + if torch.__version__ >= "2.3.0": + anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij") + else: + anchor_h, anchor_w = torch.meshgrid(h, w) + anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1) + anchors.append(anchor) + all_anchors = torch.cat(anchors, dim=0) + all_scalers = torch.cat(scaler, dim=0) + return all_anchors, all_scalers + + def get_strides(self, output, input_width) -> List[int]: + W = input_width + strides = [] + for predict_head in output: + _, _, *anchor_num = predict_head[2].shape + strides.append(W // anchor_num[1]) + + return strides + def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = None): + input_width, input_height = x.shape[-2:] y = {0: x, **(external or {})} output = dict() - for index, layer in enumerate(self.model, start=1): + + # Use a simple loop instead of enumerate() + # Needed for torch export compatibility + index = 1 + for layer in self.model: if isinstance(layer.source, list): model_input = [y[idx] for idx in layer.source] else: @@ -81,12 +116,42 @@ def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = x = layer(model_input, **external_input) y[-1] = x + if layer.usable: y[index] = x + if layer.output: output[layer.tags] = x if layer.tags == shortcut: return output + + index += 1 + + if self.export_mode: + + preds_cls, preds_anc, preds_box = [], [], [] + for layer_output in output["Main"]: + pred_cls, pred_anc, pred_box = layer_output + preds_cls.append(pred_cls.permute(0, 2, 3, 1).reshape(pred_cls.shape[0], -1, pred_cls.shape[1])) + preds_anc.append( + pred_anc.permute(0, 3, 4, 1, 2).reshape(pred_anc.shape[0], -1, pred_anc.shape[2], pred_anc.shape[1]) + ) + preds_box.append(pred_box.permute(0, 2, 3, 1).reshape(pred_box.shape[0], -1, pred_box.shape[1])) + + preds_cls = torch.concat(preds_cls, dim=1).to(x[0][0].device) + preds_anc = torch.concat(preds_anc, dim=1).to(x[0][0].device) + preds_box = torch.concat(preds_box, dim=1).to(x[0][0].device) + + strides = self.get_strides(output["Main"], input_width) + anchor_grid, scaler = self.generate_anchors([input_width, input_height], strides) # + anchor_grid = anchor_grid.to(x[0][0].device) + scaler = scaler.to(x[0][0].device) + pred_LTRB = preds_box * scaler.view(1, -1, 1) + lt, rb = pred_LTRB.chunk(2, dim=-1) + preds_box = torch.cat([anchor_grid - lt, anchor_grid + rb], dim=-1) + + return preds_cls, preds_anc, preds_box + return output def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]): @@ -158,7 +223,9 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): self.model.load_state_dict(model_state_dict) -def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: +def create_model( + model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80, export_mode: bool = False +) -> YOLO: """Constructs and returns a model from a Dictionary configuration file. Args: @@ -168,7 +235,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, YOLO: An instance of the model defined by the given configuration. """ OmegaConf.set_struct(model_cfg, False) - model = YOLO(model_cfg, class_num) + model = YOLO(model_cfg, class_num, export_mode=export_mode) if weight_path: if weight_path == True: weight_path = Path("weights") / f"{model_cfg.name}.pt" diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index c20b1ab3..03bbcd27 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -10,13 +10,18 @@ from yolo.tools.drawer import draw_bboxes from yolo.tools.loss_functions import create_loss_function from yolo.utils.bounding_box_utils import create_converter, to_metrics_format +from yolo.utils.deploy_utils import FastModelLoader +from yolo.utils.export_utils import ModelExporter +from yolo.utils.logger import logger from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler class BaseModel(LightningModule): - def __init__(self, cfg: Config): + def __init__(self, cfg: Config, export_mode: bool = False): super().__init__() - self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight) + self.model = create_model( + cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight, export_mode=export_mode + ) def forward(self, x): return self.model(x) @@ -109,15 +114,27 @@ def configure_optimizers(self): class InferenceModel(BaseModel): def __init__(self, cfg: Config): - super().__init__(cfg) + if hasattr(cfg.model.model, "auxiliary"): + cfg.model.model.auxiliary = {} + + export_mode = False + fast_inference = cfg.task.fast_inference + # TODO check if we can use export mode for all formats + if fast_inference == "coreml": + export_mode = True + + super().__init__(cfg, export_mode=export_mode) self.cfg = cfg - # TODO: Add FastModel self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task) def setup(self, stage): self.vec2box = create_converter( self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device ) + + if self.cfg.task.fast_inference: + self.fast_model = FastModelLoader(self.cfg, self.model).load_model(self.device) + self.post_process = PostProcess(self.vec2box, self.cfg.task.nms) def predict_dataloader(self): @@ -125,7 +142,11 @@ def predict_dataloader(self): def predict_step(self, batch, batch_idx): images, rev_tensor, origin_frame = batch - predicts = self.post_process(self(images), rev_tensor=rev_tensor) + if hasattr(self, "fast_model") and self.fast_model: + predictions = self.fast_model(images) + else: + predictions = self(images) + predicts = self.post_process(predictions, rev_tensor=rev_tensor) img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list) if getattr(self.predict_loader, "is_stream", None): fps = self._display_stream(img) @@ -139,3 +160,28 @@ def _save_image(self, img, batch_idx): save_image_path = Path(self.trainer.default_root_dir) / f"frame{batch_idx:03d}.png" img.save(save_image_path) print(f"💾 Saved visualize image at {save_image_path}") + + +class ExportModel(BaseModel): + def __init__(self, cfg: Config): + if hasattr(cfg.model.model, "auxiliary"): + cfg.model.model.auxiliary = {} + + export_mode = False + format = cfg.task.format + # TODO check if we can use export mode for all formats + if self.format == "coreml": + export_mode = True + + super().__init__(cfg, export_mode=export_mode) + self.cfg = cfg + self.format = format + self.model_exporter = ModelExporter(self.cfg, self.model, format=self.format) + + def export(self): + if self.format == "onnx": + self.model_exporter.export_onnx() + if self.format == "tflite": + self.model_exporter.export_tflite() + if self.format == "coreml": + self.model_exporter.export_coreml() diff --git a/yolo/utils/bounding_box_utils.py b/yolo/utils/bounding_box_utils.py index 0357bfdc..335aa59c 100644 --- a/yolo/utils/bounding_box_utils.py +++ b/yolo/utils/bounding_box_utils.py @@ -342,13 +342,14 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device): if hasattr(anchor_cfg, "strides"): logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}") self.strides = anchor_cfg.strides - else: + elif not model.export_mode: logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size") self.strides = self.create_auto_anchor(model, image_size) - anchor_grid, scaler = generate_anchors(image_size, self.strides) self.image_size = image_size - self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device) + if not model.export_mode: + anchor_grid, scaler = generate_anchors(image_size, self.strides) + self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device) def create_auto_anchor(self, model: YOLO, image_size): W, H = image_size diff --git a/yolo/utils/deploy_utils.py b/yolo/utils/deploy_utils.py index 4a0db991..5b28c86f 100644 --- a/yolo/utils/deploy_utils.py +++ b/yolo/utils/deploy_utils.py @@ -4,23 +4,30 @@ from torch import Tensor from yolo.config.config import Config -from yolo.model.yolo import create_model +from yolo.model.yolo import YOLO, create_model +from yolo.utils.export_utils import ModelExporter from yolo.utils.logger import logger class FastModelLoader: - def __init__(self, cfg: Config): + def __init__(self, cfg: Config, model: YOLO): self.cfg = cfg - self.compiler = cfg.task.fast_inference + self.model = model + self.compiler: str = cfg.task.fast_inference self.class_num = cfg.dataset.class_num self._validate_compiler() if cfg.weight == True: cfg.weight = Path("weights") / f"{cfg.model.name}.pt" - self.model_path = f"{Path(cfg.weight).stem}.{self.compiler}" + + extention = self.compiler + if self.compiler == "coreml": + extention = "mlpackage" + + self.model_path = f"{Path(cfg.weight).stem}.{extention}" def _validate_compiler(self): - if self.compiler not in ["onnx", "trt", "deploy"]: + if self.compiler not in ["onnx", "trt", "deploy", "coreml", "tflite"]: logger.warning(f":warning: Compiler '{self.compiler}' is not supported. Using original model.") self.compiler = None if self.cfg.device == "mps" and self.compiler == "trt": @@ -30,13 +37,63 @@ def _validate_compiler(self): def load_model(self, device): if self.compiler == "onnx": return self._load_onnx_model(device) + if self.compiler == "tflite": + return self._load_tflite_model(device) + elif self.compiler == "coreml": + return self._load_coreml_model(device) elif self.compiler == "trt": return self._load_trt_model().to(device) elif self.compiler == "deploy": self.cfg.model.model.auxiliary = {} return create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).to(device) + def _load_tflite_model(self, device): + + if not Path(self.model_path).exists(): + self._create_tflite_model() + + from ai_edge_litert.interpreter import Interpreter + + try: + interpreter = Interpreter(model_path=self.model_path) + interpreter.allocate_tensors() + logger.info(":rocket: Using TensorFlow Lite as MODEL framework!") + except Exception as e: + logger.warning(f"🈳 Error loading TFLite model: {e}") + interpreter = self._create_tflite_model() + + def tflite_forward(self: Interpreter, x: Tensor): + + # Get input & output tensor details + input_details = self.get_input_details() + output_details = sorted(self.get_output_details(), key=lambda d: d["name"]) # Sort by 'name' + + # Convert input tensor to NumPy and assign it to the model + x_numpy = x.cpu().numpy() + self.set_tensor(input_details[0]["index"], x_numpy) + + model_outputs, layer_output = [], [] + x_numpy = x.cpu().numpy() + self.set_tensor(input_details[0]["index"], x_numpy) + self.invoke() + for idx, output_detail in enumerate(output_details): + predict = self.get_tensor(output_detail["index"]) + layer_output.append(torch.from_numpy(predict).to(device)) + if idx % 3 == 2: + model_outputs.append(layer_output) + layer_output = [] + if len(model_outputs) == 6: + model_outputs = model_outputs[:3] + return {"Main": model_outputs} + + Interpreter.__call__ = tflite_forward + + return interpreter + def _load_onnx_model(self, device): + + # TODO install onnxruntime or onnxruntime-gpu if needed + from onnxruntime import InferenceSession def onnx_forward(self: InferenceSession, x: Tensor): @@ -55,6 +112,8 @@ def onnx_forward(self: InferenceSession, x: Tensor): if device == "cpu": providers = ["CPUExecutionProvider"] + elif device == "coreml": + providers = ["CoreMLExecutionProvider"] else: providers = ["CUDAExecutionProvider"] try: @@ -67,21 +126,47 @@ def onnx_forward(self: InferenceSession, x: Tensor): def _create_onnx_model(self, providers): from onnxruntime import InferenceSession - from torch.onnx import export - model = create_model(self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight).eval() - dummy_input = torch.ones((1, 3, *self.cfg.image_size)) - export( - model, - dummy_input, - self.model_path, - input_names=["input"], - output_names=["output"], - dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, - ) - logger.info(f":inbox_tray: ONNX model saved to {self.model_path}") + model_exporter = ModelExporter(self.cfg, self.model, format="onnx", model_path=self.model_path) + model_exporter.export_onnx(dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}) return InferenceSession(self.model_path, providers=providers) + def _load_coreml_model(self, device): + from coremltools import models + + def coreml_forward(self, x: Tensor): + x = x.cpu().numpy() + model_outputs = [] + predictions = self.predict({"x": x}) + + output_keys = ["preds_cls", "preds_anc", "preds_box"] + for key in output_keys: + model_outputs.append(torch.from_numpy(predictions[key]).to(device)) + + return model_outputs + + models.MLModel.__call__ = coreml_forward + + if not Path(self.model_path).exists(): + self._create_coreml_model() + + try: + model_coreml = models.MLModel(self.model_path) + logger.info(":rocket: Using CoreML as MODEL frameworks!") + except FileNotFoundError: + logger.warning(f"🈳 No found model weight at {self.model_path}") + return None + + return model_coreml + + def _create_tflite_model(self): + model_exporter = ModelExporter(self.cfg, self.model, format="tflite", model_path=self.model_path) + model_exporter.export_tflite() + + def _create_coreml_model(self): + model_exporter = ModelExporter(self.cfg, self.model, format="coreml", model_path=self.model_path) + model_exporter.export_coreml() + def _load_trt_model(self): from torch2trt import TRTModule diff --git a/yolo/utils/export_utils.py b/yolo/utils/export_utils.py new file mode 100644 index 00000000..906bc568 --- /dev/null +++ b/yolo/utils/export_utils.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import Dict, List, Optional + +from yolo.config.config import Config +from yolo.model.yolo import YOLO +from yolo.utils.logger import logger + + +class ModelExporter: + def __init__(self, cfg: Config, model: YOLO, format: str, model_path: Optional[str] = None): + self.model = model + self.cfg = cfg + self.class_num = cfg.dataset.class_num + self.format = format + if cfg.weight == True: + cfg.weight = Path("weights") / f"{cfg.model.name}.pt" + + if model_path: + self.model_path = model_path + else: + extention = self.format + if self.format == "coreml": + extention = "mlpackage" + + self.model_path = f"{Path(self.cfg.weight).stem}.{extention}" + + self.output_names: List[str] = [ + "1_class_scores_small", + "2_box_features_small", + "3_bbox_deltas_small", + "4_class_scores_medium", + "5_box_features_medium", + "6_bbox_deltas_medium", + "7_class_scores_large", + "8_box_features_large", + "9_bbox_deltas_large", + ] + + def export_onnx(self, dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None, model_path: Optional[str] = None): + logger.info(f":package: Exporting model to onnx format") + import torch + + dummy_input = torch.ones((1, 3, *self.cfg.image_size)) + + if model_path: + onnx_model_path = model_path + else: + onnx_model_path = self.model_path + + torch.onnx.export( + self.model, + dummy_input, + onnx_model_path, + input_names=["input"], + output_names=self.output_names, + dynamic_axes=dynamic_axes, + ) + + logger.info(f":inbox_tray: ONNX model saved to {onnx_model_path}") + + return onnx_model_path + + def export_tflite(self): + logger.info(f":package: Exporting model to tflite format") + + import torch + + self.model.eval() + example_inputs = (torch.rand(1, 3, *self.cfg.image_size),) + + import ai_edge_torch + + edge_model = ai_edge_torch.convert(self.model, example_inputs) + edge_model.export(self.model_path) + + logger.info(f":white_check_mark: Model exported to tflite format") + + def export_coreml(self): + logger.info(f":package: Exporting model to coreml format") + + import torch + + self.model.eval() + example_inputs = (torch.rand(1, 3, *self.cfg.image_size),) + exported_program = torch.export.export(self.model, example_inputs) + + import logging + + import coremltools as ct + + # Convert to Core ML program using the Unified Conversion API. + logging.getLogger("coremltools").disabled = True + + self.output_names: List[str] = ["preds_cls", "preds_anc", "preds_box"] + + model_from_export = ct.convert( + exported_program, outputs=[ct.TensorType(name=name) for name in self.output_names], convert_to="mlprogram" + ) + + model_from_export.save(self.model_path) + logger.info(f":white_check_mark: Model exported to coreml format {self.model_path}") diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 9d6c0ce5..a6d4bc65 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -173,9 +173,15 @@ def __call__( ) -> List[Tensor]: if image_size is not None: self.converter.update(image_size) - prediction = self.converter(predict["Main"]) + + if isinstance(predict, dict): + prediction = self.converter(predict["Main"]) + else: + prediction = predict + pred_class, _, pred_bbox = prediction[:3] pred_conf = prediction[3] if len(prediction) == 4 else None + if rev_tensor is not None: pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None] pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf)