From e47e6d917aeffef79ad818d09e2c20905d791100 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Mon, 27 Jan 2025 14:29:00 +0100 Subject: [PATCH 1/2] engine refactor Signed-off-by: Ashwin Vaidya --- src/otx/backend/__init__.py | 4 + src/otx/backend/native/__init__.py | 4 + src/otx/backend/native/engine/__init__.py | 8 + .../native}/engine/adaptive_bs/__init__.py | 0 .../engine/adaptive_bs/adaptive_bs_api.py | 0 .../engine/adaptive_bs/bs_search_algo.py | 0 src/otx/backend/native/engine/engine.py | 1206 ++++++++++++++++ .../native}/engine/hpo/__init__.py | 0 .../native}/engine/hpo/hpo_api.py | 4 +- .../native}/engine/hpo/hpo_trial.py | 0 .../{ => backend/native}/engine/hpo/utils.py | 0 .../native}/engine/utils/__init__.py | 0 .../{ => backend/native}/engine/utils/api.py | 0 .../native}/engine/utils/auto_configurator.py | 0 src/otx/engine/__init__.py | 18 + src/otx/engine/engine.py | 1227 +---------------- src/otx/types.py | 14 + 17 files changed, 1290 insertions(+), 1195 deletions(-) create mode 100644 src/otx/backend/__init__.py create mode 100644 src/otx/backend/native/__init__.py create mode 100644 src/otx/backend/native/engine/__init__.py rename src/otx/{ => backend/native}/engine/adaptive_bs/__init__.py (100%) rename src/otx/{ => backend/native}/engine/adaptive_bs/adaptive_bs_api.py (100%) rename src/otx/{ => backend/native}/engine/adaptive_bs/bs_search_algo.py (100%) create mode 100644 src/otx/backend/native/engine/engine.py rename src/otx/{ => backend/native}/engine/hpo/__init__.py (100%) rename src/otx/{ => backend/native}/engine/hpo/hpo_api.py (99%) rename src/otx/{ => backend/native}/engine/hpo/hpo_trial.py (100%) rename src/otx/{ => backend/native}/engine/hpo/utils.py (100%) rename src/otx/{ => backend/native}/engine/utils/__init__.py (100%) rename src/otx/{ => backend/native}/engine/utils/api.py (100%) rename src/otx/{ => backend/native}/engine/utils/auto_configurator.py (100%) create mode 100644 src/otx/types.py diff --git a/src/otx/backend/__init__.py b/src/otx/backend/__init__.py new file mode 100644 index 00000000000..a9ab4f608bc --- /dev/null +++ b/src/otx/backend/__init__.py @@ -0,0 +1,4 @@ +"""OTX backends.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/otx/backend/native/__init__.py b/src/otx/backend/native/__init__.py new file mode 100644 index 00000000000..7b25a86d2b1 --- /dev/null +++ b/src/otx/backend/native/__init__.py @@ -0,0 +1,4 @@ +"""Native backend.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/otx/backend/native/engine/__init__.py b/src/otx/backend/native/engine/__init__.py new file mode 100644 index 00000000000..54f0aa601ce --- /dev/null +++ b/src/otx/backend/native/engine/__init__.py @@ -0,0 +1,8 @@ +"""Native engine.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .engine import NativeEngine + +__all__ = ["NativeEngine"] diff --git a/src/otx/engine/adaptive_bs/__init__.py b/src/otx/backend/native/engine/adaptive_bs/__init__.py similarity index 100% rename from src/otx/engine/adaptive_bs/__init__.py rename to src/otx/backend/native/engine/adaptive_bs/__init__.py diff --git a/src/otx/engine/adaptive_bs/adaptive_bs_api.py b/src/otx/backend/native/engine/adaptive_bs/adaptive_bs_api.py similarity index 100% rename from src/otx/engine/adaptive_bs/adaptive_bs_api.py rename to src/otx/backend/native/engine/adaptive_bs/adaptive_bs_api.py diff --git a/src/otx/engine/adaptive_bs/bs_search_algo.py b/src/otx/backend/native/engine/adaptive_bs/bs_search_algo.py similarity index 100% rename from src/otx/engine/adaptive_bs/bs_search_algo.py rename to src/otx/backend/native/engine/adaptive_bs/bs_search_algo.py diff --git a/src/otx/backend/native/engine/engine.py b/src/otx/backend/native/engine/engine.py new file mode 100644 index 00000000000..83079aa7fd1 --- /dev/null +++ b/src/otx/backend/native/engine/engine.py @@ -0,0 +1,1206 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# +"""Module for OTX engine components.""" + +from __future__ import annotations + +import copy +import csv +import inspect +import logging +import tempfile +import time +from contextlib import contextmanager +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Iterator, Literal +from warnings import warn + +import torch +from lightning import Trainer, seed_everything +from lightning.pytorch.plugins.precision import MixedPrecision + +from otx.core.config.device import DeviceConfig +from otx.core.config.explain import ExplainConfig +from otx.core.config.hpo import HpoConfig +from otx.core.data.module import OTXDataModule +from otx.core.model.base import OTXModel, OVModel +from otx.core.types import PathLike +from otx.core.types.device import DeviceType +from otx.core.types.export import OTXExportFormatType +from otx.core.types.precision import OTXPrecisionType +from otx.core.types.task import OTXTaskType +from otx.core.utils.cache import TrainerArgumentsCache +from otx.engine.engine import Engine +from otx.utils.device import is_xpu_available +from otx.utils.utils import measure_flops + +from .adaptive_bs import adapt_batch_size +from .hpo import execute_hpo, update_hyper_parameter +from .utils.auto_configurator import DEFAULT_CONFIG_PER_TASK + +if TYPE_CHECKING: + from lightning import Callback + from lightning.pytorch.loggers import Logger + from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS + from pytorch_lightning.trainer.connectors.accelerator_connector import _PRECISION_INPUT + + from otx.core.metrics import MetricCallable + from otx.types import DATA, MODEL + + +@contextmanager +def override_metric_callable(model: OTXModel, new_metric_callable: MetricCallable | None) -> Iterator[OTXModel]: + """Override `OTXModel.metric_callable` to change the evaluation metric. + + Args: + model: Model to override its metric callable + new_metric_callable: If not None, override the model's one with this. Otherwise, do not override. + """ + if new_metric_callable is None: + yield model + return + + orig_metric_callable = model.metric_callable + try: + model.metric_callable = new_metric_callable + yield model + finally: + model.metric_callable = orig_metric_callable + + +class NativeEngine(Engine): + """Native Engine.""" + + """OTX Engine. + + This class defines the Engine for OTX, which governs each step of the OTX workflow. + + Example: + The following examples show how to use the Engine class. + + Auto-Configuration with data_root:: + + engine = Engine( + data_root=, + ) + + Create Engine with Custom OTXModel:: + + engine = Engine( + data_root=, + model=OTXModel(...), + checkpoint=, + ) + + Create Engine with Custom OTXDataModule:: + + engine = Engine( + model = OTXModel(...), + datamodule = OTXDataModule(...), + ) + """ + + _EXPORTED_MODEL_BASE_NAME: ClassVar[str] = "exported_model" + + def __init__( + self, + *, + model: OTXModel | str | None = None, + data: OTXDataModule | None = None, + work_dir: PathLike = "./otx-workspace", + checkpoint: PathLike | None = None, + device: DeviceType = DeviceType.auto, + num_devices: int = 1, + **kwargs, + ): + """Initializes the OTX Engine. + + Args: + data_root (PathLike | None, optional): Root directory for the data. Defaults to None. + task (OTXTaskType | None, optional): The type of OTX task. Defaults to None. + work_dir (PathLike, optional): Working directory for the engine. Defaults to "./otx-workspace". + datamodule (OTXDataModule | None, optional): The data module for the engine. Defaults to None. + model (OTXModel | str | None, optional): The model for the engine. Defaults to None. + checkpoint (PathLike | None, optional): Path to the checkpoint file. Defaults to None. + device (DeviceType, optional): The device type to use. Defaults to DeviceType.auto. + num_devices (int, optional): The number of devices to use. If it is 2 or more, it will behave as multi-gpu. + **kwargs: Additional keyword arguments for pl.Trainer. + """ + self._cache = TrainerArgumentsCache(**kwargs) + self.checkpoint = checkpoint + self.work_dir = work_dir + self.device = device # type: ignore[assignment] + self.num_devices = num_devices + + self._datamodule: OTXDataModule | None = data + self.task = data.task if data is not None else self._auto_configurator.task + + self._trainer: Trainer | None = None + get_model_args: dict[str, Any] = {} + if data is not None: + get_model_args["label_info"] = data.label_info + if (input_size := data.input_size) is not None: + get_model_args["input_size"] = (input_size, input_size) if isinstance(input_size, int) else input_size + self._model: OTXModel = model + + # ------------------------------------------------------------------------ # + # General OTX Entry Points + # ------------------------------------------------------------------------ # + + def train( + self, + max_epochs: int = 10, + seed: int | None = None, + deterministic: bool | Literal["warn"] = False, + precision: _PRECISION_INPUT | None = "32", + val_check_interval: int | float | None = None, + callbacks: list[Callback] | Callback | None = None, + logger: Logger | Iterable[Logger] | bool | None = None, + resume: bool = False, + metric: MetricCallable | None = None, + run_hpo: bool = False, + hpo_config: HpoConfig = HpoConfig(), # noqa: B008 https://github.com/omni-us/jsonargparse/issues/423 + checkpoint: PathLike | None = None, + adaptive_bs: Literal["None", "Safe", "Full"] = "None", + **kwargs, + ) -> dict[str, Any]: + r"""Trains the model using the provided LightningModule and OTXDataModule. + + Args: + max_epochs (int | None, optional): The maximum number of epochs. Defaults to None. + seed (int | None, optional): The random seed. Defaults to None. + deterministic (bool | Literal["warn"]): Whether to enable deterministic behavior. + Also, can be set to `warn` to avoid failures, because some operations don't + support deterministic mode. Defaults to False. + precision (_PRECISION_INPUT | None, optional): The precision of the model. Defaults to 32. + val_check_interval (int | float | None, optional): The validation check interval. Defaults to None. + callbacks (list[Callback] | Callback | None, optional): The callbacks to be used during training. + logger (Logger | Iterable[Logger] | bool | None, optional): The logger(s) to be used. Defaults to None. + resume (bool, optional): If True, tries to resume training from existing checkpoint. + metric (MetricCallable | None): If not None, it will override `OTXModel.metric_callable` with the given + metric callable. It will temporarilly change the evaluation metric for the validation and test. + run_hpo (bool, optional): If True, optimizer hyper parameters before training a model. + hpo_config (HpoConfig | None, optional): Configuration for HPO. + checkpoint (PathLike | None, optional): Path to the checkpoint file. Defaults to None. + adaptive_bs (Literal["None", "Safe", "Full"]): + Change the actual batch size depending on the current GPU status. + Safe => Prevent GPU out of memory. Full => Find a batch size using most of GPU memory. + **kwargs: Additional keyword arguments for pl.Trainer configuration. + + Returns: + dict[str, Any]: A dictionary containing the callback metrics from the trainer. + + Example: + >>> engine.train( + ... max_epochs=3, + ... seed=1234, + ... deterministic=False, + ... precision="32", + ... ) + + CLI Usage: + 1. Can train with data_root only. then OTX will provide default training configuration. + ```shell + >>> otx train --data_root + ``` + 2. Can pick a model or datamodule as Config file or Class. + ```shell + >>> otx train \ + ... --data_root \ + ... --model \ + ... --data + ``` + 3. Of course, can override the various values with commands. + ```shell + >>> otx train \ + ... --data_root \ + ... --max_epochs \ + ... --checkpoint + ``` + 4. To train with configuration file, run + ```shell + >>> otx train --data_root --config + ``` + 5. To reproduce the existing training with work_dir, run + ```shell + >>> otx train --work_dir + ``` + """ + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + + if adaptive_bs != "None": + adapt_batch_size(engine=self, **locals(), not_increase=(adaptive_bs != "Full")) + + if run_hpo: + best_config, best_trial_weight = execute_hpo(engine=self, **locals()) + if best_config is not None: + update_hyper_parameter(self, best_config) + if best_trial_weight is not None: + checkpoint = best_trial_weight + resume = True + + if seed is not None: + seed_everything(seed, workers=True) + + self._build_trainer( + logger=logger, + callbacks=callbacks, + precision=precision, + max_epochs=max_epochs, + deterministic=deterministic, + val_check_interval=val_check_interval, + **kwargs, + ) + fit_kwargs: dict[str, Any] = {} + + # NOTE: Model's label info should be converted datamodule's label info before ckpt loading + # This is due to smart weight loading check label name as well as number of classes. + if self.model.label_info != self.datamodule.label_info: + msg = ( + "Model label_info is not equal to the Datamodule label_info. " + f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" + ) + logging.warning(msg) + self.model.label_info = self.datamodule.label_info + + if resume and checkpoint: + # NOTE: If both `resume` and `checkpoint` are provided, + # load the entire model state from the checkpoint using the pl.Trainer's API. + fit_kwargs["ckpt_path"] = checkpoint + elif not resume and checkpoint: + # NOTE: If `resume` is not enabled but `checkpoint` is provided, + # load the model state from the checkpoint incrementally. + # This means only the model weights are loaded. If there is a mismatch in label_info, + # perform incremental weight loading for the model's classification layer. + ckpt = torch.load(checkpoint) + self.model.load_state_dict_incrementally(ckpt) + + with override_metric_callable(model=self.model, new_metric_callable=metric) as model: + self.trainer.fit( + model=model, + datamodule=self.datamodule, + **fit_kwargs, + ) + self.checkpoint = self.trainer.checkpoint_callback.best_model_path + + if not isinstance(self.checkpoint, (Path, str)): + msg = "self.checkpoint should be Path or str at this time." + raise TypeError(msg) + + best_checkpoint_symlink = Path(self.work_dir) / "best_checkpoint.ckpt" + if best_checkpoint_symlink.is_symlink(): + best_checkpoint_symlink.unlink() + best_checkpoint_symlink.symlink_to(self.checkpoint) + + return self.trainer.callback_metrics + + def test( + self, + checkpoint: PathLike | None = None, + datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, + metric: MetricCallable | None = None, + **kwargs, + ) -> dict: + r"""Run the testing phase of the engine. + + Args: + checkpoint (PathLike | None, optional): Path to the checkpoint file to load the model from. + Defaults to None. + datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module containing the test data. + metric (MetricCallable | None): If not None, it will override `OTXModel.metric_callable` with the given + metric callable. It will temporarilly change the evaluation metric for the validation and test. + **kwargs: Additional keyword arguments for pl.Trainer configuration. + + Returns: + dict: Dictionary containing the callback metrics from the trainer. + + Example: + >>> engine.test( + ... datamodule=OTXDataModule(), + ... checkpoint=, + ... ) + + CLI Usage: + 1. To eval model by specifying the work_dir where did the training, run + ```shell + >>> otx test --work_dir + ``` + 2. To eval model a specific checkpoint, run + ```shell + >>> otx test --work_dir --checkpoint + ``` + 3. Can pick a model. + ```shell + >>> otx test \ + ... --model \ + ... --data_root \ + ... --checkpoint + ``` + 4. To eval with configuration file, run + ```shell + >>> otx test --config --checkpoint + ``` + """ + model = self.model + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + datamodule = datamodule if datamodule is not None else self.datamodule + + is_ir_ckpt = Path(str(checkpoint)).suffix in [".xml", ".onnx"] + if is_ir_ckpt and not isinstance(model, OVModel): + model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) + if self.device.accelerator != "cpu": + msg = "IR model supports inference only on CPU device. The device is changed automatic." + warn(msg, stacklevel=1) + self.device = DeviceType.cpu # type: ignore[assignment] + + # NOTE: Re-initiate datamodule without tiling as model API supports its own tiling mechanism + if isinstance(model, OVModel): + datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") + + # NOTE, trainer.test takes only lightning based checkpoint. + # So, it can't take the OTX1.x checkpoint. + if checkpoint is not None and not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + + model_cls = model.__class__ + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) + + if model.label_info != self.datamodule.label_info: + if ( + self.task == "SEMANTIC_SEGMENTATION" + and "otx_background_lbl" in self.datamodule.label_info.label_names + and (len(self.datamodule.label_info.label_names) - len(model.label_info.label_names) == 1) + ): + # workaround for background label + model.label_info = copy.deepcopy(self.datamodule.label_info) + else: + msg = ( + "To launch a test pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) + + self._build_trainer(**kwargs) + + with override_metric_callable(model=model, new_metric_callable=metric) as model: + self.trainer.test( + model=model, + dataloaders=datamodule, + ) + + return self.trainer.callback_metrics + + def predict( + self, + checkpoint: PathLike | None = None, + datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, + return_predictions: bool | None = None, + explain: bool = False, + explain_config: ExplainConfig | None = None, + **kwargs, + ) -> list | None: + r"""Run predictions using the specified model and data. + + Args: + checkpoint (PathLike | None, optional): The path to the checkpoint file to load the model from. + datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module to use for predictions. + return_predictions (bool | None, optional): Whether to return the predictions or not. + explain (bool, optional): Whether to dump "saliency_map" and "feature_vector" or not. + explain_config (ExplainConfig | None, optional): Explain configuration used for saliency map post-processing + **kwargs: Additional keyword arguments for pl.Trainer configuration. + + Returns: + list | None: The predictions if `return_predictions` is True, otherwise None. + + Example: + >>> engine.predict( + ... datamodule=OTXDataModule(), + ... checkpoint=, + ... return_predictions=True, + ... explain=True, + ... ) + + CLI Usage: + 1. To predict a model with work_dir, run + ```shell + >>> otx predict --work_dir + ``` + 2. To predict a specific model, run + ```shell + >>> otx predict \ + ... --work_dir \ + ... --checkpoint + ``` + 3. To predict with configuration file, run + ```shell + >>> otx predict \ + ... --config \ + ... --checkpoint + ``` + """ + from otx.algo.utils.xai_utils import process_saliency_maps_in_pred_entity, set_crop_padded_map_flag + + model = self.model + + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + datamodule = datamodule if datamodule is not None else self.datamodule + + is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"] + if is_ir_ckpt and not isinstance(model, OVModel): + model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) + + # NOTE: Re-initiate datamodule for OVModel as model API supports its own data pipeline. + if isinstance(model, OVModel): + datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") + + if checkpoint is not None and not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + + model_cls = model.__class__ + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) + + if model.label_info != self.datamodule.label_info: + msg = ( + "To launch a predict pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) + + self._build_trainer(**kwargs) + + curr_explain_mode = model.explain_mode + + try: + model.explain_mode = explain + predict_result = self.trainer.predict( + model=model, + dataloaders=datamodule, + return_predictions=return_predictions, + ) + finally: + model.explain_mode = curr_explain_mode + + if explain: + if explain_config is None: + explain_config = ExplainConfig() + explain_config = set_crop_padded_map_flag(explain_config, datamodule) + + predict_result = process_saliency_maps_in_pred_entity(predict_result, explain_config, datamodule.label_info) + + return predict_result + + def export( + self, + checkpoint: PathLike | None = None, + export_format: OTXExportFormatType = OTXExportFormatType.OPENVINO, + export_precision: OTXPrecisionType = OTXPrecisionType.FP32, + explain: bool = False, + export_demo_package: bool = False, + ) -> Path: + r"""Export the trained model to OpenVINO Intermediate Representation (IR) or ONNX formats. + + Args: + checkpoint (PathLike | None, optional): Checkpoint to export. Defaults to None. + export_config (ExportConfig | None, optional): Config that allows to set export + format and precision. Defaults to None. + explain (bool): Whether to get "saliency_map" and "feature_vector" or not. + export_demo_package (bool): Whether to export demo package with the model. + Only OpenVINO model can be exported with demo package. + + Returns: + Path: Path to the exported model. + + Example: + >>> engine.export( + ... checkpoint=, + ... export_format=OTXExportFormatType.OPENVINO, + ... export_precision=OTXExportPrecisionType.FP32, + ... explain=True, + ... ) + + CLI Usage: + 1. To export a model with default setting (OPENVINO, FP32), run + ```shell + >>> otx export --work_dir + ``` + 2. To export a specific checkpoint, run + ```shell + >>> otx export --config --checkpoint + ``` + 3. To export a model with precision FP16 and format ONNX, run + ```shell + >>> otx export ... \ + ... --export_precision FP16 --export_format ONNX + ``` + 4. To export model with 'saliency_map' and 'feature_vector', run + ```shell + >>> otx export ... \ + ... --explain True + ``` + """ + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + + if checkpoint is None: + msg = "To make export, checkpoint must be specified." + raise RuntimeError(msg) + is_ir_ckpt = Path(checkpoint).suffix in [".xml"] + if export_demo_package and export_format == OTXExportFormatType.ONNX: + msg = ( + "ONNX export is not supported in exportable code mode. " + "Exportable code parameter will be disregarded. " + ) + warn(msg, stacklevel=1) + export_demo_package = False + + if is_ir_ckpt and not export_demo_package: + msg = "IR model is passed as a checkpoint, export automatically switched to exportable code." + warn(msg, stacklevel=1) + export_demo_package = True + + if is_ir_ckpt and not isinstance(self.model, OVModel): + # create OVModel + self.model = self._auto_configurator.get_ov_model( + model_name=str(checkpoint), + label_info=self.datamodule.label_info, + ) + + if not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + + model_cls = self.model.__class__ + self.model = model_cls.load_from_checkpoint( + checkpoint_path=checkpoint, + map_location="cpu", + **kwargs_user_input, + ) + self.model.eval() + + self.model.explain_mode = explain + exported_model_path = self.model.export( + output_dir=Path(self.work_dir), + base_name=self._EXPORTED_MODEL_BASE_NAME, + export_format=export_format, + precision=export_precision, + to_exportable_code=export_demo_package, + ) + + self.model.explain_mode = False + return exported_model_path + + def optimize( + self, + checkpoint: PathLike | None = None, + datamodule: TRAIN_DATALOADERS | OTXDataModule | None = None, + max_data_subset_size: int | None = None, + export_demo_package: bool = False, + ) -> Path: + r"""Applies NNCF.PTQ to the underlying models (now works only for OV models). + + PTQ performs int-8 quantization on the input model, so the resulting model + comes in mixed precision (some operations, however, remain in FP32). + + Args: + checkpoint (str | Path | None, optional): Checkpoint to optimize. Defaults to None. + datamodule (TRAIN_DATALOADERS | OTXDataModule | None, optional): The data module to use for optimization. + max_data_subset_size (int | None): The maximum size of the train subset from `datamodule` that would be + used for model optimization. If not set, NNCF.PTQ will select subset size according to it's + default settings. + export_demo_package (bool): Whether to export demo package with optimized models. + It outputs zip archive with stand-alone demo package. + + Returns: + Path: path to the optimized model. + + Example: + >>> engine.optimize( + ... checkpoint=, + ... datamodule=OTXDataModule(), + ... checkpoint=, + ... ) + + CLI Usage: + 1. To optimize a model with IR Model, run + ```shell + >>> otx optimize \ + ... --work_dir \ + ... --checkpoint + ``` + 2. To optimize a specific OVModel class with XML, run + ```shell + >>> otx optimize \ + ... --data_root \ + ... --checkpoint \ + ... --model \ + ... --model.model_name= + ``` + """ + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + optimize_datamodule = datamodule if datamodule is not None else self.datamodule + + is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"] + if not is_ir_ckpt: + msg = "Engine.optimize() supports only OV IR or ONNX checkpoints" + raise RuntimeError(msg) + + model = self.model + if not isinstance(model, OVModel): + optimize_datamodule = self._auto_configurator.update_ov_subset_pipeline( + datamodule=optimize_datamodule, + subset="train", + ) + model = self._auto_configurator.get_ov_model( + model_name=str(checkpoint), + label_info=optimize_datamodule.label_info, + ) + + ptq_config = {} + if max_data_subset_size is not None: + ptq_config["subset_size"] = max_data_subset_size + + if not export_demo_package: + return model.optimize( + Path(self.work_dir), + optimize_datamodule, + ptq_config, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_model_path = model.optimize(Path(tmp_dir), optimize_datamodule, ptq_config) + return self.export( + checkpoint=tmp_model_path, + export_demo_package=True, + ) + + def explain( + self, + checkpoint: PathLike | None = None, + datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, + explain_config: ExplainConfig | None = None, + dump: bool | None = False, + **kwargs, + ) -> list | None: + r"""Run XAI using the specified model and data (test subset). + + Args: + checkpoint (PathLike | None, optional): The path to the checkpoint file to load the model from. + datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module to use for predictions. + explain_config (ExplainConfig | None, optional): Config used to handle saliency maps. + dump (bool): Whether to dump "saliency_map" or not. + **kwargs: Additional keyword arguments for pl.Trainer configuration. + + Returns: + list: Saliency maps. + + Example: + >>> engine.explain( + ... datamodule=OTXDataModule(), + ... checkpoint=, + ... explain_config=ExplainConfig(), + ... dump=True, + ... ) + + CLI Usage: + 1. To run XAI with the torch model in work_dir, run + ```shell + >>> otx explain \ + ... --work_dir + ``` + 2. To run XAI using the specified model (torch or IR), run + ```shell + >>> otx explain \ + ... --work_dir \ + ... --checkpoint + ``` + 3. To run XAI using the configuration, run + ```shell + >>> otx explain \ + ... --config --data_root \ + ... --checkpoint + ``` + """ + from otx.algo.utils.xai_utils import ( + dump_saliency_maps, + process_saliency_maps_in_pred_entity, + set_crop_padded_map_flag, + ) + + model = self.model + + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + datamodule = datamodule if datamodule is not None else self.datamodule + + is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"] + if is_ir_ckpt and not isinstance(model, OVModel): + datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") + model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) + + if checkpoint is not None and not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + + model_cls = model.__class__ + model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) + + if model.label_info != self.datamodule.label_info: + msg = ( + "To launch a explain pipeline, the label information should be same " + "between the training and testing datasets. " + "Please check whether you use the same dataset: " + f"model.label_info={model.label_info}, " + f"datamodule.label_info={self.datamodule.label_info}" + ) + raise ValueError(msg) + + model.explain_mode = True + + self._build_trainer(**kwargs) + + predict_result = self.trainer.predict( + model=model, + datamodule=datamodule, + ) + + if explain_config is None: + explain_config = ExplainConfig() + explain_config = set_crop_padded_map_flag(explain_config, datamodule) + + predict_result = process_saliency_maps_in_pred_entity(predict_result, explain_config, datamodule.label_info) + if dump: + dump_saliency_maps( + predict_result, + explain_config, + datamodule, + output_dir=Path(self.work_dir), + ) + model.explain_mode = False + return predict_result + + def benchmark( + self, + checkpoint: PathLike | None = None, + batch_size: int = 1, + n_iters: int = 10, + extended_stats: bool = False, + print_table: bool = True, + ) -> dict[str, str]: + r"""Executes model micro benchmarking on random data. + + Benchmark can provide latency, throughput, number of parameters, + and theoretical computational complexity with batch size 1. + The latter two characteristics are available for torch model recipes only. + Before the measurements, a warm-up is done. + + Args: + checkpoint (PathLike | None, optional): Path to checkpoint. Optional for torch models. Defaults to None. + batch_size (int, optional): Batch size for benchmarking. Defaults to 1. + n_iters (int, optional): Number of iterations to average on. Defaults to 10. + extended_stats (bool, optional): Flag that enables printing of per module complexity for torch model. + Defaults to False. + print_table (bool, optional): Flag that enables printing the benchmark results in a rich table. + Defaults to True. + + Returns: + dict[str, str]: a dict with the benchmark results. + + Example: + >>> engine.benchmark( + ... checkpoint=, + ... batch_size=1, + ... n_iters=20, + ... extended_stats=True, + ... ) + + CLI Usage: + 1. To run benchmark by specifying the work_dir where did the training, run + ```shell + >>> otx benchmark --work_dir + ``` + 2. To run benchmark by specifying the checkpoint, run + ```shell + >>> otx benchmark \ + ... --work_dir \ + ... --checkpoint + ``` + 3. To run benchmark using the configuration, launch + ```shell + >>> otx benchmark \ + ... --config \ + ... --data_root \ + ... --checkpoint + ``` + """ + checkpoint = checkpoint if checkpoint is not None else self.checkpoint + + if checkpoint is not None: + is_ir_ckpt = Path(checkpoint).suffix in [".xml"] + if is_ir_ckpt and not isinstance(self.model, OVModel): + # create OVModel + self.model = self._auto_configurator.get_ov_model( + model_name=str(checkpoint), + label_info=self.datamodule.label_info, + ) + + if not is_ir_ckpt: + kwargs_user_input: dict[str, Any] = {} + if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: + # to update user's custom infer_reference_info_root through cli for zero-shot learning + # TODO (sungchul): revisit for better solution + kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) + + model_cls = self.model.__class__ + self.model = model_cls.load_from_checkpoint( + checkpoint_path=checkpoint, + map_location="cpu", + **kwargs_user_input, + ) + elif isinstance(self.model, OVModel): + msg = "To run benchmark on OV model, checkpoint must be specified." + raise RuntimeError(msg) + + self.model.eval() + + def dummy_infer(model: OTXModel, batch_size: int = 1) -> float: + input_batch = model.get_dummy_input(batch_size) + start = time.perf_counter() + model.forward(input_batch) + end = time.perf_counter() + return end - start + + warmup_iters = max(1, int(n_iters / 10)) + for _ in range(warmup_iters): + dummy_infer(self.model, batch_size) + + total_time = 0.0 + for _ in range(n_iters): + total_time += dummy_infer(self.model, batch_size) + latency = total_time / n_iters + fps = batch_size / latency + + final_stats = {"latency": f"{latency:.3f} s", "throughput": f"{(fps):.3f} FPS"} + + if not isinstance(self.model, OVModel): + try: + from torch.utils.flop_counter import convert_num_with_suffix, get_suffix_str + + input_batch = self.model.get_dummy_input(1) + model_fwd = lambda: self.model.forward(input_batch) + depth = 3 if extended_stats else 0 + fwd_flops = measure_flops(model_fwd, print_stats_depth=depth) + flops_str = convert_num_with_suffix(fwd_flops, get_suffix_str(fwd_flops * 10**3)) + final_stats["complexity"] = flops_str + " MACs" + except Exception as e: + logging.warning(f"Failed to complete complexity estimation: {e}") + + params_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + params_num_str = convert_num_with_suffix(params_num, get_suffix_str(params_num * 100)) + final_stats["parameters_number"] = params_num_str + + if print_table: + from rich.console import Console + from rich.table import Column, Table + + console = Console() + table_headers = ["Benchmark", "Value"] + columns = [Column(h, justify="center", style="magenta", width=console.width) for h in table_headers] + columns[0].style = "cyan" + table = Table(*columns) + for name, val in final_stats.items(): + table.add_row(*[f"{name:<20}", f"{val}"]) + console.print(table) + + with (Path(self.work_dir) / "benchmark_report.csv").open("w") as f: + writer = csv.writer(f) + writer.writerow(list(final_stats)) + writer.writerow(list(final_stats.values())) + + return final_stats + + @classmethod + def from_config( + cls, + config_path: PathLike, + data_root: PathLike | None = None, + work_dir: PathLike | None = None, + **kwargs, + ) -> Engine: + """Builds the engine from a configuration file. + + Args: + config_path (PathLike): The configuration file path. + data_root (PathLike | None): Root directory for the data. + Defaults to None. If data_root is None, use the data_root from the configuration file. + work_dir (PathLike | None, optional): Working directory for the engine. + Defaults to None. If work_dir is None, use the work_dir from the configuration file. + kwargs: Arguments that can override the engine's arguments. + + Returns: + Engine: An instance of the Engine class. + + Example: + >>> engine = Engine.from_config( + ... config="config.yaml", + ... ) + """ + from otx.cli.utils.jsonargparse import get_instantiated_classes + + # For the Engine argument, prepend 'engine.' for CLI parser + filter_kwargs = ["device", "checkpoint", "task"] + for key in filter_kwargs: + if key in kwargs: + kwargs[f"engine.{key}"] = kwargs.pop(key) + instantiated_config, train_kwargs = get_instantiated_classes( + config=config_path, + data_root=data_root, + work_dir=work_dir, + **kwargs, + ) + engine_kwargs = {**instantiated_config.get("engine", {}), **train_kwargs} + + # Remove any input that is not currently available in Engine and print a warning message. + set_valid_args = TrainerArgumentsCache.get_trainer_constructor_args().union( + set(inspect.signature(Engine.__init__).parameters.keys()), + ) + removed_args = [] + for engine_key in list(engine_kwargs.keys()): + if engine_key not in set_valid_args: + engine_kwargs.pop(engine_key) + removed_args.append(engine_key) + if removed_args: + msg = ( + f"Warning: {removed_args} -> not available in Engine constructor. " + "It will be ignored. Use what need in the right places." + ) + warn(msg, stacklevel=1) + + if (datamodule := instantiated_config.get("data")) is None: + msg = "Cannot instantiate datamodule from config." + raise ValueError(msg) + if not isinstance(datamodule, OTXDataModule): + raise TypeError(datamodule) + + if (model := instantiated_config.get("model")) is None: + msg = "Cannot instantiate model from config." + raise ValueError(msg) + if not isinstance(model, OTXModel): + raise TypeError(model) + + model.label_info = datamodule.label_info + + return cls( + work_dir=instantiated_config.get("work_dir", work_dir), + datamodule=datamodule, + model=model, + **engine_kwargs, + ) + + @classmethod + def from_model_name( + cls, + model_name: str, + task: OTXTaskType, + data_root: PathLike | None = None, + work_dir: PathLike | None = None, + **kwargs, + ) -> Engine: + """Builds the engine from a model name. + + Args: + model_name (str): The model name. + task (OTXTaskType): The type of OTX task. + data_root (PathLike | None): Root directory for the data. + Defaults to None. If data_root is None, use the data_root from the configuration file. + work_dir (PathLike | None, optional): Working directory for the engine. + Defaults to None. If work_dir is None, use the work_dir from the configuration file. + kwargs: Arguments that can override the engine's arguments. + + Returns: + Engine: An instance of the Engine class. + + Example: + >>> engine = Engine.from_model_name( + ... model_name="atss_mobilenetv2", + ... task="DETECTION", + ... data_root=, + ... ) + + If you want to override configuration from default config: + >>> overriding = { + ... "data.train_subset.batch_size": 2, + ... "data.test_subset.subset_name": "TESTING", + ... } + >>> engine = Engine( + ... model_name="atss_mobilenetv2", + ... task="DETECTION", + ... data_root=, + ... **overriding, + ... ) + """ + default_config = DEFAULT_CONFIG_PER_TASK.get(task) + model_path = str(default_config).split("/") + model_path[-1] = f"{model_name}.yaml" + config = Path("/".join(model_path)) + if not config.exists(): + candidate_list = [model.stem for model in config.parent.glob("*")] + msg = ( + f"Model config file not found: {config}, please check the model name. " + f"Available models for {task} task are {candidate_list}" + ) + raise FileNotFoundError(msg) + + return cls.from_config( + config_path=config, + data_root=data_root, + work_dir=work_dir, + task=task, + **kwargs, + ) + + # ------------------------------------------------------------------------ # + # Property and setter functions provided by Engine. + # ------------------------------------------------------------------------ # + + @property + def work_dir(self) -> PathLike: + """Work directory.""" + return self._work_dir + + @work_dir.setter + def work_dir(self, work_dir: PathLike) -> None: + self._work_dir = work_dir + self._cache.update(default_root_dir=work_dir) + self._cache.is_trainer_args_identical = False + + @property + def device(self) -> DeviceConfig: + """Device engine uses.""" + return self._device + + @device.setter + def device(self, device: DeviceType) -> None: + if is_xpu_available() and device == DeviceType.auto: + device = DeviceType.xpu + self._device = DeviceConfig(accelerator=device) + self._cache.update(accelerator=self._device.accelerator, devices=self._device.devices) + self._cache.is_trainer_args_identical = False + + @property + def num_devices(self) -> int: + """Number of devices for Engine use.""" + return self._device.devices + + @num_devices.setter + def num_devices(self, num_devices: int) -> None: + """Setter function for multi-gpu.""" + self._device.devices = num_devices + self._cache.update(devices=self._device.devices) + self._cache.is_trainer_args_identical = False + + @property + def trainer(self) -> Trainer: + """Returns the trainer object associated with the engine. + + To get this property, you should execute `Engine.train()` function first. + + Returns: + Trainer: The trainer object. + """ + if self._trainer is None: + msg = "Please run train() first" + raise RuntimeError(msg) + return self._trainer + + def _build_trainer(self, **kwargs) -> None: + """Instantiate the trainer based on the model parameters.""" + if self._cache.requires_update(**kwargs) or self._trainer is None: + self._cache.update(**kwargs) + # set up xpu device + if self._device.accelerator == DeviceType.xpu: + self._cache.update(strategy="xpu_single") + # add plugin for Automatic Mixed Precision on XPU + if self._cache.args.get("precision", 32) == 16: + self._cache.update( + plugins=[ + MixedPrecision( + precision="bf16-mixed", + device="xpu", + ), + ], + ) + self._cache.args["precision"] = None + + kwargs = self._cache.args + self._trainer = Trainer(**kwargs) + self._cache.is_trainer_args_identical = True + self._trainer.task = self.task + self.work_dir = self._trainer.default_root_dir + + @property + def trainer_params(self) -> dict: + """Returns the parameters used for training the model. + + Returns: + dict: A dictionary containing the training parameters. + """ + return self._cache.args + + @property + def model(self) -> OTXModel: + """Returns the model object associated with the engine. + + Returns: + OTXModel: The OTXModel object. + """ + return self._model + + @model.setter + def model(self, model: OTXModel | str) -> None: + """Sets the model for the engine. + + Args: + model (OTXModel | str): The model to be set. + + Returns: + None + """ + if isinstance(model, str): + model = self._auto_configurator.get_model(model, label_info=self.datamodule.label_info) + self._model = model + + @property + def datamodule(self) -> OTXDataModule: + """Returns the datamodule object associated with the engine. + + Returns: + OTXDataModule: The OTXDataModule object. + """ + if self._datamodule is None: + msg = "Please include the `data_root` or `datamodule` when creating the Engine." + raise RuntimeError(msg) + return self._datamodule + + @staticmethod + def is_supported(model: MODEL, data: DATA) -> bool: + """Check if the engine is supported for the given model and data.""" + return bool(isinstance(model, OTXModel) and isinstance(data, OTXDataModule)) diff --git a/src/otx/engine/hpo/__init__.py b/src/otx/backend/native/engine/hpo/__init__.py similarity index 100% rename from src/otx/engine/hpo/__init__.py rename to src/otx/backend/native/engine/hpo/__init__.py diff --git a/src/otx/engine/hpo/hpo_api.py b/src/otx/backend/native/engine/hpo/hpo_api.py similarity index 99% rename from src/otx/engine/hpo/hpo_api.py rename to src/otx/backend/native/engine/hpo/hpo_api.py index 3963e3ccd34..28e20c75b32 100644 --- a/src/otx/engine/hpo/hpo_api.py +++ b/src/otx/backend/native/engine/hpo/hpo_api.py @@ -19,12 +19,12 @@ import yaml from lightning import Callback +from otx.backend.native.engine.adaptive_bs import adapt_batch_size from otx.core.config.hpo import HpoConfig from otx.core.optimizer.callable import OptimizerCallableSupportHPO from otx.core.schedulers import LinearWarmupSchedulerCallable, SchedulerCallableSupportHPO from otx.core.types.device import DeviceType from otx.core.types.task import OTXTaskType -from otx.engine.adaptive_bs import adapt_batch_size from otx.hpo import HyperBand, run_hpo_loop from otx.utils.device import is_xpu_available from otx.utils.utils import ( @@ -85,7 +85,7 @@ def execute_hpo( train_args=train_args, ) if ( - train_args.get("adaptive_bs", None) == "Full" + train_args.get("adaptive_bs") == "Full" and "datamodule.train_subset.batch_size" in hpo_configurator.hpo_config["search_space"] ): logger.info("Because adaptive_bs is set as Full, batch size is excluded from HPO.") diff --git a/src/otx/engine/hpo/hpo_trial.py b/src/otx/backend/native/engine/hpo/hpo_trial.py similarity index 100% rename from src/otx/engine/hpo/hpo_trial.py rename to src/otx/backend/native/engine/hpo/hpo_trial.py diff --git a/src/otx/engine/hpo/utils.py b/src/otx/backend/native/engine/hpo/utils.py similarity index 100% rename from src/otx/engine/hpo/utils.py rename to src/otx/backend/native/engine/hpo/utils.py diff --git a/src/otx/engine/utils/__init__.py b/src/otx/backend/native/engine/utils/__init__.py similarity index 100% rename from src/otx/engine/utils/__init__.py rename to src/otx/backend/native/engine/utils/__init__.py diff --git a/src/otx/engine/utils/api.py b/src/otx/backend/native/engine/utils/api.py similarity index 100% rename from src/otx/engine/utils/api.py rename to src/otx/backend/native/engine/utils/api.py diff --git a/src/otx/engine/utils/auto_configurator.py b/src/otx/backend/native/engine/utils/auto_configurator.py similarity index 100% rename from src/otx/engine/utils/auto_configurator.py rename to src/otx/backend/native/engine/utils/auto_configurator.py diff --git a/src/otx/engine/__init__.py b/src/otx/engine/__init__.py index 86f7406c917..9a7b579af11 100644 --- a/src/otx/engine/__init__.py +++ b/src/otx/engine/__init__.py @@ -3,6 +3,24 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + +from otx.backend.native.engine import NativeEngine + from .engine import Engine __all__ = ["Engine"] + +if TYPE_CHECKING: + from otx.types import DATA, MODEL + +SUPPORTED_ENGINES = [NativeEngine] + + +def create_engine(model: "MODEL", data: "DATA") -> Engine: + """Create an engine.""" + for engine in SUPPORTED_ENGINES: + if engine.is_supported(model, data): + return engine(model=model, data=data) + msg = f"No engine found for model {model} and data {data}" + raise ValueError(msg) diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index 9e6664cbcb8..5c261e3739e 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -1,1208 +1,49 @@ -# Copyright (C) 2024 Intel Corporation +"""Engine base class.""" + +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# -"""Module for OTX engine components.""" from __future__ import annotations -import copy -import csv -import inspect -import logging -import tempfile -import time -from contextlib import contextmanager -from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Iterator, Literal -from warnings import warn - -import torch -from lightning import Trainer, seed_everything -from lightning.pytorch.plugins.precision import MixedPrecision - -from otx.core.config.device import DeviceConfig -from otx.core.config.explain import ExplainConfig -from otx.core.config.hpo import HpoConfig -from otx.core.data.module import OTXDataModule -from otx.core.model.base import OTXModel, OVModel -from otx.core.types import PathLike -from otx.core.types.device import DeviceType -from otx.core.types.export import OTXExportFormatType -from otx.core.types.precision import OTXPrecisionType -from otx.core.types.task import OTXTaskType -from otx.core.utils.cache import TrainerArgumentsCache -from otx.utils.device import is_xpu_available -from otx.utils.utils import measure_flops - -from .adaptive_bs import adapt_batch_size -from .hpo import execute_hpo, update_hyper_parameter -from .utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, AutoConfigurator +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING if TYPE_CHECKING: - from lightning import Callback - from lightning.pytorch.loggers import Logger - from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS - from pytorch_lightning.trainer.connectors.accelerator_connector import _PRECISION_INPUT - - from otx.core.metrics import MetricCallable - - -@contextmanager -def override_metric_callable(model: OTXModel, new_metric_callable: MetricCallable | None) -> Iterator[OTXModel]: - """Override `OTXModel.metric_callable` to change the evaluation metric. - - Args: - model: Model to override its metric callable - new_metric_callable: If not None, override the model's one with this. Otherwise, do not override. - """ - if new_metric_callable is None: - yield model - return - - orig_metric_callable = model.metric_callable - try: - model.metric_callable = new_metric_callable - yield model - finally: - model.metric_callable = orig_metric_callable - - -class Engine: - """OTX Engine. - - This class defines the Engine for OTX, which governs each step of the OTX workflow. - - Example: - The following examples show how to use the Engine class. - - Auto-Configuration with data_root:: - - engine = Engine( - data_root=, - ) - - Create Engine with Custom OTXModel:: - - engine = Engine( - data_root=, - model=OTXModel(...), - checkpoint=, - ) - - Create Engine with Custom OTXDataModule:: - - engine = Engine( - model = OTXModel(...), - datamodule = OTXDataModule(...), - ) - """ - - _EXPORTED_MODEL_BASE_NAME: ClassVar[str] = "exported_model" - - def __init__( - self, - *, - data_root: PathLike | None = None, - task: OTXTaskType | None = None, - work_dir: PathLike = "./otx-workspace", - datamodule: OTXDataModule | None = None, - model: OTXModel | str | None = None, - checkpoint: PathLike | None = None, - device: DeviceType = DeviceType.auto, - num_devices: int = 1, - **kwargs, - ): - """Initializes the OTX Engine. - - Args: - data_root (PathLike | None, optional): Root directory for the data. Defaults to None. - task (OTXTaskType | None, optional): The type of OTX task. Defaults to None. - work_dir (PathLike, optional): Working directory for the engine. Defaults to "./otx-workspace". - datamodule (OTXDataModule | None, optional): The data module for the engine. Defaults to None. - model (OTXModel | str | None, optional): The model for the engine. Defaults to None. - checkpoint (PathLike | None, optional): Path to the checkpoint file. Defaults to None. - device (DeviceType, optional): The device type to use. Defaults to DeviceType.auto. - num_devices (int, optional): The number of devices to use. If it is 2 or more, it will behave as multi-gpu. - **kwargs: Additional keyword arguments for pl.Trainer. - """ - self._cache = TrainerArgumentsCache(**kwargs) - self.checkpoint = checkpoint - self.work_dir = work_dir - self.device = device # type: ignore[assignment] - self.num_devices = num_devices - self._auto_configurator = AutoConfigurator( - data_root=data_root, - task=datamodule.task if datamodule is not None else task, - model_name=None if isinstance(model, OTXModel) else model, - ) - - self._datamodule: OTXDataModule | None = ( - datamodule if datamodule is not None else self._auto_configurator.get_datamodule() - ) - self.task = task if task is not None else self._auto_configurator.task - - self._trainer: Trainer | None = None - get_model_args: dict[str, Any] = {} - if self._datamodule is not None: - get_model_args["label_info"] = self._datamodule.label_info - if (input_size := self._datamodule.input_size) is not None: - get_model_args["input_size"] = (input_size, input_size) if isinstance(input_size, int) else input_size - self._model: OTXModel = ( - model if isinstance(model, OTXModel) else self._auto_configurator.get_model(**get_model_args) - ) - - # ------------------------------------------------------------------------ # - # General OTX Entry Points - # ------------------------------------------------------------------------ # - - def train( - self, - max_epochs: int = 10, - seed: int | None = None, - deterministic: bool | Literal["warn"] = False, - precision: _PRECISION_INPUT | None = "32", - val_check_interval: int | float | None = None, - callbacks: list[Callback] | Callback | None = None, - logger: Logger | Iterable[Logger] | bool | None = None, - resume: bool = False, - metric: MetricCallable | None = None, - run_hpo: bool = False, - hpo_config: HpoConfig = HpoConfig(), # noqa: B008 https://github.com/omni-us/jsonargparse/issues/423 - checkpoint: PathLike | None = None, - adaptive_bs: Literal["None", "Safe", "Full"] = "None", - **kwargs, - ) -> dict[str, Any]: - r"""Trains the model using the provided LightningModule and OTXDataModule. - - Args: - max_epochs (int | None, optional): The maximum number of epochs. Defaults to None. - seed (int | None, optional): The random seed. Defaults to None. - deterministic (bool | Literal["warn"]): Whether to enable deterministic behavior. - Also, can be set to `warn` to avoid failures, because some operations don't - support deterministic mode. Defaults to False. - precision (_PRECISION_INPUT | None, optional): The precision of the model. Defaults to 32. - val_check_interval (int | float | None, optional): The validation check interval. Defaults to None. - callbacks (list[Callback] | Callback | None, optional): The callbacks to be used during training. - logger (Logger | Iterable[Logger] | bool | None, optional): The logger(s) to be used. Defaults to None. - resume (bool, optional): If True, tries to resume training from existing checkpoint. - metric (MetricCallable | None): If not None, it will override `OTXModel.metric_callable` with the given - metric callable. It will temporarilly change the evaluation metric for the validation and test. - run_hpo (bool, optional): If True, optimizer hyper parameters before training a model. - hpo_config (HpoConfig | None, optional): Configuration for HPO. - checkpoint (PathLike | None, optional): Path to the checkpoint file. Defaults to None. - adaptive_bs (Literal["None", "Safe", "Full"]): - Change the actual batch size depending on the current GPU status. - Safe => Prevent GPU out of memory. Full => Find a batch size using most of GPU memory. - **kwargs: Additional keyword arguments for pl.Trainer configuration. - - Returns: - dict[str, Any]: A dictionary containing the callback metrics from the trainer. - - Example: - >>> engine.train( - ... max_epochs=3, - ... seed=1234, - ... deterministic=False, - ... precision="32", - ... ) - - CLI Usage: - 1. Can train with data_root only. then OTX will provide default training configuration. - ```shell - >>> otx train --data_root - ``` - 2. Can pick a model or datamodule as Config file or Class. - ```shell - >>> otx train \ - ... --data_root \ - ... --model \ - ... --data - ``` - 3. Of course, can override the various values with commands. - ```shell - >>> otx train \ - ... --data_root \ - ... --max_epochs \ - ... --checkpoint - ``` - 4. To train with configuration file, run - ```shell - >>> otx train --data_root --config - ``` - 5. To reproduce the existing training with work_dir, run - ```shell - >>> otx train --work_dir - ``` - """ - checkpoint = checkpoint if checkpoint is not None else self.checkpoint - - if adaptive_bs != "None": - adapt_batch_size(engine=self, **locals(), not_increase=(adaptive_bs != "Full")) - - if run_hpo: - best_config, best_trial_weight = execute_hpo(engine=self, **locals()) - if best_config is not None: - update_hyper_parameter(self, best_config) - if best_trial_weight is not None: - checkpoint = best_trial_weight - resume = True - - if seed is not None: - seed_everything(seed, workers=True) - - self._build_trainer( - logger=logger, - callbacks=callbacks, - precision=precision, - max_epochs=max_epochs, - deterministic=deterministic, - val_check_interval=val_check_interval, - **kwargs, - ) - fit_kwargs: dict[str, Any] = {} - - # NOTE: Model's label info should be converted datamodule's label info before ckpt loading - # This is due to smart weight loading check label name as well as number of classes. - if self.model.label_info != self.datamodule.label_info: - msg = ( - "Model label_info is not equal to the Datamodule label_info. " - f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" - ) - logging.warning(msg) - self.model.label_info = self.datamodule.label_info - - if resume and checkpoint: - # NOTE: If both `resume` and `checkpoint` are provided, - # load the entire model state from the checkpoint using the pl.Trainer's API. - fit_kwargs["ckpt_path"] = checkpoint - elif not resume and checkpoint: - # NOTE: If `resume` is not enabled but `checkpoint` is provided, - # load the model state from the checkpoint incrementally. - # This means only the model weights are loaded. If there is a mismatch in label_info, - # perform incremental weight loading for the model's classification layer. - ckpt = torch.load(checkpoint) - self.model.load_state_dict_incrementally(ckpt) - - with override_metric_callable(model=self.model, new_metric_callable=metric) as model: - self.trainer.fit( - model=model, - datamodule=self.datamodule, - **fit_kwargs, - ) - self.checkpoint = self.trainer.checkpoint_callback.best_model_path - - if not isinstance(self.checkpoint, (Path, str)): - msg = "self.checkpoint should be Path or str at this time." - raise TypeError(msg) - - best_checkpoint_symlink = Path(self.work_dir) / "best_checkpoint.ckpt" - if best_checkpoint_symlink.is_symlink(): - best_checkpoint_symlink.unlink() - best_checkpoint_symlink.symlink_to(self.checkpoint) - - return self.trainer.callback_metrics - - def test( - self, - checkpoint: PathLike | None = None, - datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, - metric: MetricCallable | None = None, - **kwargs, - ) -> dict: - r"""Run the testing phase of the engine. - - Args: - checkpoint (PathLike | None, optional): Path to the checkpoint file to load the model from. - Defaults to None. - datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module containing the test data. - metric (MetricCallable | None): If not None, it will override `OTXModel.metric_callable` with the given - metric callable. It will temporarilly change the evaluation metric for the validation and test. - **kwargs: Additional keyword arguments for pl.Trainer configuration. - - Returns: - dict: Dictionary containing the callback metrics from the trainer. - - Example: - >>> engine.test( - ... datamodule=OTXDataModule(), - ... checkpoint=, - ... ) - - CLI Usage: - 1. To eval model by specifying the work_dir where did the training, run - ```shell - >>> otx test --work_dir - ``` - 2. To eval model a specific checkpoint, run - ```shell - >>> otx test --work_dir --checkpoint - ``` - 3. Can pick a model. - ```shell - >>> otx test \ - ... --model \ - ... --data_root \ - ... --checkpoint - ``` - 4. To eval with configuration file, run - ```shell - >>> otx test --config --checkpoint - ``` - """ - model = self.model - checkpoint = checkpoint if checkpoint is not None else self.checkpoint - datamodule = datamodule if datamodule is not None else self.datamodule - - is_ir_ckpt = Path(str(checkpoint)).suffix in [".xml", ".onnx"] - if is_ir_ckpt and not isinstance(model, OVModel): - model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) - if self.device.accelerator != "cpu": - msg = "IR model supports inference only on CPU device. The device is changed automatic." - warn(msg, stacklevel=1) - self.device = DeviceType.cpu # type: ignore[assignment] - - # NOTE: Re-initiate datamodule without tiling as model API supports its own tiling mechanism - if isinstance(model, OVModel): - datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") - - # NOTE, trainer.test takes only lightning based checkpoint. - # So, it can't take the OTX1.x checkpoint. - if checkpoint is not None and not is_ir_ckpt: - kwargs_user_input: dict[str, Any] = {} - if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: - # to update user's custom infer_reference_info_root through cli for zero-shot learning - # TODO (sungchul): revisit for better solution - kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) - - model_cls = model.__class__ - model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) - - if model.label_info != self.datamodule.label_info: - if ( - self.task == "SEMANTIC_SEGMENTATION" - and "otx_background_lbl" in self.datamodule.label_info.label_names - and (len(self.datamodule.label_info.label_names) - len(model.label_info.label_names) == 1) - ): - # workaround for background label - model.label_info = copy.deepcopy(self.datamodule.label_info) - else: - msg = ( - "To launch a test pipeline, the label information should be same " - "between the training and testing datasets. " - "Please check whether you use the same dataset: " - f"model.label_info={model.label_info}, " - f"datamodule.label_info={self.datamodule.label_info}" - ) - raise ValueError(msg) - - self._build_trainer(**kwargs) - - with override_metric_callable(model=model, new_metric_callable=metric) as model: - self.trainer.test( - model=model, - dataloaders=datamodule, - ) - - return self.trainer.callback_metrics - - def predict( - self, - checkpoint: PathLike | None = None, - datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, - return_predictions: bool | None = None, - explain: bool = False, - explain_config: ExplainConfig | None = None, - **kwargs, - ) -> list | None: - r"""Run predictions using the specified model and data. - - Args: - checkpoint (PathLike | None, optional): The path to the checkpoint file to load the model from. - datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module to use for predictions. - return_predictions (bool | None, optional): Whether to return the predictions or not. - explain (bool, optional): Whether to dump "saliency_map" and "feature_vector" or not. - explain_config (ExplainConfig | None, optional): Explain configuration used for saliency map post-processing - **kwargs: Additional keyword arguments for pl.Trainer configuration. - - Returns: - list | None: The predictions if `return_predictions` is True, otherwise None. - - Example: - >>> engine.predict( - ... datamodule=OTXDataModule(), - ... checkpoint=, - ... return_predictions=True, - ... explain=True, - ... ) - - CLI Usage: - 1. To predict a model with work_dir, run - ```shell - >>> otx predict --work_dir - ``` - 2. To predict a specific model, run - ```shell - >>> otx predict \ - ... --work_dir \ - ... --checkpoint - ``` - 3. To predict with configuration file, run - ```shell - >>> otx predict \ - ... --config \ - ... --checkpoint - ``` - """ - from otx.algo.utils.xai_utils import process_saliency_maps_in_pred_entity, set_crop_padded_map_flag - - model = self.model - - checkpoint = checkpoint if checkpoint is not None else self.checkpoint - datamodule = datamodule if datamodule is not None else self.datamodule - - is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"] - if is_ir_ckpt and not isinstance(model, OVModel): - model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) - - # NOTE: Re-initiate datamodule for OVModel as model API supports its own data pipeline. - if isinstance(model, OVModel): - datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") - - if checkpoint is not None and not is_ir_ckpt: - kwargs_user_input: dict[str, Any] = {} - if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: - # to update user's custom infer_reference_info_root through cli for zero-shot learning - # TODO (sungchul): revisit for better solution - kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) - - model_cls = model.__class__ - model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) - - if model.label_info != self.datamodule.label_info: - msg = ( - "To launch a predict pipeline, the label information should be same " - "between the training and testing datasets. " - "Please check whether you use the same dataset: " - f"model.label_info={model.label_info}, " - f"datamodule.label_info={self.datamodule.label_info}" - ) - raise ValueError(msg) - - self._build_trainer(**kwargs) - - curr_explain_mode = model.explain_mode - - try: - model.explain_mode = explain - predict_result = self.trainer.predict( - model=model, - dataloaders=datamodule, - return_predictions=return_predictions, - ) - finally: - model.explain_mode = curr_explain_mode - - if explain: - if explain_config is None: - explain_config = ExplainConfig() - explain_config = set_crop_padded_map_flag(explain_config, datamodule) - - predict_result = process_saliency_maps_in_pred_entity(predict_result, explain_config, datamodule.label_info) - - return predict_result - - def export( - self, - checkpoint: PathLike | None = None, - export_format: OTXExportFormatType = OTXExportFormatType.OPENVINO, - export_precision: OTXPrecisionType = OTXPrecisionType.FP32, - explain: bool = False, - export_demo_package: bool = False, - ) -> Path: - r"""Export the trained model to OpenVINO Intermediate Representation (IR) or ONNX formats. - - Args: - checkpoint (PathLike | None, optional): Checkpoint to export. Defaults to None. - export_config (ExportConfig | None, optional): Config that allows to set export - format and precision. Defaults to None. - explain (bool): Whether to get "saliency_map" and "feature_vector" or not. - export_demo_package (bool): Whether to export demo package with the model. - Only OpenVINO model can be exported with demo package. - - Returns: - Path: Path to the exported model. - - Example: - >>> engine.export( - ... checkpoint=, - ... export_format=OTXExportFormatType.OPENVINO, - ... export_precision=OTXExportPrecisionType.FP32, - ... explain=True, - ... ) - - CLI Usage: - 1. To export a model with default setting (OPENVINO, FP32), run - ```shell - >>> otx export --work_dir - ``` - 2. To export a specific checkpoint, run - ```shell - >>> otx export --config --checkpoint - ``` - 3. To export a model with precision FP16 and format ONNX, run - ```shell - >>> otx export ... \ - ... --export_precision FP16 --export_format ONNX - ``` - 4. To export model with 'saliency_map' and 'feature_vector', run - ```shell - >>> otx export ... \ - ... --explain True - ``` - """ - checkpoint = checkpoint if checkpoint is not None else self.checkpoint - - if checkpoint is None: - msg = "To make export, checkpoint must be specified." - raise RuntimeError(msg) - is_ir_ckpt = Path(checkpoint).suffix in [".xml"] - if export_demo_package and export_format == OTXExportFormatType.ONNX: - msg = ( - "ONNX export is not supported in exportable code mode. " - "Exportable code parameter will be disregarded. " - ) - warn(msg, stacklevel=1) - export_demo_package = False - - if is_ir_ckpt and not export_demo_package: - msg = "IR model is passed as a checkpoint, export automatically switched to exportable code." - warn(msg, stacklevel=1) - export_demo_package = True - - if is_ir_ckpt and not isinstance(self.model, OVModel): - # create OVModel - self.model = self._auto_configurator.get_ov_model( - model_name=str(checkpoint), - label_info=self.datamodule.label_info, - ) - - if not is_ir_ckpt: - kwargs_user_input: dict[str, Any] = {} - if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: - # to update user's custom infer_reference_info_root through cli for zero-shot learning - # TODO (sungchul): revisit for better solution - kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) - - model_cls = self.model.__class__ - self.model = model_cls.load_from_checkpoint( - checkpoint_path=checkpoint, - map_location="cpu", - **kwargs_user_input, - ) - self.model.eval() - - self.model.explain_mode = explain - exported_model_path = self.model.export( - output_dir=Path(self.work_dir), - base_name=self._EXPORTED_MODEL_BASE_NAME, - export_format=export_format, - precision=export_precision, - to_exportable_code=export_demo_package, - ) - - self.model.explain_mode = False - return exported_model_path - - def optimize( - self, - checkpoint: PathLike | None = None, - datamodule: TRAIN_DATALOADERS | OTXDataModule | None = None, - max_data_subset_size: int | None = None, - export_demo_package: bool = False, - ) -> Path: - r"""Applies NNCF.PTQ to the underlying models (now works only for OV models). - - PTQ performs int-8 quantization on the input model, so the resulting model - comes in mixed precision (some operations, however, remain in FP32). - - Args: - checkpoint (str | Path | None, optional): Checkpoint to optimize. Defaults to None. - datamodule (TRAIN_DATALOADERS | OTXDataModule | None, optional): The data module to use for optimization. - max_data_subset_size (int | None): The maximum size of the train subset from `datamodule` that would be - used for model optimization. If not set, NNCF.PTQ will select subset size according to it's - default settings. - export_demo_package (bool): Whether to export demo package with optimized models. - It outputs zip archive with stand-alone demo package. - - Returns: - Path: path to the optimized model. - - Example: - >>> engine.optimize( - ... checkpoint=, - ... datamodule=OTXDataModule(), - ... checkpoint=, - ... ) - - CLI Usage: - 1. To optimize a model with IR Model, run - ```shell - >>> otx optimize \ - ... --work_dir \ - ... --checkpoint - ``` - 2. To optimize a specific OVModel class with XML, run - ```shell - >>> otx optimize \ - ... --data_root \ - ... --checkpoint \ - ... --model \ - ... --model.model_name= - ``` - """ - checkpoint = checkpoint if checkpoint is not None else self.checkpoint - optimize_datamodule = datamodule if datamodule is not None else self.datamodule - - is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"] - if not is_ir_ckpt: - msg = "Engine.optimize() supports only OV IR or ONNX checkpoints" - raise RuntimeError(msg) - - model = self.model - if not isinstance(model, OVModel): - optimize_datamodule = self._auto_configurator.update_ov_subset_pipeline( - datamodule=optimize_datamodule, - subset="train", - ) - model = self._auto_configurator.get_ov_model( - model_name=str(checkpoint), - label_info=optimize_datamodule.label_info, - ) - - ptq_config = {} - if max_data_subset_size is not None: - ptq_config["subset_size"] = max_data_subset_size - - if not export_demo_package: - return model.optimize( - Path(self.work_dir), - optimize_datamodule, - ptq_config, - ) - - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_model_path = model.optimize(Path(tmp_dir), optimize_datamodule, ptq_config) - return self.export( - checkpoint=tmp_model_path, - export_demo_package=True, - ) - - def explain( - self, - checkpoint: PathLike | None = None, - datamodule: EVAL_DATALOADERS | OTXDataModule | None = None, - explain_config: ExplainConfig | None = None, - dump: bool | None = False, - **kwargs, - ) -> list | None: - r"""Run XAI using the specified model and data (test subset). - - Args: - checkpoint (PathLike | None, optional): The path to the checkpoint file to load the model from. - datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module to use for predictions. - explain_config (ExplainConfig | None, optional): Config used to handle saliency maps. - dump (bool): Whether to dump "saliency_map" or not. - **kwargs: Additional keyword arguments for pl.Trainer configuration. - - Returns: - list: Saliency maps. - - Example: - >>> engine.explain( - ... datamodule=OTXDataModule(), - ... checkpoint=, - ... explain_config=ExplainConfig(), - ... dump=True, - ... ) - - CLI Usage: - 1. To run XAI with the torch model in work_dir, run - ```shell - >>> otx explain \ - ... --work_dir - ``` - 2. To run XAI using the specified model (torch or IR), run - ```shell - >>> otx explain \ - ... --work_dir \ - ... --checkpoint - ``` - 3. To run XAI using the configuration, run - ```shell - >>> otx explain \ - ... --config --data_root \ - ... --checkpoint - ``` - """ - from otx.algo.utils.xai_utils import ( - dump_saliency_maps, - process_saliency_maps_in_pred_entity, - set_crop_padded_map_flag, - ) - - model = self.model - - checkpoint = checkpoint if checkpoint is not None else self.checkpoint - datamodule = datamodule if datamodule is not None else self.datamodule - - is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"] - if is_ir_ckpt and not isinstance(model, OVModel): - datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test") - model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info) - - if checkpoint is not None and not is_ir_ckpt: - kwargs_user_input: dict[str, Any] = {} - if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: - # to update user's custom infer_reference_info_root through cli for zero-shot learning - # TODO (sungchul): revisit for better solution - kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) - - model_cls = model.__class__ - model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **kwargs_user_input) - - if model.label_info != self.datamodule.label_info: - msg = ( - "To launch a explain pipeline, the label information should be same " - "between the training and testing datasets. " - "Please check whether you use the same dataset: " - f"model.label_info={model.label_info}, " - f"datamodule.label_info={self.datamodule.label_info}" - ) - raise ValueError(msg) - - model.explain_mode = True - - self._build_trainer(**kwargs) - - predict_result = self.trainer.predict( - model=model, - datamodule=datamodule, - ) - - if explain_config is None: - explain_config = ExplainConfig() - explain_config = set_crop_padded_map_flag(explain_config, datamodule) - - predict_result = process_saliency_maps_in_pred_entity(predict_result, explain_config, datamodule.label_info) - if dump: - dump_saliency_maps( - predict_result, - explain_config, - datamodule, - output_dir=Path(self.work_dir), - ) - model.explain_mode = False - return predict_result - - def benchmark( - self, - checkpoint: PathLike | None = None, - batch_size: int = 1, - n_iters: int = 10, - extended_stats: bool = False, - print_table: bool = True, - ) -> dict[str, str]: - r"""Executes model micro benchmarking on random data. - - Benchmark can provide latency, throughput, number of parameters, - and theoretical computational complexity with batch size 1. - The latter two characteristics are available for torch model recipes only. - Before the measurements, a warm-up is done. - - Args: - checkpoint (PathLike | None, optional): Path to checkpoint. Optional for torch models. Defaults to None. - batch_size (int, optional): Batch size for benchmarking. Defaults to 1. - n_iters (int, optional): Number of iterations to average on. Defaults to 10. - extended_stats (bool, optional): Flag that enables printing of per module complexity for torch model. - Defaults to False. - print_table (bool, optional): Flag that enables printing the benchmark results in a rich table. - Defaults to True. - - Returns: - dict[str, str]: a dict with the benchmark results. - - Example: - >>> engine.benchmark( - ... checkpoint=, - ... batch_size=1, - ... n_iters=20, - ... extended_stats=True, - ... ) - - CLI Usage: - 1. To run benchmark by specifying the work_dir where did the training, run - ```shell - >>> otx benchmark --work_dir - ``` - 2. To run benchmark by specifying the checkpoint, run - ```shell - >>> otx benchmark \ - ... --work_dir \ - ... --checkpoint - ``` - 3. To run benchmark using the configuration, launch - ```shell - >>> otx benchmark \ - ... --config \ - ... --data_root \ - ... --checkpoint - ``` - """ - checkpoint = checkpoint if checkpoint is not None else self.checkpoint - - if checkpoint is not None: - is_ir_ckpt = Path(checkpoint).suffix in [".xml"] - if is_ir_ckpt and not isinstance(self.model, OVModel): - # create OVModel - self.model = self._auto_configurator.get_ov_model( - model_name=str(checkpoint), - label_info=self.datamodule.label_info, - ) - - if not is_ir_ckpt: - kwargs_user_input: dict[str, Any] = {} - if self.task == OTXTaskType.ZERO_SHOT_VISUAL_PROMPTING: - # to update user's custom infer_reference_info_root through cli for zero-shot learning - # TODO (sungchul): revisit for better solution - kwargs_user_input.update(infer_reference_info_root=self.model.infer_reference_info_root) - - model_cls = self.model.__class__ - self.model = model_cls.load_from_checkpoint( - checkpoint_path=checkpoint, - map_location="cpu", - **kwargs_user_input, - ) - elif isinstance(self.model, OVModel): - msg = "To run benchmark on OV model, checkpoint must be specified." - raise RuntimeError(msg) - - self.model.eval() - - def dummy_infer(model: OTXModel, batch_size: int = 1) -> float: - input_batch = model.get_dummy_input(batch_size) - start = time.perf_counter() - model.forward(input_batch) - end = time.perf_counter() - return end - start - - warmup_iters = max(1, int(n_iters / 10)) - for _ in range(warmup_iters): - dummy_infer(self.model, batch_size) - - total_time = 0.0 - for _ in range(n_iters): - total_time += dummy_infer(self.model, batch_size) - latency = total_time / n_iters - fps = batch_size / latency - - final_stats = {"latency": f"{latency:.3f} s", "throughput": f"{(fps):.3f} FPS"} - - if not isinstance(self.model, OVModel): - try: - from torch.utils.flop_counter import convert_num_with_suffix, get_suffix_str - - input_batch = self.model.get_dummy_input(1) - model_fwd = lambda: self.model.forward(input_batch) - depth = 3 if extended_stats else 0 - fwd_flops = measure_flops(model_fwd, print_stats_depth=depth) - flops_str = convert_num_with_suffix(fwd_flops, get_suffix_str(fwd_flops * 10**3)) - final_stats["complexity"] = flops_str + " MACs" - except Exception as e: - logging.warning(f"Failed to complete complexity estimation: {e}") - - params_num = sum(p.numel() for p in self.model.parameters() if p.requires_grad) - params_num_str = convert_num_with_suffix(params_num, get_suffix_str(params_num * 100)) - final_stats["parameters_number"] = params_num_str - - if print_table: - from rich.console import Console - from rich.table import Column, Table - - console = Console() - table_headers = ["Benchmark", "Value"] - columns = [Column(h, justify="center", style="magenta", width=console.width) for h in table_headers] - columns[0].style = "cyan" - table = Table(*columns) - for name, val in final_stats.items(): - table.add_row(*[f"{name:<20}", f"{val}"]) - console.print(table) - - with (Path(self.work_dir) / "benchmark_report.csv").open("w") as f: - writer = csv.writer(f) - writer.writerow(list(final_stats)) - writer.writerow(list(final_stats.values())) - - return final_stats - - @classmethod - def from_config( - cls, - config_path: PathLike, - data_root: PathLike | None = None, - work_dir: PathLike | None = None, - **kwargs, - ) -> Engine: - """Builds the engine from a configuration file. - - Args: - config_path (PathLike): The configuration file path. - data_root (PathLike | None): Root directory for the data. - Defaults to None. If data_root is None, use the data_root from the configuration file. - work_dir (PathLike | None, optional): Working directory for the engine. - Defaults to None. If work_dir is None, use the work_dir from the configuration file. - kwargs: Arguments that can override the engine's arguments. - - Returns: - Engine: An instance of the Engine class. - - Example: - >>> engine = Engine.from_config( - ... config="config.yaml", - ... ) - """ - from otx.cli.utils.jsonargparse import get_instantiated_classes - - # For the Engine argument, prepend 'engine.' for CLI parser - filter_kwargs = ["device", "checkpoint", "task"] - for key in filter_kwargs: - if key in kwargs: - kwargs[f"engine.{key}"] = kwargs.pop(key) - instantiated_config, train_kwargs = get_instantiated_classes( - config=config_path, - data_root=data_root, - work_dir=work_dir, - **kwargs, - ) - engine_kwargs = {**instantiated_config.get("engine", {}), **train_kwargs} - - # Remove any input that is not currently available in Engine and print a warning message. - set_valid_args = TrainerArgumentsCache.get_trainer_constructor_args().union( - set(inspect.signature(Engine.__init__).parameters.keys()), - ) - removed_args = [] - for engine_key in list(engine_kwargs.keys()): - if engine_key not in set_valid_args: - engine_kwargs.pop(engine_key) - removed_args.append(engine_key) - if removed_args: - msg = ( - f"Warning: {removed_args} -> not available in Engine constructor. " - "It will be ignored. Use what need in the right places." - ) - warn(msg, stacklevel=1) - - if (datamodule := instantiated_config.get("data")) is None: - msg = "Cannot instantiate datamodule from config." - raise ValueError(msg) - if not isinstance(datamodule, OTXDataModule): - raise TypeError(datamodule) - - if (model := instantiated_config.get("model")) is None: - msg = "Cannot instantiate model from config." - raise ValueError(msg) - if not isinstance(model, OTXModel): - raise TypeError(model) - - model.label_info = datamodule.label_info - - return cls( - work_dir=instantiated_config.get("work_dir", work_dir), - datamodule=datamodule, - model=model, - **engine_kwargs, - ) - - @classmethod - def from_model_name( - cls, - model_name: str, - task: OTXTaskType, - data_root: PathLike | None = None, - work_dir: PathLike | None = None, - **kwargs, - ) -> Engine: - """Builds the engine from a model name. - - Args: - model_name (str): The model name. - task (OTXTaskType): The type of OTX task. - data_root (PathLike | None): Root directory for the data. - Defaults to None. If data_root is None, use the data_root from the configuration file. - work_dir (PathLike | None, optional): Working directory for the engine. - Defaults to None. If work_dir is None, use the work_dir from the configuration file. - kwargs: Arguments that can override the engine's arguments. - - Returns: - Engine: An instance of the Engine class. - - Example: - >>> engine = Engine.from_model_name( - ... model_name="atss_mobilenetv2", - ... task="DETECTION", - ... data_root=, - ... ) - - If you want to override configuration from default config: - >>> overriding = { - ... "data.train_subset.batch_size": 2, - ... "data.test_subset.subset_name": "TESTING", - ... } - >>> engine = Engine( - ... model_name="atss_mobilenetv2", - ... task="DETECTION", - ... data_root=, - ... **overriding, - ... ) - """ - default_config = DEFAULT_CONFIG_PER_TASK.get(task) - model_path = str(default_config).split("/") - model_path[-1] = f"{model_name}.yaml" - config = Path("/".join(model_path)) - if not config.exists(): - candidate_list = [model.stem for model in config.parent.glob("*")] - msg = ( - f"Model config file not found: {config}, please check the model name. " - f"Available models for {task} task are {candidate_list}" - ) - raise FileNotFoundError(msg) - - return cls.from_config( - config_path=config, - data_root=data_root, - work_dir=work_dir, - task=task, - **kwargs, - ) - - # ------------------------------------------------------------------------ # - # Property and setter functions provided by Engine. - # ------------------------------------------------------------------------ # - - @property - def work_dir(self) -> PathLike: - """Work directory.""" - return self._work_dir - - @work_dir.setter - def work_dir(self, work_dir: PathLike) -> None: - self._work_dir = work_dir - self._cache.update(default_root_dir=work_dir) - self._cache.is_trainer_args_identical = False - - @property - def device(self) -> DeviceConfig: - """Device engine uses.""" - return self._device - - @device.setter - def device(self, device: DeviceType) -> None: - if is_xpu_available() and device == DeviceType.auto: - device = DeviceType.xpu - self._device = DeviceConfig(accelerator=device) - self._cache.update(accelerator=self._device.accelerator, devices=self._device.devices) - self._cache.is_trainer_args_identical = False - - @property - def num_devices(self) -> int: - """Number of devices for Engine use.""" - return self._device.devices - - @num_devices.setter - def num_devices(self, num_devices: int) -> None: - """Setter function for multi-gpu.""" - self._device.devices = num_devices - self._cache.update(devices=self._device.devices) - self._cache.is_trainer_args_identical = False - - @property - def trainer(self) -> Trainer: - """Returns the trainer object associated with the engine. - - To get this property, you should execute `Engine.train()` function first. - - Returns: - Trainer: The trainer object. - """ - if self._trainer is None: - msg = "Please run train() first" - raise RuntimeError(msg) - return self._trainer - - def _build_trainer(self, **kwargs) -> None: - """Instantiate the trainer based on the model parameters.""" - if self._cache.requires_update(**kwargs) or self._trainer is None: - self._cache.update(**kwargs) - # set up xpu device - if self._device.accelerator == DeviceType.xpu: - self._cache.update(strategy="xpu_single") - # add plugin for Automatic Mixed Precision on XPU - if self._cache.args.get("precision", 32) == 16: - self._cache.update( - plugins=[ - MixedPrecision( - precision="bf16-mixed", - device="xpu", - ), - ], - ) - self._cache.args["precision"] = None - - kwargs = self._cache.args - self._trainer = Trainer(**kwargs) - self._cache.is_trainer_args_identical = True - self._trainer.task = self.task - self.work_dir = self._trainer.default_root_dir + from pathlib import Path - @property - def trainer_params(self) -> dict: - """Returns the parameters used for training the model. + from otx.types import ANNOTATIONS, DATA, METRICS, MODEL - Returns: - dict: A dictionary containing the training parameters. - """ - return self._cache.args - @property - def model(self) -> OTXModel: - """Returns the model object associated with the engine. +class Engine(ABC): + """Engine base class.""" - Returns: - OTXModel: The OTXModel object. - """ - return self._model + @abstractmethod + def train(self, **kwargs) -> METRICS: + """Train the model.""" + raise NotImplementedError - @model.setter - def model(self, model: OTXModel | str) -> None: - """Sets the model for the engine. + @abstractmethod + def test(self, **kwargs) -> METRICS: + """Test the model.""" + raise NotImplementedError - Args: - model (OTXModel | str): The model to be set. + @abstractmethod + def predict(self, **kwargs) -> ANNOTATIONS: + """Predict on model.""" + raise NotImplementedError - Returns: - None - """ - if isinstance(model, str): - model = self._auto_configurator.get_model(model, label_info=self.datamodule.label_info) - self._model = model + @abstractmethod + def export(self, **kwargs) -> Path: + """Export the model.""" + raise NotImplementedError - @property - def datamodule(self) -> OTXDataModule: - """Returns the datamodule object associated with the engine. + @abstractmethod + def optimize(self, **kwargs) -> Path: + """Optimize the model.""" + raise NotImplementedError - Returns: - OTXDataModule: The OTXDataModule object. - """ - if self._datamodule is None: - msg = "Please include the `data_root` or `datamodule` when creating the Engine." - raise RuntimeError(msg) - return self._datamodule + @staticmethod + @abstractmethod + def is_supported(model: MODEL, data: DATA) -> bool: + """Check if the engine is supported for the given model and data.""" + raise NotImplementedError diff --git a/src/otx/types.py b/src/otx/types.py new file mode 100644 index 00000000000..4a55fab026b --- /dev/null +++ b/src/otx/types.py @@ -0,0 +1,14 @@ +"""Typing hints for OTX.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any + +from otx.core.model.base import OTXModel + +METRICS = Any # TODO(ashwinvaidya17): Temporary till metrics is properly defined +ANNOTATIONS = Any # TODO(ashwinvaidya17): Temporary till annotations is properly defined +MODEL = OTXModel # TODO(ashwinvaidya17): Temporary till model is properly defined From 4a8a224156d0e9c50ef85b2c1b2baaf88576f87c Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Thu, 30 Jan 2025 14:23:47 +0100 Subject: [PATCH 2/2] Fix tests Signed-off-by: Ashwin Vaidya --- src/otx/backend/native/engine/__init__.py | 4 +- .../engine/adaptive_bs/adaptive_bs_api.py | 8 ++-- src/otx/backend/native/engine/engine.py | 41 ++++++++++++++----- src/otx/backend/native/engine/hpo/hpo_api.py | 6 +-- .../backend/native/engine/hpo/hpo_trial.py | 12 +++--- src/otx/cli/cli.py | 16 ++++---- src/otx/cli/utils/help_formatter.py | 4 +- src/otx/engine/__init__.py | 32 ++++++++++----- src/otx/tools/converter.py | 13 +++--- src/otx/types.py | 1 + tests/e2e/cli/test_cli.py | 4 +- tests/integration/api/test_augmentation.py | 3 +- .../api/test_auto_configuration.py | 7 ++-- tests/integration/api/test_engine_api.py | 12 +++--- tests/integration/api/test_xai.py | 7 ++-- .../cli/test_auto_configuration.py | 4 +- tests/integration/cli/test_cli.py | 4 +- .../algo/explain/test_saliency_map_dumping.py | 3 +- tests/unit/cli/test_cli.py | 7 ++-- .../adaptive_bs/test_adaptive_bs_api.py | 10 ++++- .../engine/adaptive_bs/test_bs_search_algo.py | 9 +++- tests/unit/engine/hpo/test_hpo_api.py | 11 ++--- tests/unit/engine/hpo/test_hpo_trial.py | 9 ++-- tests/unit/engine/hpo/test_utils.py | 3 +- tests/unit/engine/test_engine.py | 29 ++++++------- tests/unit/engine/utils/test_api.py | 3 +- .../engine/utils/test_auto_configurator.py | 13 +++--- 27 files changed, 167 insertions(+), 108 deletions(-) diff --git a/src/otx/backend/native/engine/__init__.py b/src/otx/backend/native/engine/__init__.py index 54f0aa601ce..807cd58e6c4 100644 --- a/src/otx/backend/native/engine/__init__.py +++ b/src/otx/backend/native/engine/__init__.py @@ -3,6 +3,6 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from .engine import NativeEngine +from .engine import OTXEngine -__all__ = ["NativeEngine"] +__all__ = ["OTXEngine"] diff --git a/src/otx/backend/native/engine/adaptive_bs/adaptive_bs_api.py b/src/otx/backend/native/engine/adaptive_bs/adaptive_bs_api.py index d2002ee81da..77334a09e5c 100644 --- a/src/otx/backend/native/engine/adaptive_bs/adaptive_bs_api.py +++ b/src/otx/backend/native/engine/adaptive_bs/adaptive_bs_api.py @@ -22,13 +22,13 @@ if TYPE_CHECKING: from lightning import LightningModule, Trainer - from otx.engine.engine import Engine + from otx.backend.native.engine import OTXEngine logger = logging.getLogger(__name__) def adapt_batch_size( - engine: Engine, + engine: OTXEngine, not_increase: bool = True, callbacks: list[Callback] | Callback | None = None, **train_args, @@ -94,7 +94,7 @@ def _adjust_train_args(train_args: dict[str, Any]) -> dict[str, Any]: return train_args -def _train_model(bs: int, engine: Engine, callbacks: list[Callback] | Callback | None = None, **train_args) -> None: +def _train_model(bs: int, engine: OTXEngine, callbacks: list[Callback] | Callback | None = None, **train_args) -> None: if bs <= 0: msg = f"Batch size should be greater than 0, but {bs} is given." raise ValueError(msg) @@ -167,7 +167,7 @@ def _scale_batch_reset_params(trainer: Trainer, steps_per_trial: int) -> None: trainer.limit_val_batches = steps_per_trial -def _apply_new_batch_size(engine: Engine, new_batch_size: int) -> None: +def _apply_new_batch_size(engine: OTXEngine, new_batch_size: int) -> None: origin_bs = engine.datamodule.train_subset.batch_size if new_batch_size == origin_bs: return diff --git a/src/otx/backend/native/engine/engine.py b/src/otx/backend/native/engine/engine.py index 83079aa7fd1..f97fa86ad4a 100644 --- a/src/otx/backend/native/engine/engine.py +++ b/src/otx/backend/native/engine/engine.py @@ -37,7 +37,7 @@ from .adaptive_bs import adapt_batch_size from .hpo import execute_hpo, update_hyper_parameter -from .utils.auto_configurator import DEFAULT_CONFIG_PER_TASK +from .utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, AutoConfigurator if TYPE_CHECKING: from lightning import Callback @@ -69,7 +69,7 @@ def override_metric_callable(model: OTXModel, new_metric_callable: MetricCallabl model.metric_callable = orig_metric_callable -class NativeEngine(Engine): +class OTXEngine(Engine): """Native Engine.""" """OTX Engine. @@ -107,6 +107,9 @@ def __init__( self, *, model: OTXModel | str | None = None, + data_root: PathLike | None = None, + task: OTXTaskType | None = None, + datamodule: OTXDataModule | None = None, data: OTXDataModule | None = None, work_dir: PathLike = "./otx-workspace", checkpoint: PathLike | None = None, @@ -133,16 +136,33 @@ def __init__( self.device = device # type: ignore[assignment] self.num_devices = num_devices - self._datamodule: OTXDataModule | None = data - self.task = data.task if data is not None else self._auto_configurator.task + if data_root is not None: + msg = "data_root is deprecated. Please pass dataloader to `data` instead." + logging.warning(msg) + + if datamodule is None: + msg = "datamodule is deprecated. Please pass data to `data` instead." + logging.warning(msg) + data = datamodule + + self._auto_configurator = AutoConfigurator( + data_root=data_root, + task=data.task if data is not None else task, + model_name=None if isinstance(model, OTXModel) else model, + ) + + self._datamodule: OTXDataModule | None = data if data is not None else self._auto_configurator.get_datamodule() + self.task = task if task is not None else self._auto_configurator.task self._trainer: Trainer | None = None get_model_args: dict[str, Any] = {} - if data is not None: - get_model_args["label_info"] = data.label_info - if (input_size := data.input_size) is not None: + if self._datamodule is not None: + get_model_args["label_info"] = self._datamodule.label_info + if (input_size := self._datamodule.input_size) is not None: get_model_args["input_size"] = (input_size, input_size) if isinstance(input_size, int) else input_size - self._model: OTXModel = model + self._model: OTXModel = ( + model if isinstance(model, OTXModel) else self._auto_configurator.get_model(**get_model_args) + ) # ------------------------------------------------------------------------ # # General OTX Entry Points @@ -504,7 +524,7 @@ def predict( return predict_result - def export( + def export( # type: ignore[override] self, checkpoint: PathLike | None = None, export_format: OTXExportFormatType = OTXExportFormatType.OPENVINO, @@ -606,7 +626,8 @@ def export( self.model.explain_mode = False return exported_model_path - def optimize( + # TODO(ashwinvaidya17): temporary till the base class contains all kwargs + def optimize( # type: ignore[override] self, checkpoint: PathLike | None = None, datamodule: TRAIN_DATALOADERS | OTXDataModule | None = None, diff --git a/src/otx/backend/native/engine/hpo/hpo_api.py b/src/otx/backend/native/engine/hpo/hpo_api.py index 28e20c75b32..cb2e80dc33f 100644 --- a/src/otx/backend/native/engine/hpo/hpo_api.py +++ b/src/otx/backend/native/engine/hpo/hpo_api.py @@ -39,14 +39,14 @@ if TYPE_CHECKING: from lightning.pytorch.cli import OptimizerCallable - from otx.engine.engine import Engine + from otx.backend.native.engine import OTXEngine from otx.hpo.hpo_base import HpoBase logger = logging.getLogger(__name__) def execute_hpo( - engine: Engine, + engine: OTXEngine, max_epochs: int, hpo_config: HpoConfig, callbacks: list[Callback] | Callback | None = None, @@ -163,7 +163,7 @@ class HPOConfigurator: def __init__( self, - engine: Engine, + engine: OTXEngine, max_epochs: int, hpo_config: HpoConfig, hpo_workdir: Path | None = None, diff --git a/src/otx/backend/native/engine/hpo/hpo_trial.py b/src/otx/backend/native/engine/hpo/hpo_trial.py index 6b999256c5d..71ca184cc20 100644 --- a/src/otx/backend/native/engine/hpo/hpo_trial.py +++ b/src/otx/backend/native/engine/hpo/hpo_trial.py @@ -22,10 +22,10 @@ if TYPE_CHECKING: from lightning import LightningModule, Trainer - from otx.engine.engine import Engine + from otx.backend.native.engine import OTXEngine -def update_hyper_parameter(engine: Engine, hyper_parameter: dict[str, Any]) -> None: +def update_hyper_parameter(engine: OTXEngine, hyper_parameter: dict[str, Any]) -> None: """Update hyper parameter in the engine.""" for key, val in hyper_parameter.items(): set_using_dot_delimited_key(key, val, engine) @@ -62,7 +62,7 @@ def run_hpo_trial( hp_config: dict[str, Any], report_func: Callable[[int | float, int | float, bool], None], hpo_workdir: Path, - engine: Engine, + engine: OTXEngine, callbacks: list[Callback] | Callback | None = None, metric_name: str | None = None, **train_args, @@ -107,7 +107,7 @@ def run_hpo_trial( report_func(0, 0, done=True) # type: ignore[call-arg] -def _set_trial_hyper_parameter(hyper_parameter: dict[str, Any], engine: Engine, train_args: dict[str, Any]) -> None: +def _set_trial_hyper_parameter(hyper_parameter: dict[str, Any], engine: OTXEngine, train_args: dict[str, Any]) -> None: train_args["max_epochs"] = round(hyper_parameter.pop("iterations")) update_hyper_parameter(engine, hyper_parameter) @@ -119,7 +119,7 @@ def _find_last_weight(weight_dir: Path) -> Path | None: def _register_hpo_callback( report_func: Callable, callbacks: list[Callback] | Callback | None = None, - engine: Engine | None = None, + engine: OTXEngine | None = None, metric_name: str | None = None, ) -> list[Callback]: if isinstance(callbacks, Callback): @@ -148,7 +148,7 @@ def _register_init_weight_callback(callbacks: list[Callback], save_path: Path) - return callbacks -def _change_work_dir(work_dir: str, callbacks: list[Callback], engine: Engine) -> None: +def _change_work_dir(work_dir: str, callbacks: list[Callback], engine: OTXEngine) -> None: for callback in callbacks: if isinstance(callback, ModelCheckpoint): callback.dirpath = work_dir diff --git a/src/otx/cli/cli.py b/src/otx/cli/cli.py index 3bd7a1d308a..3c8e2f11fa4 100644 --- a/src/otx/cli/cli.py +++ b/src/otx/cli/cli.py @@ -33,8 +33,8 @@ _ENGINE_AVAILABLE = True try: + from otx.backend.native.engine import OTXEngine from otx.core.config import register_configs - from otx.engine import Engine register_configs() except ImportError: @@ -144,7 +144,7 @@ def engine_subcommand_parser(subcommand: str, **kwargs) -> tuple[ArgumentParser, ) engine_skip = {"model", "datamodule", "work_dir"} parser.add_class_arguments( - Engine, + OTXEngine, "engine", fail_untyped=False, sub_configs=True, @@ -178,7 +178,7 @@ def engine_subcommand_parser(subcommand: str, **kwargs) -> tuple[ArgumentParser, parser.link_arguments("engine.device", "data.device") added_arguments = parser.add_method_arguments( - Engine, + OTXEngine, subcommand, skip=set(OTXCLI.engine_subcommands()[subcommand]), fail_untyped=False, @@ -254,7 +254,7 @@ def add_subcommands(self) -> None: # If the user specifies the config directly, not set the cache ckpt as default. self._load_cache_ckpt(parser=sub_parser) - fn = getattr(Engine, subcommand) + fn = getattr(OTXEngine, subcommand) description = get_short_docstring(fn) self._subcommand_method_arguments[subcommand] = added_arguments @@ -290,7 +290,7 @@ def _set_default_config(self) -> dict: task = sys.argv[sys.argv.index("--task") + 1] enable_auto_config = data_root is not None and "--config" not in sys.argv if enable_auto_config: - from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, AutoConfigurator + from otx.backend.native.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, AutoConfigurator auto_configurator = AutoConfigurator( data_root=data_root, @@ -359,14 +359,14 @@ def instantiate_classes(self, instantiate_engine: bool = True) -> None: if instantiate_engine: self.engine = self.instantiate_engine() - def instantiate_engine(self) -> Engine: + def instantiate_engine(self) -> OTXEngine: """Instantiate an Engine object with the specified parameters. Returns: An instance of the Engine class. """ engine_kwargs = self.get_config_value(self.config_init, "engine") - return Engine( + return OTXEngine( model=self.model, datamodule=self.datamodule, work_dir=self.workspace.work_dir, @@ -531,7 +531,7 @@ def run(self) -> None: otx_install(**self.config["install"]) elif self.subcommand == "find": - from otx.engine.utils.api import list_models + from otx.backend.native.engine.utils.api import list_models list_models(print_table=True, **self.config[self.subcommand]) elif self.subcommand in self.engine_subcommands(): diff --git a/src/otx/cli/utils/help_formatter.py b/src/otx/cli/utils/help_formatter.py index 908a19db2ce..cc358824b03 100644 --- a/src/otx/cli/utils/help_formatter.py +++ b/src/otx/cli/utils/help_formatter.py @@ -162,10 +162,10 @@ def render_guide(subcommand: str | None = None) -> list: """ if subcommand is None or subcommand in ("install"): return [] - from otx.engine import Engine + from otx.backend.native.engine import OTXEngine contents: list[Panel | Markdown] = [Markdown(INTRO_MARKDOWN)] - target_command = getattr(Engine, subcommand) + target_command = getattr(OTXEngine, subcommand) cli_usage = get_cli_usage_docstring(target_command) if cli_usage is not None: cli_usage += f"\n{VERBOSE_USAGE.format(subcommand=subcommand)}" diff --git a/src/otx/engine/__init__.py b/src/otx/engine/__init__.py index 9a7b579af11..80d367125e5 100644 --- a/src/otx/engine/__init__.py +++ b/src/otx/engine/__init__.py @@ -3,24 +3,36 @@ # Copyright (C) 2023 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING +from __future__ import annotations -from otx.backend.native.engine import NativeEngine +from typing import TYPE_CHECKING from .engine import Engine -__all__ = ["Engine"] - if TYPE_CHECKING: from otx.types import DATA, MODEL -SUPPORTED_ENGINES = [NativeEngine] +def create_engine(model: MODEL, data: DATA) -> Engine: + """Create an engine. + + Args: + model: The model to use + data: The data/datamodule to use + + Returns: + An instance of an Engine subclass that supports the model and data + + Raises: + ValueError: If no compatible engine is found + """ + # Get all concrete (non-abstract) subclasses of Engine + engine_classes: list[type[Engine]] = Engine.__subclasses__() + + for engine_cls in engine_classes: + if engine_cls.is_supported(model, data): + # Type ignore since mypy can't verify the constructor signature of subclasses + return engine_cls(model=model, datamodule=data) # type: ignore[call-arg] -def create_engine(model: "MODEL", data: "DATA") -> Engine: - """Create an engine.""" - for engine in SUPPORTED_ENGINES: - if engine.is_supported(model, data): - return engine(model=model, data=data) msg = f"No engine found for model {model} and data {data}" raise ValueError(msg) diff --git a/src/otx/tools/converter.py b/src/otx/tools/converter.py index fdbb7c73f00..6e7f0db1d7b 100644 --- a/src/otx/tools/converter.py +++ b/src/otx/tools/converter.py @@ -9,18 +9,21 @@ import json from copy import deepcopy from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from warnings import warn from jsonargparse import ArgumentParser, Namespace +from otx.backend.native.engine import OTXEngine +from otx.backend.native.engine.utils.auto_configurator import AutoConfigurator from otx.core.config.data import SamplerConfig, SubsetConfig, TileConfig, UnlabeledDataConfig from otx.core.data.module import OTXDataModule from otx.core.model.base import OTXModel from otx.core.types import PathLike from otx.core.types.task import OTXTaskType -from otx.engine import Engine -from otx.engine.utils.auto_configurator import AutoConfigurator + +if TYPE_CHECKING: + from otx.engine.engine import Engine TEMPLATE_ID_DICT = { # MULTI_CLASS_CLS @@ -468,7 +471,7 @@ def instantiate( # Instantiate Engine config_work_dir = config.pop("work_dir", config["engine"].pop("work_dir", None)) config["engine"]["work_dir"] = work_dir if work_dir is not None else config_work_dir - engine = Engine( + engine = OTXEngine( model=model, datamodule=datamodule, **config.pop("engine"), @@ -477,7 +480,7 @@ def instantiate( # Instantiate Engine.train Arguments engine_parser = ArgumentParser() train_arguments = engine_parser.add_method_arguments( - Engine, + OTXEngine, "train", skip={"accelerator", "devices"}, fail_untyped=False, diff --git a/src/otx/types.py b/src/otx/types.py index 4a55fab026b..7b6e6db3b51 100644 --- a/src/otx/types.py +++ b/src/otx/types.py @@ -12,3 +12,4 @@ METRICS = Any # TODO(ashwinvaidya17): Temporary till metrics is properly defined ANNOTATIONS = Any # TODO(ashwinvaidya17): Temporary till annotations is properly defined MODEL = OTXModel # TODO(ashwinvaidya17): Temporary till model is properly defined +DATA = Any # TODO(ashwinvaidya17): Temporary till data is properly defined diff --git a/tests/e2e/cli/test_cli.py b/tests/e2e/cli/test_cli.py index 3078784a8fd..33dd9412f03 100644 --- a/tests/e2e/cli/test_cli.py +++ b/tests/e2e/cli/test_cli.py @@ -8,9 +8,9 @@ import numpy as np import pytest import yaml -from otx.core.types.task import OTXTaskType -from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK +from otx.backend.native.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK +from otx.core.types.task import OTXTaskType from tests.e2e.cli.utils import run_main from tests.utils import ExportCase2Test diff --git a/tests/integration/api/test_augmentation.py b/tests/integration/api/test_augmentation.py index 2f0b11a64c2..55d3580140a 100644 --- a/tests/integration/api/test_augmentation.py +++ b/tests/integration/api/test_augmentation.py @@ -7,11 +7,12 @@ import pytest from datumaro import Dataset as DmDataset + +from otx.backend.native.engine.utils.auto_configurator import AutoConfigurator from otx.core.config.data import SamplerConfig, SubsetConfig from otx.core.data.factory import OTXDatasetFactory from otx.core.data.mem_cache import MemCacheHandlerSingleton from otx.core.types.task import OTXTaskType -from otx.engine.utils.auto_configurator import AutoConfigurator def _test_augmentation( diff --git a/tests/integration/api/test_auto_configuration.py b/tests/integration/api/test_auto_configuration.py index 75e9c2e365c..da2f1614523 100644 --- a/tests/integration/api/test_auto_configuration.py +++ b/tests/integration/api/test_auto_configuration.py @@ -4,11 +4,12 @@ from pathlib import Path import pytest + +from otx.backend.native.engine import OTXEngine +from otx.backend.native.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK from otx.core.data.module import OTXDataModule from otx.core.model.base import OTXModel from otx.core.types.task import OTXTaskType -from otx.engine import Engine -from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK @pytest.mark.parametrize("task", pytest.TASK_LIST) @@ -37,7 +38,7 @@ def test_auto_configuration( tmp_path_train = tmp_path / f"auto_train_{task}" data_root = fxt_target_dataset_per_task[task.lower()] - engine = Engine(data_root=data_root, task=task, work_dir=tmp_path_train, device=fxt_accelerator) + engine = OTXEngine(data_root=data_root, task=task, work_dir=tmp_path_train, device=fxt_accelerator) if task.lower() == "zero_shot_visual_prompting": engine.model.infer_reference_info_root = Path(tmp_path_train) # update litmodule.hparams to reflect changed hparams diff --git a/tests/integration/api/test_engine_api.py b/tests/integration/api/test_engine_api.py index b36f658a869..cbb5e51709a 100644 --- a/tests/integration/api/test_engine_api.py +++ b/tests/integration/api/test_engine_api.py @@ -8,14 +8,14 @@ import pytest from datumaro import Dataset as DmDataset from model_api.tilers import Tiler + from otx.algo.classification.efficientnet import EfficientNetForMulticlassCls +from otx.backend.native.engine import OTXEngine +from otx.backend.native.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, OVMODEL_PER_TASK from otx.core.config.hpo import HpoConfig from otx.core.data.module import OTXDataModule from otx.core.model.base import OTXModel from otx.core.types.task import OTXTaskType -from otx.engine import Engine -from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK, OVMODEL_PER_TASK - from tests.test_helpers import CommonSemanticSegmentationExporter @@ -44,7 +44,7 @@ def test_engine_from_config( ) tmp_path_train = tmp_path / task - engine = Engine.from_config( + engine = OTXEngine.from_config( config_path=DEFAULT_CONFIG_PER_TASK[task], data_root=fxt_target_dataset_per_task[task.value.lower()], work_dir=tmp_path_train, @@ -163,7 +163,7 @@ def test_engine_from_tile_recipe( data_root = tmp_path / "tiling_detection_css" dataset.export(data_root, format=CommonSemanticSegmentationExporter, save_media=True) - engine = Engine.from_config( + engine = OTXEngine.from_config( config_path=recipe, data_root=data_root, work_dir=tmp_path / task, @@ -215,7 +215,7 @@ def test_otx_hpo( num_workers=1, ) work_dir = str(tmp_path) - engine = Engine( + engine = OTXEngine( data_root=fxt_target_dataset_per_task[task.lower()], task=task, work_dir=work_dir, diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index d82723470ec..f485b11964f 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -6,8 +6,9 @@ import numpy as np import openvino.runtime as ov import pytest + +from otx.backend.native.engine import OTXEngine from otx.core.data.entity.base import OTXBatchPredEntity -from otx.engine import Engine RECIPE_LIST_ALL = pytest.RECIPE_LIST MULTI_CLASS_CLS = [recipe for recipe in RECIPE_LIST_ALL if "multi_class_cls" in recipe] @@ -61,7 +62,7 @@ def test_forward_explain( if "yolov9" in recipe: pytest.skip("yolov9 on detection is not supported yet.") - engine = Engine.from_config( + engine = OTXEngine.from_config( config_path=recipe, data_root=fxt_target_dataset_per_task[task], device=fxt_accelerator, @@ -127,7 +128,7 @@ def test_predict_with_explain( pytest.skip("yolov9 on detection is not supported yet.") tmp_path = tmp_path / f"otx_xai_{model_name}" - engine = Engine.from_config( + engine = OTXEngine.from_config( config_path=recipe, data_root=fxt_target_dataset_per_task[task], device=fxt_accelerator, diff --git a/tests/integration/cli/test_auto_configuration.py b/tests/integration/cli/test_auto_configuration.py index 6fa7d1f87b7..25d1826f496 100644 --- a/tests/integration/cli/test_auto_configuration.py +++ b/tests/integration/cli/test_auto_configuration.py @@ -5,9 +5,9 @@ from pathlib import Path import pytest -from otx.core.types.task import OTXTaskType -from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK +from otx.backend.native.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK +from otx.core.types.task import OTXTaskType from tests.utils import run_main diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index 3c11993ddab..a8bf4f08c86 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -9,9 +9,9 @@ import pytest import torch import yaml -from otx.core.types.task import OTXTaskType -from otx.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK +from otx.backend.native.engine.utils.auto_configurator import DEFAULT_CONFIG_PER_TASK +from otx.core.types.task import OTXTaskType from tests.utils import ExportCase2Test, run_main diff --git a/tests/unit/algo/explain/test_saliency_map_dumping.py b/tests/unit/algo/explain/test_saliency_map_dumping.py index 643c04b4c43..57ee278588e 100644 --- a/tests/unit/algo/explain/test_saliency_map_dumping.py +++ b/tests/unit/algo/explain/test_saliency_map_dumping.py @@ -5,12 +5,13 @@ import cv2 import numpy as np + from otx.algo.utils.xai_utils import dump_saliency_maps +from otx.backend.native.engine.utils.auto_configurator import AutoConfigurator from otx.core.config.explain import ExplainConfig from otx.core.data.entity.base import ImageInfo from otx.core.data.entity.classification import MulticlassClsBatchPredEntity from otx.core.types.task import OTXTaskType -from otx.engine.utils.auto_configurator import AutoConfigurator NUM_CLASSES = 5 BATCH_SIZE = 25 diff --git a/tests/unit/cli/test_cli.py b/tests/unit/cli/test_cli.py index 07aa5d083e2..7ee84ba4f1a 100644 --- a/tests/unit/cli/test_cli.py +++ b/tests/unit/cli/test_cli.py @@ -8,9 +8,10 @@ import pytest import torch import yaml -from otx.cli import OTXCLI, main from rich.console import Console +from otx.cli import OTXCLI, main + class TestOTXCLI: def test_init(self, mocker) -> None: @@ -112,9 +113,9 @@ def test_instantiate_classes(self, fxt_train_command, mocker) -> None: assert isinstance(cli.datamodule, OTXDataModule) - from otx.engine import Engine + from otx.backend.native.engine import OTXEngine - assert isinstance(cli.engine, Engine) + assert isinstance(cli.engine, OTXEngine) assert cli.datamodule == cli.engine.datamodule assert cli.model == cli.engine.model diff --git a/tests/unit/engine/adaptive_bs/test_adaptive_bs_api.py b/tests/unit/engine/adaptive_bs/test_adaptive_bs_api.py index 4668210ce0c..9976b53f8e9 100644 --- a/tests/unit/engine/adaptive_bs/test_adaptive_bs_api.py +++ b/tests/unit/engine/adaptive_bs/test_adaptive_bs_api.py @@ -9,9 +9,15 @@ import pytest from lightning.pytorch.loggers.logger import DummyLogger + +from otx.backend.native.engine.adaptive_bs import adaptive_bs_api as target_file +from otx.backend.native.engine.adaptive_bs.adaptive_bs_api import ( + BatchSizeFinder, + _adjust_train_args, + _train_model, + adapt_batch_size, +) from otx.core.types.task import OTXTaskType -from otx.engine.adaptive_bs import adaptive_bs_api as target_file -from otx.engine.adaptive_bs.adaptive_bs_api import BatchSizeFinder, _adjust_train_args, _train_model, adapt_batch_size @pytest.fixture() diff --git a/tests/unit/engine/adaptive_bs/test_bs_search_algo.py b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py index dc4d94d0fa4..0ef9013512f 100644 --- a/tests/unit/engine/adaptive_bs/test_bs_search_algo.py +++ b/tests/unit/engine/adaptive_bs/test_bs_search_algo.py @@ -4,8 +4,13 @@ from unittest.mock import MagicMock import pytest -from otx.engine.adaptive_bs import bs_search_algo as target_file -from otx.engine.adaptive_bs.bs_search_algo import BsSearchAlgo, _get_max_memory_reserved, _get_total_memory_size + +from otx.backend.native.engine.adaptive_bs import bs_search_algo as target_file +from otx.backend.native.engine.adaptive_bs.bs_search_algo import ( + BsSearchAlgo, + _get_max_memory_reserved, + _get_total_memory_size, +) @pytest.fixture() diff --git a/tests/unit/engine/hpo/test_hpo_api.py b/tests/unit/engine/hpo/test_hpo_api.py index 8b24dffcf00..3ec308ba059 100644 --- a/tests/unit/engine/hpo/test_hpo_api.py +++ b/tests/unit/engine/hpo/test_hpo_api.py @@ -12,17 +12,18 @@ import pytest import torch import yaml -from otx.core.config.hpo import HpoConfig -from otx.core.optimizer.callable import OptimizerCallableSupportHPO -from otx.core.schedulers import LinearWarmupSchedulerCallable, SchedulerCallableSupportHPO -from otx.engine.hpo import hpo_api as target_file -from otx.engine.hpo.hpo_api import ( + +from otx.backend.native.engine.hpo import hpo_api as target_file +from otx.backend.native.engine.hpo.hpo_api import ( HPOConfigurator, _adjust_train_args, _remove_unused_model_weights, _update_hpo_progress, execute_hpo, ) +from otx.core.config.hpo import HpoConfig +from otx.core.optimizer.callable import OptimizerCallableSupportHPO +from otx.core.schedulers import LinearWarmupSchedulerCallable, SchedulerCallableSupportHPO if TYPE_CHECKING: from pathlib import Path diff --git a/tests/unit/engine/hpo/test_hpo_trial.py b/tests/unit/engine/hpo/test_hpo_trial.py index e50a116905b..d29cd675a4e 100644 --- a/tests/unit/engine/hpo/test_hpo_trial.py +++ b/tests/unit/engine/hpo/test_hpo_trial.py @@ -11,9 +11,11 @@ import pytest from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from torch import tensor + from otx.algo.callbacks.adaptive_train_scheduling import AdaptiveTrainScheduling -from otx.engine.hpo import hpo_trial as target_file -from otx.engine.hpo.hpo_trial import ( +from otx.backend.native.engine.hpo import hpo_trial as target_file +from otx.backend.native.engine.hpo.hpo_trial import ( HPOCallback, HPOInitWeightCallback, _get_hpo_initial_weight, @@ -22,9 +24,8 @@ run_hpo_trial, update_hyper_parameter, ) -from otx.engine.hpo.utils import get_hpo_weight_dir +from otx.backend.native.engine.hpo.utils import get_hpo_weight_dir from otx.hpo import TrialStatus -from torch import tensor if TYPE_CHECKING: from lightning import Callback diff --git a/tests/unit/engine/hpo/test_utils.py b/tests/unit/engine/hpo/test_utils.py index c5ddc21d160..04d81286e8b 100644 --- a/tests/unit/engine/hpo/test_utils.py +++ b/tests/unit/engine/hpo/test_utils.py @@ -8,7 +8,8 @@ import pytest from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint -from otx.engine.hpo.utils import ( + +from otx.backend.native.engine.hpo.utils import ( find_trial_file, get_best_hpo_weight, get_callable_args_name, diff --git a/tests/unit/engine/test_engine.py b/tests/unit/engine/test_engine.py index 3adcc5678d7..99adaa77490 100644 --- a/tests/unit/engine/test_engine.py +++ b/tests/unit/engine/test_engine.py @@ -5,23 +5,24 @@ from unittest.mock import MagicMock import pytest +from pytest_mock import MockerFixture + from otx.algo.classification.efficientnet import EfficientNetForMulticlassCls from otx.algo.classification.torchvision_model import TVModelForMulticlassCls +from otx.backend.native.engine import OTXEngine from otx.core.model.base import OTXModel, OVModel from otx.core.types.export import OTXExportFormatType from otx.core.types.label import NullLabelInfo from otx.core.types.precision import OTXPrecisionType -from otx.engine import Engine -from pytest_mock import MockerFixture @pytest.fixture() -def fxt_engine(tmp_path) -> Engine: +def fxt_engine(tmp_path) -> OTXEngine: recipe_path = "src/otx/recipe/classification/multi_class_cls/tv_mobilenet_v3_small.yaml" data_root = "tests/assets/classification_dataset" task_type = "MULTI_CLASS_CLS" - return Engine.from_config( + return OTXEngine.from_config( config_path=recipe_path, data_root=data_root, task=task_type, @@ -32,11 +33,11 @@ def fxt_engine(tmp_path) -> Engine: class TestEngine: def test_constructor(self, tmp_path) -> None: with pytest.raises(RuntimeError): - Engine(work_dir=tmp_path) + OTXEngine(work_dir=tmp_path) # Check auto-configuration data_root = "tests/assets/classification_dataset" - engine = Engine(work_dir=tmp_path, data_root=data_root) + engine = OTXEngine(work_dir=tmp_path, data_root=data_root) assert engine.task == "MULTI_CLASS_CLS" assert engine.datamodule.task == "MULTI_CLASS_CLS" assert isinstance(engine.model, EfficientNetForMulticlassCls) @@ -50,7 +51,7 @@ def test_constructor(self, tmp_path) -> None: # Create engine with no data_root with pytest.raises(ValueError, match="Given model class (.*) requires a valid label_info to instantiate."): - _ = Engine(work_dir=tmp_path, task="MULTI_CLASS_CLS") + _ = OTXEngine(work_dir=tmp_path, task="MULTI_CLASS_CLS") @pytest.fixture() def mock_datamodule(self, mocker): @@ -61,13 +62,13 @@ def mock_datamodule(self, mocker): mock_datamodule.input_size = input_size return mocker.patch( - "otx.engine.utils.auto_configurator.AutoConfigurator.get_datamodule", + "otx.backend.native.engine.utils.auto_configurator.AutoConfigurator.get_datamodule", return_value=mock_datamodule, ) def test_model_init(self, tmp_path, mock_datamodule): data_root = "tests/assets/classification_dataset" - engine = Engine(work_dir=tmp_path, data_root=data_root) + engine = OTXEngine(work_dir=tmp_path, data_root=data_root) assert engine._model.input_size == (1234, 1234) assert engine._model.label_info.num_classes == 4321 @@ -75,7 +76,7 @@ def test_model_init(self, tmp_path, mock_datamodule): def test_model_init_datamodule_ipt_size_int(self, tmp_path, mock_datamodule): mock_datamodule.input_size = 1234 data_root = "tests/assets/classification_dataset" - engine = Engine(work_dir=tmp_path, data_root=data_root) + engine = OTXEngine(work_dir=tmp_path, data_root=data_root) assert engine._model.input_size == (1234, 1234) assert engine._model.label_info.num_classes == 4321 @@ -350,7 +351,7 @@ def test_from_config_with_model_name(self, tmp_path) -> None: "data.test_subset.subset_name": "TESTING", } - engine = Engine.from_model_name( + engine = OTXEngine.from_model_name( model_name=model_name, data_root=data_root, task=task_type, @@ -363,7 +364,7 @@ def test_from_config_with_model_name(self, tmp_path) -> None: assert engine.datamodule.test_subset.subset_name == "TESTING" with pytest.raises(FileNotFoundError): - engine = Engine.from_model_name( + engine = OTXEngine.from_model_name( model_name="wrong_model", data_root=data_root, task=task_type, @@ -381,7 +382,7 @@ def test_from_config(self, tmp_path) -> None: "data.test_subset.subset_name": "TESTING", } - engine = Engine.from_config( + engine = OTXEngine.from_config( config_path=recipe_path, data_root=data_root, task=task_type, @@ -430,6 +431,6 @@ def test_num_devices(self, fxt_engine, tmp_path) -> None: assert fxt_engine._cache.args.get("devices") == 2 data_root = "tests/assets/classification_dataset" - engine = Engine(work_dir=tmp_path, data_root=data_root, num_devices=3) + engine = OTXEngine(work_dir=tmp_path, data_root=data_root, num_devices=3) assert engine.num_devices == 3 assert engine._cache.args.get("devices") == 3 diff --git a/tests/unit/engine/utils/test_api.py b/tests/unit/engine/utils/test_api.py index 7e6ad379052..55abde5912e 100644 --- a/tests/unit/engine/utils/test_api.py +++ b/tests/unit/engine/utils/test_api.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest + +from otx.backend.native.engine.utils.api import RECIPE_PATH, list_models from otx.core.types.task import OTXTaskType -from otx.engine.utils.api import RECIPE_PATH, list_models def test_list_models() -> None: diff --git a/tests/unit/engine/utils/test_auto_configurator.py b/tests/unit/engine/utils/test_auto_configurator.py index 681b1b24639..2824b0de2aa 100644 --- a/tests/unit/engine/utils/test_auto_configurator.py +++ b/tests/unit/engine/utils/test_auto_configurator.py @@ -6,17 +6,18 @@ import pytest import torch + +from otx.backend.native.engine.utils import auto_configurator as target_file +from otx.backend.native.engine.utils.auto_configurator import ( + DEFAULT_CONFIG_PER_TASK, + AutoConfigurator, + configure_task, +) from otx.core.data.module import OTXDataModule from otx.core.model.base import OTXModel from otx.core.types.label import LabelInfo, SegLabelInfo from otx.core.types.task import OTXTaskType from otx.core.types.transformer_libs import TransformLibType -from otx.engine.utils import auto_configurator as target_file -from otx.engine.utils.auto_configurator import ( - DEFAULT_CONFIG_PER_TASK, - AutoConfigurator, - configure_task, -) from otx.utils.utils import should_pass_label_info