Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resubmitting the PR for later merging #28

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions doctr/models/detection/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

__all__ = ["DetectionPredictor"]

from doctr.utils.gpu import select_gpu_device


class DetectionPredictor(nn.Module):
"""Implements an object able to localize text elements in a document
Expand All @@ -27,29 +29,26 @@ def __init__(
pre_processor: PreProcessor,
model: nn.Module,
) -> None:

super().__init__()
self.model = model.eval()
self.pre_processor = pre_processor
self.postprocessor = self.model.postprocessor
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.environ.get("CUDA_VISIBLE_DEVICES", []) == "":
self.device = torch.device("cpu")
elif len(os.environ.get("CUDA_VISIBLE_DEVICES", [])) > 0:
self.device = torch.device("cuda")
if "onnx" not in str((type(self.model))) and (self.device == torch.device("cuda")):

detected_device, selected_device = select_gpu_device()
if "onnx" in str((type(self.model))):
selected_device = 'cpu'
# self.model = nn.DataParallel(self.model)
# self.model = self.model.half()
self.model = self.model.to(self.device)
self.device = torch.device(selected_device)
self.model = self.model.to(self.device)

@torch.no_grad()
def forward(
self,
pages: List[Union[np.ndarray, torch.Tensor]],
return_model_output = False,
return_model_output=False,
**kwargs: Any,
) -> List[np.ndarray]:

# Dimension check
if any(page.ndim != 3 for page in pages):
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
Expand Down
17 changes: 8 additions & 9 deletions doctr/models/recognition/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch import nn
import os
from doctr.models.preprocessor import PreProcessor
from doctr.utils.gpu import select_gpu_device

from ._utils import remap_preds, split_crops

Expand All @@ -31,20 +32,19 @@ def __init__(
model: nn.Module,
split_wide_crops: bool = True,
) -> None:

super().__init__()
self.pre_processor = pre_processor
self.model = model.eval()
self.postprocessor = self.model.postprocessor
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if os.environ.get("CUDA_VISIBLE_DEVICES", []) == "":
self.device = torch.device("cpu")
elif len(os.environ.get("CUDA_VISIBLE_DEVICES", [])) > 0:
self.device = torch.device("cuda")
if "onnx" not in str((type(self.model))) and (self.device == torch.device("cuda")):

detected_device, selected_device = select_gpu_device()
if "onnx" in str((type(self.model))):
selected_device = 'cpu'
# self.model = nn.DataParallel(self.model)
self.model = self.model.to(self.device)
# self.model = self.model.half()
self.device = torch.device(selected_device)
self.model = self.model.to(self.device)

self.split_wide_crops = split_wide_crops
self.critical_ar = 8 # Critical aspect ratio
self.dil_factor = 1.4 # Dilation factor to overlap the crops
Expand All @@ -56,7 +56,6 @@ def forward(
crops: Sequence[Union[np.ndarray, torch.Tensor]],
**kwargs: Any,
) -> List[Tuple[str, float]]:

if len(crops) == 0:
return []
# Dimension check
Expand Down
40 changes: 40 additions & 0 deletions doctr/utils/gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging
import os
from typing import Tuple
import torch

log = logging.getLogger(__name__)


def select_gpu_device() -> Tuple[str, str]:
"""tries to find either cuda or arm mps gpu accelerator and choses the most appropriate one,
honoring the environment variables (CUDA_VISIBLE_DEVICES), if any have been set.

returns tuple(best_detected_device, selected_device)
best_detected_device reflects capabilities of the system
selected_device is the device that should be used (might be cpu even if best_detected_device is eg cuda)
"""
if torch.cuda.is_available():
detected_gpu_device = 'cuda'
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
detected_gpu_device = 'mps'
else:
detected_gpu_device = 'cpu'

selected_gpu_device = detected_gpu_device
match detected_gpu_device: # various exceptions to the above
case 'cuda':
if os.environ.get("CUDA_VISIBLE_DEVICES") == "":
selected_gpu_device = 'cpu'
case 'mps':
# FIXME detected mps selects cpu here because of the many bugs present in the mps implementation of
# torch'es 1.13 LSTM. As of 5/29/2023, they appear to be actively fixing them. I did try with torch
# 2.0.1 and while the bugs look different it's still broken. Revisit when later versions of torch
# are available.
# pass
selected_gpu_device = 'cpu'
case 'cpu':
pass

log.info(f"{detected_gpu_device=} {selected_gpu_device=}")
return detected_gpu_device, selected_gpu_device