Skip to content

Commit

Permalink
close PIL images when loading images to tensor/numpy (#1598)
Browse files Browse the repository at this point in the history
Co-authored-by: Benedikt Fuchs <[email protected]>
  • Loading branch information
helpmefindaname and Benedikt Fuchs authored May 15, 2024
1 parent 3f116ad commit 45c2df3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
3 changes: 2 additions & 1 deletion doctr/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def crop_bboxes_from_image(img_path: Union[str, Path], geoms: np.ndarray) -> Lis
-------
a list of cropped images
"""
img: np.ndarray = np.array(Image.open(img_path).convert("RGB"))
with Image.open(img_path) as pil_img:
img: np.ndarray = np.array(pil_img.convert("RGB"))
# Polygon
if geoms.ndim == 3 and geoms.shape[1:] == (4, 2):
return extract_rcrops(img, geoms.astype(dtype=int))
Expand Down
10 changes: 4 additions & 6 deletions doctr/io/image/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float3
if dtype not in (torch.uint8, torch.float16, torch.float32):
raise ValueError("insupported value for dtype")

pil_img = Image.open(img_path, mode="r").convert("RGB")

return tensor_from_pil(pil_img, dtype)
with Image.open(img_path, mode="r") as pil_img:
return tensor_from_pil(pil_img.convert("RGB"), dtype)


def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor:
Expand All @@ -71,9 +70,8 @@ def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32)
if dtype not in (torch.uint8, torch.float16, torch.float32):
raise ValueError("insupported value for dtype")

pil_img = Image.open(BytesIO(img_content), mode="r").convert("RGB")

return tensor_from_pil(pil_img, dtype)
with Image.open(BytesIO(img_content), mode="r") as pil_img:
return tensor_from_pil(pil_img.convert("RGB"), dtype)


def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor:
Expand Down

0 comments on commit 45c2df3

Please sign in to comment.