Skip to content

Commit

Permalink
Add MANIFEST and inti
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed Oct 15, 2018
1 parent a058371 commit d24ce04
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
4 changes: 4 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include LICENSE
include README.md
include CONTRIBUTING.md
graft tests
12 changes: 7 additions & 5 deletions keras_preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,6 @@ def __init__(self,
self.rotation_range = rotation_range
self.width_shift_range = width_shift_range
self.height_shift_range = height_shift_range
self.brightness_range = brightness_range
self.shear_range = shear_range
self.zoom_range = zoom_range
self.channel_shift_range = channel_shift_range
Expand Down Expand Up @@ -856,6 +855,13 @@ def __init__(self,
'`samplewise_std_normalization`, '
'which overrides setting of '
'`samplewise_center`.')
if brightness_range is not None:
if (not isinstance(brightness_range, (tuple, list))
or len(brightness_range) != 2):
raise ValueError(
'`brightness_range should be tuple or list of two floats. '
'Received: %s' % (brightness_range,))
self.brightness_range = brightness_range

def flow(self, x,
y=None, batch_size=32, shuffle=True,
Expand Down Expand Up @@ -1227,10 +1233,6 @@ def get_random_transform(self, img_shape, seed=None):

brightness = None
if self.brightness_range is not None:
if len(self.brightness_range) != 2:
raise ValueError(
'`brightness_range should be tuple or list of two floats. '
'Received: %s' % (self.brightness_range,))
brightness = np.random.uniform(self.brightness_range[0],
self.brightness_range[1])

Expand Down
4 changes: 4 additions & 0 deletions tests/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,10 @@ def preprocessing_function(x):
with pytest.raises(ValueError):
x1, y1 = dir_seq[9]

def test_valid_args(self):
with pytest.raises(ValueError):
dt = image.ImageDataGenerator(brightness_range=0.1)

def test_dataframe_iterator_class_mode_input(self, tmpdir):
# save the images in the paths
count = 0
Expand Down

0 comments on commit d24ce04

Please sign in to comment.