From de396fe3eabc0087d8fe56d0ca4b1b68a0801d53 Mon Sep 17 00:00:00 2001 From: Grenka054 Date: Sat, 13 Jul 2024 16:37:39 +0700 Subject: [PATCH] Update model.py --- sahi/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sahi/model.py b/sahi/model.py index 8ff9119..e51dbf2 100644 --- a/sahi/model.py +++ b/sahi/model.py @@ -1097,6 +1097,10 @@ def load_model(self): def perform_inference(self, image: np.ndarray): img, self.img_shape, self.src_shape = precess_image(image, img_size=self.image_size, stride=self.stride) + + # move the input tensor to the same device as the model + img = img.to(self.device) + self._original_predictions = self.model(img) def _create_object_prediction_list_from_original_predictions( @@ -1120,7 +1124,7 @@ def _create_object_prediction_list_from_original_predictions( for *xyxy, conf, cls in reversed(det): category_id = int(cls) category_name = COCO_CLASSES[category_id] - score = float(conf.numpy()) + score = float(conf.cpu().numpy()) bbox = [int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])] object_prediction = ObjectPrediction( bbox=bbox,