Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📤 Add export task (coreml and tflite) #174

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions yolo/config/task/export.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
task: export

format: onnx
5 changes: 4 additions & 1 deletion yolo/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
sys.path.append(str(project_root))

from yolo.config.config import Config
from yolo.tools.solver import InferenceModel, TrainModel, ValidateModel
from yolo.tools.solver import ExportModel, InferenceModel, TrainModel, ValidateModel
from yolo.utils.logging_utils import setup


Expand Down Expand Up @@ -39,6 +39,9 @@ def main(cfg: Config):
if cfg.task.task == "inference":
model = InferenceModel(cfg)
trainer.predict(model)
if cfg.task.task == "export":
model = ExportModel(cfg)
model.export()


if __name__ == "__main__":
Expand Down
75 changes: 71 additions & 4 deletions yolo/model/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ class YOLO(nn.Module):
parameters, and any other relevant configuration details.
"""

def __init__(self, model_cfg: ModelConfig, class_num: int = 80):
def __init__(self, model_cfg: ModelConfig, class_num: int = 80, export_mode: bool = False):
super(YOLO, self).__init__()
self.num_classes = class_num
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
self.model: List[YOLOLayer] = nn.ModuleList()
self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
self.export_mode = export_mode
self.build_model(model_cfg.model)

def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
Expand Down Expand Up @@ -68,10 +69,44 @@ def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
setattr(layer, "out_c", out_channels)
layer_idx += 1

def generate_anchors(self, image_size: List[int], strides: List[int]):
W, H = image_size
anchors = []
scaler = []
for stride in strides:
anchor_num = W // stride * H // stride
scaler.append(torch.full((anchor_num,), stride))
shift = stride // 2
h = torch.arange(0, H, stride) + shift
w = torch.arange(0, W, stride) + shift
if torch.__version__ >= "2.3.0":
anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
else:
anchor_h, anchor_w = torch.meshgrid(h, w)
anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
anchors.append(anchor)
all_anchors = torch.cat(anchors, dim=0)
all_scalers = torch.cat(scaler, dim=0)
return all_anchors, all_scalers

def get_strides(self, output, input_width) -> List[int]:
W = input_width
strides = []
for predict_head in output:
_, _, *anchor_num = predict_head[2].shape
strides.append(W // anchor_num[1])

return strides

def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] = None):
input_width, input_height = x.shape[-2:]
y = {0: x, **(external or {})}
output = dict()
for index, layer in enumerate(self.model, start=1):

# Use a simple loop instead of enumerate()
# Needed for torch export compatibility
index = 1
for layer in self.model:
if isinstance(layer.source, list):
model_input = [y[idx] for idx in layer.source]
else:
Expand All @@ -81,12 +116,42 @@ def forward(self, x, external: Optional[Dict] = None, shortcut: Optional[str] =

x = layer(model_input, **external_input)
y[-1] = x

if layer.usable:
y[index] = x

if layer.output:
output[layer.tags] = x
if layer.tags == shortcut:
return output

index += 1

if self.export_mode:

preds_cls, preds_anc, preds_box = [], [], []
for layer_output in output["Main"]:
pred_cls, pred_anc, pred_box = layer_output
preds_cls.append(pred_cls.permute(0, 2, 3, 1).reshape(pred_cls.shape[0], -1, pred_cls.shape[1]))
preds_anc.append(
pred_anc.permute(0, 3, 4, 1, 2).reshape(pred_anc.shape[0], -1, pred_anc.shape[2], pred_anc.shape[1])
)
preds_box.append(pred_box.permute(0, 2, 3, 1).reshape(pred_box.shape[0], -1, pred_box.shape[1]))

preds_cls = torch.concat(preds_cls, dim=1).to(x[0][0].device)
preds_anc = torch.concat(preds_anc, dim=1).to(x[0][0].device)
preds_box = torch.concat(preds_box, dim=1).to(x[0][0].device)

strides = self.get_strides(output["Main"], input_width)
anchor_grid, scaler = self.generate_anchors([input_width, input_height], strides) #
anchor_grid = anchor_grid.to(x[0][0].device)
scaler = scaler.to(x[0][0].device)
pred_LTRB = preds_box * scaler.view(1, -1, 1)
lt, rb = pred_LTRB.chunk(2, dim=-1)
preds_box = torch.cat([anchor_grid - lt, anchor_grid + rb], dim=-1)

return preds_cls, preds_anc, preds_box

return output

def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
Expand Down Expand Up @@ -158,7 +223,9 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]):
self.model.load_state_dict(model_state_dict)


def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO:
def create_model(
model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80, export_mode: bool = False
) -> YOLO:
"""Constructs and returns a model from a Dictionary configuration file.

