Skip to content

Commit b61fa0e

Browse files
authored
Field selection for simple compound types (HDFGroup#173)
* Field selection for simple compound types * Add logging to test_datatype
1 parent 4162811 commit b61fa0e

File tree

4 files changed

+181
-29
lines changed

4 files changed

+181
-29
lines changed

h5pyd/_hl/dataset.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,9 @@ def __getitem__(self, args, new_dtype=None):
10491049
req = "/datasets/" + self.id.uuid + "/value"
10501050
params = {}
10511051

1052+
if len(names) > 0:
1053+
params["fields"] = ":".join(names)
1054+
10521055
if self.id._http_conn.mode == "r" and self.id._http_conn.cache_on:
10531056
# enables lambda to be used on server
10541057
self.log.debug("setting nonstrict parameter")
@@ -1483,41 +1486,23 @@ def __setitem__(self, args, val):
14831486
last N dimensions have to match (got %s, but should be %s)" % (valshp, shp,))
14841487
mtype = h5t.py_create(numpy.dtype((val.dtype, shp)))
14851488
mshape = val.shape[0:len(val.shape)-len(shp)]
1489+
"""
14861490

1487-
1488-
# Make a compound memory type if field-name slicing is required
1489-
elif len(names) != 0:
1490-
1491-
mshape = val.shape
1492-
1491+
# Check for field selection
1492+
if len(names) != 0:
14931493
# Catch common errors
14941494
if self.dtype.fields is None:
14951495
raise TypeError("Illegal slicing argument (not a compound dataset)")
14961496
mismatch = [x for x in names if x not in self.dtype.fields]
14971497
if len(mismatch) != 0:
1498-
mismatch = ", ".join('"%s"'%x for x in mismatch)
1498+
mismatch = ", ".join('"%s"' % x for x in mismatch)
14991499
raise ValueError("Illegal slicing argument (fields %s not in dataset type)" % mismatch)
15001500

1501-
# Write non-compound source into a single dataset field
1502-
if len(names) == 1 and val.dtype.fields is None:
1503-
subtype = h5y.py_create(val.dtype)
1504-
mtype = h5t.create(h5t.COMPOUND, subtype.get_size())
1505-
mtype.insert(self._e(names[0]), 0, subtype)
1506-
1507-
# Make a new source type keeping only the requested fields
1508-
else:
1509-
fieldnames = [x for x in val.dtype.names if x in names] # Keep source order
1510-
mtype = h5t.create(h5t.COMPOUND, val.dtype.itemsize)
1511-
for fieldname in fieldnames:
1512-
subtype = h5t.py_create(val.dtype.fields[fieldname][0])
1513-
offset = val.dtype.fields[fieldname][1]
1514-
mtype.insert(self._e(fieldname), offset, subtype)
1515-
15161501
# Use mtype derived from array (let DatasetID.write figure it out)
15171502
else:
15181503
mshape = val.shape
1519-
#mtype = None
1520-
"""
1504+
# mtype = None
1505+
15211506
mshape = val.shape
15221507
self.log.debug(f"mshape: {mshape}")
15231508
self.log.debug(f"data dtype: {val.dtype}")
@@ -1582,6 +1567,10 @@ def __setitem__(self, args, val):
15821567
self.log.debug(f"got select query param: {select_param}")
15831568
params["select"] = select_param
15841569

1570+
# Perform write to subset of named fields within compound datatype, if any
1571+
if len(names) > 0:
1572+
params["fields"] = ":".join(names)
1573+
15851574
self.PUT(req, body=body, format=format, params=params)
15861575
"""
15871576
mspace = h5s.create_simple(mshape_pad, (h5s.UNLIMITED,)*len(mshape_pad))

h5pyd/_hl/h5type.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,14 @@ def getTypeItem(dt):
441441
type_info['length'] = 'H5T_VARIABLE'
442442
type_info['charSet'] = 'H5T_CSET_UTF8'
443443
type_info['strPad'] = 'H5T_STR_NULLTERM'
444-
elif vlen_check == int:
444+
elif vlen_check in (int, np.int64):
445445
type_info['class'] = 'H5T_VLEN'
446446
type_info['size'] = 'H5T_VARIABLE'
447447
type_info['base'] = 'H5T_STD_I64'
448+
elif vlen_check == np.int32:
449+
type_info['class'] = 'H5T_VLEN'
450+
type_info['size'] = 'H5T_VARIABLE'
451+
type_info['base'] = 'H5T_STD_I32'
448452
elif vlen_check in (float, np.float64):
449453
type_info['class'] = 'H5T_VLEN'
450454
type_info['size'] = 'H5T_VARIABLE'
@@ -456,7 +460,7 @@ def getTypeItem(dt):
456460
type_info['base'] = getTypeItem(vlen_check)
457461
elif vlen_check is not None:
458462
# unknown vlen type
459-
raise TypeError("Unknown h5py vlen type: " + str(vlen_check))
463+
raise TypeError("Unknown h5pyd vlen type: " + str(vlen_check))
460464
elif ref_check is not None:
461465
# a reference type
462466
type_info['class'] = 'H5T_REFERENCE'
@@ -781,7 +785,7 @@ def createBaseDataType(typeItem):
781785
raise TypeError("ArrayType is not supported for variable len types")
782786
if 'base' not in typeItem:
783787
raise KeyError("'base' not provided")
784-
baseType = createBaseDataType(typeItem['base'])
788+
baseType = createDataType(typeItem['base'])
785789
dtRet = special_dtype(vlen=np.dtype(baseType))
786790
elif typeClass == 'H5T_OPAQUE':
787791
if dims:
@@ -842,9 +846,8 @@ def createBaseDataType(typeItem):
842846
else:
843847
# not a boolean enum, use h5py special dtype
844848
dtRet = special_dtype(enum=(dt, mapping))
845-
846849
else:
847-
raise TypeError("Invalid type class")
850+
raise TypeError(f"Invalid base type class: {typeClass}")
848851

849852
return dtRet
850853

test/hl/test_datatype.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
##############################################################################
2+
# Copyright by The HDF Group. #
3+
# All rights reserved. #
4+
# #
5+
# This file is part of H5Serv (HDF5 REST Server) Service, Libraries and #
6+
# Utilities. The full HDF5 REST Server copyright notice, including #
7+
# terms governing use, modification, and redistribution, is contained in #
8+
# the file COPYING, which can be found at the root of the source code #
9+
# distribution tree. If you do not have access to this file, you may #
10+
# request a copy from [email protected]. #
11+
##############################################################################
12+
13+
import numpy as np
14+
import math
15+
import logging
16+
import config
17+
18+
if config.get("use_h5py"):
19+
import h5py
20+
else:
21+
import h5pyd as h5py
22+
23+
from common import ut, TestCase
24+
25+
26+
class TestScalarCompound(TestCase):
27+
28+
def setUp(self):
29+
filename = self.getFileName("scalar_compound_dset")
30+
print("filename:", filename)
31+
self.f = h5py.File(filename, "w")
32+
self.data = np.array((42.5, -118, "Hello"), dtype=[('a', 'f'), ('b', 'i'), ('c', '|S10')])
33+
self.dset = self.f.create_dataset('x', data=self.data)
34+
35+
def test_ndim(self):
36+
""" Verify number of dimensions """
37+
self.assertEqual(self.dset.ndim, 0)
38+
39+
def test_shape(self):
40+
""" Verify shape """
41+
self.assertEqual(self.dset.shape, tuple())
42+
43+
def test_size(self):
44+
""" Verify size """
45+
self.assertEqual(self.dset.size, 1)
46+
47+
def test_ellipsis(self):
48+
""" Ellipsis -> scalar ndarray """
49+
out = self.dset[...]
50+
# assertArrayEqual doesn't work with compounds; do manually
51+
self.assertIsInstance(out, np.ndarray)
52+
self.assertEqual(out.shape, self.data.shape)
53+
self.assertEqual(out.dtype, self.data.dtype)
54+
55+
def test_tuple(self):
56+
""" () -> np.void instance """
57+
out = self.dset[()]
58+
self.assertIsInstance(out, np.void)
59+
self.assertEqual(out.dtype, self.data.dtype)
60+
61+
def test_slice(self):
62+
""" slice -> ValueError """
63+
with self.assertRaises(ValueError):
64+
self.dset[0:4]
65+
66+
def test_index(self):
67+
""" index -> ValueError """
68+
with self.assertRaises(ValueError):
69+
self.dset[0]
70+
71+
def test_rt(self):
72+
""" Compound types are read back in correct order (h5py issue 236)"""
73+
74+
dt = np.dtype([('weight', np.float64),
75+
('cputime', np.float64),
76+
('walltime', np.float64),
77+
('parents_offset', np.uint32),
78+
('n_parents', np.uint32),
79+
('status', np.uint8),
80+
('endpoint_type', np.uint8),])
81+
82+
testdata = np.ndarray((16,), dtype=dt)
83+
for key in dt.fields:
84+
testdata[key] = np.random.random((16,)) * 100
85+
86+
self.f['test'] = testdata
87+
outdata = self.f['test'][...]
88+
self.assertTrue(np.all(outdata == testdata))
89+
self.assertEqual(outdata.dtype, testdata.dtype)
90+
91+
def test_assign(self):
92+
dt = np.dtype([('weight', (np.float64)),
93+
('endpoint_type', np.uint8),])
94+
95+
testdata = np.ndarray((16,), dtype=dt)
96+
for key in dt.fields:
97+
testdata[key] = np.random.random(size=testdata[key].shape) * 100
98+
99+
ds = self.f.create_dataset('test', (16,), dtype=dt)
100+
for key in dt.fields:
101+
ds[key] = testdata[key]
102+
103+
outdata = self.f['test'][...]
104+
105+
self.assertTrue(np.all(outdata == testdata))
106+
self.assertEqual(outdata.dtype, testdata.dtype)
107+
108+
def test_read(self):
109+
dt = np.dtype([('weight', (np.float64)),
110+
('endpoint_type', np.uint8),])
111+
112+
testdata = np.ndarray((16,), dtype=dt)
113+
for key in dt.fields:
114+
testdata[key] = np.random.random(size=testdata[key].shape) * 100
115+
116+
ds = self.f.create_dataset('test', (16,), dtype=dt)
117+
118+
# Write to all fields
119+
ds[...] = testdata
120+
121+
for key in dt.fields:
122+
outdata = self.f['test'][key]
123+
np.testing.assert_array_equal(outdata, testdata[key])
124+
self.assertEqual(outdata.dtype, testdata[key].dtype)
125+
126+
"""
127+
TBD
128+
def test_nested_compound_vlen(self):
129+
dt_inner = np.dtype([('a', h5py.vlen_dtype(np.int32)),
130+
('b', h5py.vlen_dtype(np.int32))])
131+
132+
dt = np.dtype([('f1', h5py.vlen_dtype(dt_inner)),
133+
('f2', np.int64)])
134+
135+
inner1 = (np.array(range(1, 3), dtype=np.int32),
136+
np.array(range(6, 9), dtype=np.int32))
137+
138+
inner2 = (np.array(range(10, 14), dtype=np.int32),
139+
np.array(range(16, 21), dtype=np.int32))
140+
141+
data = np.array([(np.array([inner1, inner2], dtype=dt_inner), 2),
142+
(np.array([inner1], dtype=dt_inner), 3)],
143+
dtype=dt)
144+
145+
self.f["ds"] = data
146+
out = self.f["ds"]
147+
148+
# Specifying check_alignment=False because vlen fields have 8 bytes of padding
149+
# because the vlen datatype in hdf5 occupies 16 bytes
150+
self.assertArrayEqual(out, data, check_alignment=False)
151+
"""
152+
153+
154+
if __name__ == '__main__':
155+
loglevel = logging.ERROR
156+
logging.basicConfig(format='%(asctime)s %(message)s', level=loglevel)
157+
ut.main()

testall.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import sys
1616

17+
1718
hl_tests = ('test_attribute',
1819
'test_committedtype',
1920
'test_complex_numbers',
@@ -26,6 +27,7 @@
2627
'test_dataset_pointselect',
2728
'test_dataset_scalar',
2829
'test_dataset_setitem',
30+
'test_datatype',
2931
'test_dimscale',
3032
'test_file',
3133
'test_group',
@@ -34,6 +36,7 @@
3436
'test_vlentype',
3537
'test_folder')
3638

39+
3740
app_tests = ('test_hsinfo', 'test_tall_inspect', 'test_diamond_inspect',
3841
'test_shuffle_inspect')
3942

0 commit comments

Comments
 (0)