Skip to content

Commit b51b095

Browse files
committed
Support retrieving multiple attributes by name
1 parent cf36e75 commit b51b095

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

h5pyd/_hl/attrs.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __getitem__(self, name):
166166

167167
return arr
168168

169-
def get_attributes(self, pattern=None, limit=None, marker=None, use_cache=True):
169+
def get_attributes(self, names=None, pattern=None, limit=None, marker=None, use_cache=True):
170170
"""
171171
Get all attributes or a subset of attributes from the target object.
172172
If 'use_cache' is True, use the objdb cache if available.
@@ -179,6 +179,9 @@ def get_attributes(self, pattern=None, limit=None, marker=None, use_cache=True):
179179
if use_cache and (pattern or limit or marker):
180180
raise ValueError("use_cache cannot be used with pattern, limit, or marker parameters")
181181

182+
if names and (pattern or limit or marker or use_cache):
183+
raise ValueError("names cannot be used with pattern, limit, marker, or cache")
184+
182185
if self._objdb_attributes is not None:
183186
# use the objdb cache
184187
out = {}
@@ -189,8 +192,8 @@ def get_attributes(self, pattern=None, limit=None, marker=None, use_cache=True):
189192

190193
# Omit trailing slash
191194
req = self._req_prefix[:-1]
192-
193195
req += "?IncludeData=1"
196+
body = {}
194197

195198
if pattern:
196199
req += "&pattern=" + pattern
@@ -199,7 +202,21 @@ def get_attributes(self, pattern=None, limit=None, marker=None, use_cache=True):
199202
if marker:
200203
req += "&Marker=" + marker
201204

202-
rsp = self._parent.GET(req)
205+
if names:
206+
if isinstance(names, list):
207+
names = [name.decode('utf-8') if isinstance(name, bytes) else name for name in names]
208+
else:
209+
if isinstance(names, bytes):
210+
names = names.decode("utf-8")
211+
names = [names]
212+
213+
body['attr_names'] = names
214+
215+
if body:
216+
rsp = self._parent.POST(req, body=body)
217+
else:
218+
rsp = self._parent.GET(req)
219+
203220
attrs_json = rsp['attributes']
204221
names = [attr['name'] for attr in attrs_json]
205222
values = [attr['value'] for attr in attrs_json]

test/hl/test_attribute.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def test_get_multiple(self):
198198
self.assertTrue("attr5" in values_out)
199199
self.assertTrue(np.array_equal(values_out["attr5"], values[5]))
200200

201-
# get attributes that match the pattern 'att*' (all attributes)
201+
# get only attributes that match the pattern 'att*'
202+
g1.attrs['new_attr'] = np.arange(100)
202203
pattern = "att*"
203204
values_out = g1.attrs.get_attributes(pattern=pattern, use_cache=False)
204205

@@ -228,6 +229,18 @@ def test_get_multiple(self):
228229
self.assertTrue(names[i] in values_out)
229230
self.assertTrue(np.array_equal(values_out[names[i]], values[i]))
230231

232+
# get set of attributes by name
233+
names = ['attr5', 'attr7', 'attr9']
234+
235+
values_out = g1.attrs.get_attributes(names=names, use_cache=False)
236+
237+
self.assertEqual(len(values_out), 3)
238+
239+
for name in names:
240+
self.assertTrue(name in values_out)
241+
i = int(name[4])
242+
self.assertTrue(np.array_equal(values_out[name], values[i]))
243+
231244
def test_delete_multiple(self):
232245
if config.get('use_h5py') or self.hsds_version() < "0.9.0":
233246
return
@@ -266,6 +279,16 @@ def test_delete_multiple(self):
266279
self.assertTrue(names[i] in g1.attrs)
267280
self.assertTrue(np.array_equal(g1.attrs[names[i]], values[i]))
268281

282+
# delete attributes with name that must be URL-encoded
283+
names = ['attr with spaces', 'attr%', 'unicode八attr']
284+
for name in names:
285+
g1.attrs[name] = np.arange(100)
286+
287+
del g1.attrs[names]
288+
289+
for name in names:
290+
self.assertTrue(name not in g1.attrs)
291+
269292

270293
if __name__ == '__main__':
271294
loglevel = logging.ERROR

0 commit comments

Comments
 (0)