diff --git a/detect.py b/detect.py index f791faa0908..6bfb06e0e20 100644 --- a/detect.py +++ b/detect.py @@ -35,6 +35,7 @@ import sys from pathlib import Path +import numpy as np import torch FILE = Path(__file__).resolve() @@ -131,7 +132,7 @@ def run( # Run inference model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup - seen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device)) + seen, windows, dt = 0, [], (Profile(), Profile(), Profile()) for path, im, im0s, vid_cap, s in dataset: with dt[0]: im = torch.from_numpy(im).to(model.device) @@ -139,22 +140,12 @@ def run( im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim - if model.xml and im.shape[0] > 1: - ims = torch.chunk(im, im.shape[0], 0) # Inference with dt[1]: visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False - if model.xml and im.shape[0] > 1: - pred = None - for image in ims: - if pred is None: - pred = model(image, augment=augment, visualize=visualize).unsqueeze(0) - else: - pred = torch.cat((pred, model(image, augment=augment, visualize=visualize).unsqueeze(0)), dim=0) - pred = [pred, None] - else: - pred = model(im, augment=augment, visualize=visualize) + pred = model(im, augment=augment, visualize=visualize) + # NMS with dt[2]: pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) @@ -167,7 +158,6 @@ def run( # Create or append to the CSV file def write_to_csv(image_name, prediction, confidence): - """Writes prediction data for an image to a CSV file, appending if the file exists.""" data = {"Image Name": image_name, "Prediction": prediction, "Confidence": confidence} with open(csv_path, mode="a", newline="") as f: writer = csv.DictWriter(f, fieldnames=data.keys()) @@ -191,10 +181,37 @@ def write_to_csv(image_name, prediction, confidence): gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh imc = im0.copy() if save_crop else im0 # for save_crop annotator = Annotator(im0, line_width=line_thickness, example=str(names)) + + # Calculate the center of the image + img_center = np.array([im0.shape[1] // 2, im0.shape[0] // 2]) + min_distance = None + closest_box = None + if len(det): # Rescale boxes from img_size to im0 size det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + # Calculate the center of the image + img_center = np.array([im0.shape[1] // 2, im0.shape[0] // 2]) + + # Calculate centers of all detection boxes and find the closest one to the image center + centers = np.array( + [ + [(xyxy[0].cpu() + xyxy[2].cpu()) / 2, (xyxy[1].cpu() + xyxy[3].cpu()) / 2] + for *xyxy, _, _ in reversed(det) + ] + ) + distances = np.linalg.norm(centers - img_center, axis=1) + closest_idx = np.argmin(distances) + + # Draw boxes, marking the closest one in green + for j, (*xyxy, conf, cls) in enumerate(reversed(det)): + color = (0, 255, 0) if j == closest_idx else colors(int(cls), True) + annotator.box_label(xyxy, f"{names[int(cls)]} {conf:.2f}", color=color) + + # Rescale boxes from img_size to im0 size + det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + # Print results for c in det[:, 5].unique(): n = (det[:, 5] == c).sum() # detections per class @@ -266,7 +283,6 @@ def write_to_csv(image_name, prediction, confidence): def parse_opt(): - """Parses command-line arguments for YOLOv5 detection, setting inference options and model configurations.""" parser = argparse.ArgumentParser() parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model path or triton URL") parser.add_argument("--source", type=str, default=ROOT / "data/images", help="file/dir/URL/glob/screen/0(webcam)") @@ -303,11 +319,13 @@ def parse_opt(): def main(opt): - """Executes YOLOv5 model inference with given options, checking requirements before running the model.""" check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop")) run(**vars(opt)) +# python detect.py --weights runs/train/exp10/weights/best.pt --source project/test +# python detect.py --weights runs/train/exp10/weights/best.pt --source project/test + if __name__ == "__main__": opt = parse_opt() main(opt)