diff --git a/sahi/model.py b/sahi/model.py index 8ff9119..e51dbf2 100644 --- a/sahi/model.py +++ b/sahi/model.py @@ -1097,6 +1097,10 @@ def load_model(self): def perform_inference(self, image: np.ndarray): img, self.img_shape, self.src_shape = precess_image(image, img_size=self.image_size, stride=self.stride) + + # move the input tensor to the same device as the model + img = img.to(self.device) + self._original_predictions = self.model(img) def _create_object_prediction_list_from_original_predictions( @@ -1120,7 +1124,7 @@ def _create_object_prediction_list_from_original_predictions( for *xyxy, conf, cls in reversed(det): category_id = int(cls) category_name = COCO_CLASSES[category_id] - score = float(conf.numpy()) + score = float(conf.cpu().numpy()) bbox = [int(xyxy[0]), int(xyxy[1]), int(xyxy[2]), int(xyxy[3])] object_prediction = ObjectPrediction( bbox=bbox,