Skip to content

Commit

Permalink
Merge pull request #21 from dkillick/nd_sample_point
Browse files Browse the repository at this point in the history
Add convert support for nd arrays
  • Loading branch information
bjlittle authored Nov 23, 2016
2 parents 1021ea3 + b4f1e1c commit d64de36
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 14 deletions.
39 changes: 26 additions & 13 deletions nc_time_axis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,24 @@ def default_units(cls, sample_point, axis):
Computes some units for the given data point.
"""
try:
# Try getting the first item. Otherwise we just use this item.
sample_point = sample_point[0]
except (TypeError, IndexError):
pass

if not hasattr(sample_point, 'calendar'):
msg = 'Expecting netcdftimes with an extra "calendar" attribute.'
raise ValueError(msg)

return sample_point.calendar, cls.standard_unit
if hasattr(sample_point, '__iter__'):
# Deal with nD `sample_point` arrays.
if isinstance(sample_point, np.ndarray):
sample_point = sample_point.reshape(-1)
calendars = np.array([point.calendar for point in sample_point])
if np.all(calendars[0] == calendars):
calendar = calendars[0]
else:
raise ValueError('Calendar units are not all equal.')
else:
# Deal with a single `sample_point` value.
if not hasattr(sample_point, 'calendar'):
msg = ('Expecting netcdftimes with an extra '
'"calendar" attribute.')
raise ValueError(msg)
else:
calendar = sample_point.calendar
return calendar, cls.standard_unit

@classmethod
def convert(cls, value, unit, axis):
Expand All @@ -245,11 +252,13 @@ def convert(cls, value, unit, axis):
with :func:`netcdftime.utime().date2num`.
"""
shape = None
if isinstance(value, np.ndarray):
# Don't do anything with numeric types.
if value.dtype != np.object:
return value

shape = value.shape
value = value.reshape(-1)
first_value = value[0]
else:
# Don't do anything with numeric types.
Expand All @@ -270,7 +279,11 @@ def convert(cls, value, unit, axis):
if isinstance(value, CalendarDateTime):
value = [value]

return ut.date2num([v.datetime for v in value])
result = ut.date2num([v.datetime for v in value])
if shape is not None:
result = result.reshape(shape)

return result


# Automatically register NetCDFTimeConverter with matplotlib.unit's converter
Expand Down
37 changes: 36 additions & 1 deletion nc_time_axis/tests/unit/test_NetCDFTimeConverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,55 @@ def test_axis_default_limits(self):


class Test_default_units(unittest.TestCase):
def test_360_day_calendar(self):
def test_360_day_calendar_point(self):
calendar = '360_day'
unit = 'days since 2000-01-01'
val = CalendarDateTime(netcdftime.datetime(2014, 8, 12), calendar)
result = NetCDFTimeConverter().default_units(val, None)
self.assertEqual(result, (calendar, unit))

def test_360_day_calendar_list(self):
calendar = '360_day'
unit = 'days since 2000-01-01'
val = [CalendarDateTime(netcdftime.datetime(2014, 8, 12), calendar)]
result = NetCDFTimeConverter().default_units(val, None)
self.assertEqual(result, (calendar, unit))

def test_360_day_calendar_nd(self):
# Test the case where the input is an nd-array.
calendar = '360_day'
unit = 'days since 2000-01-01'
val = np.array([[CalendarDateTime(netcdftime.datetime(2014, 8, 12),
calendar)],
[CalendarDateTime(netcdftime.datetime(2014, 8, 13),
calendar)]])
result = NetCDFTimeConverter().default_units(val, None)
self.assertEqual(result, (calendar, unit))

def test_nonequal_calendars(self):
# Test that different supplied calendars causes an error.
calendar_1 = '360_day'
calendar_2 = '365_day'
unit = 'days since 2000-01-01'
val = [CalendarDateTime(netcdftime.datetime(2014, 8, 12), calendar_1),
CalendarDateTime(netcdftime.datetime(2014, 8, 13), calendar_2)]
with self.assertRaisesRegexp(ValueError, 'not all equal'):
NetCDFTimeConverter().default_units(val, None)


class Test_convert(unittest.TestCase):
def test_numpy_array(self):
val = np.array([7])
result = NetCDFTimeConverter().convert(val, None, None)
np.testing.assert_array_equal(result, val)

def test_numpy_nd_array(self):
shape = (4, 2)
val = np.arange(8).reshape(shape)
result = NetCDFTimeConverter().convert(val, None, None)
np.testing.assert_array_equal(result, val)
self.assertEqual(result.shape, shape)

def test_numeric(self):
val = 4
result = NetCDFTimeConverter().convert(val, None, None)
Expand Down

0 comments on commit d64de36

Please sign in to comment.