From 6701f27afa62712b34a17d4b0ff879156b0c7937 Mon Sep 17 00:00:00 2001 From: Muller Hsu Date: Fri, 5 Feb 2021 03:43:30 +0800 Subject: [PATCH] To support io.Bytesio (#339) * To support io.Bytesio * fix doc error * fix format * add unit test * fix assert condition * fix np ndarray compare in an error way * fix type error, cast the LocalPath to pathlib.Path --- keras_preprocessing/image/utils.py | 113 ++++++++++++++++------------- tests/image/utils_test.py | 24 ++++++ 2 files changed, 86 insertions(+), 51 deletions(-) diff --git a/keras_preprocessing/image/utils.py b/keras_preprocessing/image/utils.py index e196066d..91c55804 100644 --- a/keras_preprocessing/image/utils.py +++ b/keras_preprocessing/image/utils.py @@ -3,6 +3,7 @@ import io import os import warnings +from pathlib import Path import numpy as np @@ -77,7 +78,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None, """Loads an image into PIL format. # Arguments - path: Path to image file. + path: Path (string), pathlib.Path object, or io.BytesIO stream to image file. grayscale: DEPRECATED use `color_mode="grayscale"`. color_mode: The desired image format. One of "grayscale", "rgb", "rgba". "grayscale" supports 8-bit images and 32-bit signed integer images. @@ -101,6 +102,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None, # Raises ImportError: if PIL is not available. ValueError: if interpolation method is not supported. + TypeError: type of 'path' should be path-like or io.Byteio. """ if grayscale is True: warnings.warn('grayscale is deprecated. Please use ' @@ -109,56 +111,65 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None, if pil_image is None: raise ImportError('Could not import PIL.Image. ' 'The use of `load_img` requires PIL.') - with open(path, 'rb') as f: - img = pil_image.open(io.BytesIO(f.read())) - if color_mode == 'grayscale': - # if image is not already an 8-bit, 16-bit or 32-bit grayscale image - # convert it to an 8-bit grayscale image. - if img.mode not in ('L', 'I;16', 'I'): - img = img.convert('L') - elif color_mode == 'rgba': - if img.mode != 'RGBA': - img = img.convert('RGBA') - elif color_mode == 'rgb': - if img.mode != 'RGB': - img = img.convert('RGB') - else: - raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"') - if target_size is not None: - width_height_tuple = (target_size[1], target_size[0]) - if img.size != width_height_tuple: - if interpolation not in _PIL_INTERPOLATION_METHODS: - raise ValueError( - 'Invalid interpolation method {} specified. Supported ' - 'methods are {}'.format( - interpolation, - ", ".join(_PIL_INTERPOLATION_METHODS.keys()))) - resample = _PIL_INTERPOLATION_METHODS[interpolation] - - if keep_aspect_ratio: - width, height = img.size - target_width, target_height = width_height_tuple - - crop_height = (width * target_height) // target_width - crop_width = (height * target_width) // target_height - - # Set back to input height / width - # if crop_height / crop_width is not smaller. - crop_height = min(height, crop_height) - crop_width = min(width, crop_width) - - crop_box_hstart = (height - crop_height) // 2 - crop_box_wstart = (width - crop_width) // 2 - crop_box_wend = crop_box_wstart + crop_width - crop_box_hend = crop_box_hstart + crop_height - crop_box = [crop_box_wstart, crop_box_hstart, - crop_box_wend, crop_box_hend] - - img = img.resize(width_height_tuple, resample, - box=crop_box) - else: - img = img.resize(width_height_tuple, resample) - return img + if isinstance(path, io.BytesIO): + img = pil_image.open(path) + elif isinstance(path, (Path, bytes, str)): + if isinstance(path, Path): + path = str(path.resolve()) + with open(path, 'rb') as f: + img = pil_image.open(io.BytesIO(f.read())) + else: + raise TypeError('path should be path-like or io.BytesIO' + ', not {}'.format(type(path))) + + if color_mode == 'grayscale': + # if image is not already an 8-bit, 16-bit or 32-bit grayscale image + # convert it to an 8-bit grayscale image. + if img.mode not in ('L', 'I;16', 'I'): + img = img.convert('L') + elif color_mode == 'rgba': + if img.mode != 'RGBA': + img = img.convert('RGBA') + elif color_mode == 'rgb': + if img.mode != 'RGB': + img = img.convert('RGB') + else: + raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"') + if target_size is not None: + width_height_tuple = (target_size[1], target_size[0]) + if img.size != width_height_tuple: + if interpolation not in _PIL_INTERPOLATION_METHODS: + raise ValueError( + 'Invalid interpolation method {} specified. Supported ' + 'methods are {}'.format( + interpolation, + ", ".join(_PIL_INTERPOLATION_METHODS.keys()))) + resample = _PIL_INTERPOLATION_METHODS[interpolation] + + if keep_aspect_ratio: + width, height = img.size + target_width, target_height = width_height_tuple + + crop_height = (width * target_height) // target_width + crop_width = (height * target_width) // target_height + + # Set back to input height / width + # if crop_height / crop_width is not smaller. + crop_height = min(height, crop_height) + crop_width = min(width, crop_width) + + crop_box_hstart = (height - crop_height) // 2 + crop_box_wstart = (width - crop_width) // 2 + crop_box_wend = crop_box_wstart + crop_width + crop_box_hend = crop_box_hstart + crop_height + crop_box = [ + crop_box_wstart, crop_box_hstart, crop_box_wend, + crop_box_hend + ] + img = img.resize(width_height_tuple, resample, box=crop_box) + else: + img = img.resize(width_height_tuple, resample) + return img def list_pictures(directory, ext=('jpg', 'jpeg', 'bmp', 'png', 'ppm', 'tif', diff --git a/tests/image/utils_test.py b/tests/image/utils_test.py index 8053401a..f70b5aa7 100644 --- a/tests/image/utils_test.py +++ b/tests/image/utils_test.py @@ -1,4 +1,6 @@ +import io import resource +from pathlib import Path import numpy as np import PIL @@ -193,6 +195,28 @@ def test_load_img(tmpdir): loaded_im_array = utils.img_to_array(loaded_im, dtype='int32') assert loaded_im_array.shape == (25, 25, 1) + # Test different path type + with open(filename_grayscale_32bit, 'rb') as f: + _path = io.BytesIO(f.read()) # io.Bytesio + loaded_im = utils.load_img(_path, color_mode='grayscale') + loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32) + assert np.all(loaded_im_array == original_grayscale_32bit_array) + + _path = filename_grayscale_32bit # str + loaded_im = utils.load_img(_path, color_mode='grayscale') + loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32) + assert np.all(loaded_im_array == original_grayscale_32bit_array) + + _path = filename_grayscale_32bit.encode() # bytes + loaded_im = utils.load_img(_path, color_mode='grayscale') + loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32) + assert np.all(loaded_im_array == original_grayscale_32bit_array) + + _path = Path(tmpdir / 'grayscale_32bit_utils.tiff') # Path + loaded_im = utils.load_img(_path, color_mode='grayscale') + loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32) + assert np.all(loaded_im_array == original_grayscale_32bit_array) + # Check that exception is raised if interpolation not supported. loaded_im = utils.load_img(filename_rgb, interpolation="unsupported")