Args:
Expand All @@ -168,7 +235,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True,
YOLO: An instance of the model defined by the given configuration.
"""
OmegaConf.set_struct(model_cfg, False)
model = YOLO(model_cfg, class_num)
model = YOLO(model_cfg, class_num, export_mode=export_mode)
if weight_path:
if weight_path == True:
weight_path = Path("weights") / f"{model_cfg.name}.pt"
Expand Down
56 changes: 51 additions & 5 deletions yolo/tools/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
from yolo.tools.drawer import draw_bboxes
from yolo.tools.loss_functions import create_loss_function
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
from yolo.utils.deploy_utils import FastModelLoader
from yolo.utils.export_utils import ModelExporter
from yolo.utils.logger import logger
from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler


class BaseModel(LightningModule):
def __init__(self, cfg: Config):
def __init__(self, cfg: Config, export_mode: bool = False):
super().__init__()
self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
self.model = create_model(
cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight, export_mode=export_mode
)

def forward(self, x):
return self.model(x)
Expand Down Expand Up @@ -109,23 +114,39 @@ def configure_optimizers(self):

class InferenceModel(BaseModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
if hasattr(cfg.model.model, "auxiliary"):
cfg.model.model.auxiliary = {}

export_mode = False
fast_inference = cfg.task.fast_inference
# TODO check if we can use export mode for all formats
if fast_inference == "coreml":
export_mode = True

super().__init__(cfg, export_mode=export_mode)
self.cfg = cfg
# TODO: Add FastModel
self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)

def setup(self, stage):
self.vec2box = create_converter(
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
)

if self.cfg.task.fast_inference:
self.fast_model = FastModelLoader(self.cfg, self.model).load_model(self.device)

self.post_process = PostProcess(self.vec2box, self.cfg.task.nms)

def predict_dataloader(self):
return self.predict_loader

def predict_step(self, batch, batch_idx):
images, rev_tensor, origin_frame = batch
predicts = self.post_process(self(images), rev_tensor=rev_tensor)
if hasattr(self, "fast_model") and self.fast_model:
predictions = self.fast_model(images)
else:
predictions = self(images)
predicts = self.post_process(predictions, rev_tensor=rev_tensor)
img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list)
if getattr(self.predict_loader, "is_stream", None):
fps = self._display_stream(img)
Expand All @@ -139,3 +160,28 @@ def _save_image(self, img, batch_idx):
save_image_path = Path(self.trainer.default_root_dir) / f"frame{batch_idx:03d}.png"
img.save(save_image_path)
print(f"💾 Saved visualize image at {save_image_path}")


class ExportModel(BaseModel):
def __init__(self, cfg: Config):
if hasattr(cfg.model.model, "auxiliary"):
cfg.model.model.auxiliary = {}

export_mode = False
format = cfg.task.format
# TODO check if we can use export mode for all formats
if self.format == "coreml":
export_mode = True

super().__init__(cfg, export_mode=export_mode)
self.cfg = cfg
self.format = format
self.model_exporter = ModelExporter(self.cfg, self.model, format=self.format)

def export(self):
if self.format == "onnx":
self.model_exporter.export_onnx()
if self.format == "tflite":
self.model_exporter.export_tflite()
if self.format == "coreml":
self.model_exporter.export_coreml()
7 changes: 4 additions & 3 deletions yolo/utils/bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,14 @@ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
if hasattr(anchor_cfg, "strides"):
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
self.strides = anchor_cfg.strides
else:
elif not model.export_mode:
logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
self.strides = self.create_auto_anchor(model, image_size)

anchor_grid, scaler = generate_anchors(image_size, self.strides)
self.image_size = image_size
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
if not model.export_mode:
anchor_grid, scaler = generate_anchors(image_size, self.strides)
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)

def create_auto_anchor(self, model: YOLO, image_size):
W, H = image_size
Expand Down
Loading
Loading