Skip to content

Commit fc091ca

Browse files
committed
added cursor class
1 parent 3dab0c4 commit fc091ca

File tree

3 files changed

+97
-6
lines changed

3 files changed

+97
-6
lines changed

h5pyd/_hl/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def __iter__(self):
510510
if shape[0] - i < numrows:
511511
numrows = shape[0] - i
512512
self.log.debug("get {} iter items".format(numrows))
513-
arr = self[i:numrows]
513+
arr = self[i:numrows+i]
514514

515515
yield arr[i%BUFFER_SIZE]
516516

h5pyd/_hl/table.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414
import numpy
1515
import six
16+
from six.moves import xrange
1617
from .base import _decode
1718
from .dataset import Dataset
1819
from .objectid import DatasetID
@@ -21,7 +22,54 @@
2122
from .h5type import check_dtype
2223

2324

25+
class Cursor():
26+
"""
27+
Cursor for retreiving rows from a table
28+
"""
29+
def __init__(self, table, query=None, start=None, stop=None):
30+
self._table = table
31+
self._query = query
32+
if start is None:
33+
self._start = 0
34+
else:
35+
self._start = start
36+
if stop is None:
37+
self._stop = table.nrows
38+
else:
39+
self._stop = stop
40+
41+
def __iter__(self):
42+
""" Iterate over the first axis. TypeError if scalar.
2443
44+
BEWARE: Modifications to the yielded data are *NOT* written to file.
45+
"""
46+
nrows = self._table.nrows
47+
# to reduce round trips, grab BUFFER_SIZE items at a time
48+
# TBD: set buffersize based on size of each row
49+
BUFFER_SIZE = 1000
50+
51+
arr = None
52+
query_complete = False
53+
54+
for indx in xrange(self._start, self._stop):
55+
if indx%BUFFER_SIZE == 0:
56+
# grab another buffer
57+
read_count = BUFFER_SIZE
58+
if nrows - indx < read_count:
59+
read_count = nrows - indx
60+
if self._query is None:
61+
62+
arr = self._table[indx:read_count+indx]
63+
else:
64+
# call table to return query result
65+
if query_complete:
66+
arr = None # nothing more to fetch
67+
else:
68+
arr = self._table.read_where(self._query, start=indx, limit=read_count)
69+
if arr is not None and arr.shape[0] < read_count:
70+
query_complete = True # we've gotten all the rows
71+
if arr is not None and indx%BUFFER_SIZE < arr.shape[0]:
72+
yield arr[indx%BUFFER_SIZE]
2573

2674
class Table(Dataset):
2775

@@ -76,7 +124,7 @@ def read(self, start=None, stop=None, step=None, field=None, out=None):
76124

77125

78126

79-
def read_where(self, condition, condvars=None, field=None, start=None, stop=None, step=None):
127+
def read_where(self, condition, condvars=None, field=None, start=None, stop=None, step=None, limit=None):
80128
"""Read rows from table using pytable-style condition
81129
"""
82130
names = () # todo
@@ -148,10 +196,19 @@ def readtime_dtype(basetype, names):
148196
try:
149197
self.log.debug("params: {}".format(params))
150198
rsp = self.GET(req, params=params)
151-
count = len(rsp["value"])
199+
values = rsp["value"]
200+
count = len(values)
152201
self.log.info("got {} rows".format(count))
153202
if count > 0:
154-
data.extend(rsp['value'])
203+
if limit is None or count + len(data) <= limit:
204+
# add in all the data
205+
data.extend(values)
206+
else:
207+
# we've hit the limit for number of rows to return
208+
add_count = limit - len(data)
209+
self.log.debug("adding {} from {} to rrows".format(add_count, count))
210+
data.extend(values[:add_count])
211+
155212
# advance to next page
156213
cursor += page_size
157214
except IOError as ioe:
@@ -165,7 +222,7 @@ def readtime_dtype(basetype, names):
165222
# otherwise, just raise the exception
166223
self.log.info("Unexpected exception: {}".format(ioe.errno))
167224
raise ioe
168-
if cursor >= stop:
225+
if cursor >= stop or limit and len(data) == limit:
169226
self.log.info("completed iteration, returning: {} rows".format(len(data)))
170227
break
171228

@@ -190,6 +247,13 @@ def readtime_dtype(basetype, names):
190247
arr = numpy.asscalar(arr)
191248

192249
return arr
250+
251+
def create_cursor(self, condition=None, start=None, stop=None):
252+
"""Return a cursor for iteration
253+
"""
254+
return Cursor(self, query=condition, start=start, stop=stop)
255+
256+
193257

194258
def append(self, rows):
195259
""" Append rows to end of table

test/test_table.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,21 @@ def test_create_table(self):
4848

4949
self.assertEqual(table.colnames, ['real', 'img'])
5050
self.assertEqual(table.nrows, count)
51+
52+
num_rows = 0
5153
for row in table:
5254
self.assertEqual(len(row), 2)
55+
num_rows += 1
56+
self.assertEqual(num_rows, count)
57+
58+
# try the same thing using cursor object
59+
cursor = table.create_cursor()
60+
num_rows = 0
61+
for row in cursor:
62+
self.assertEqual(len(row), 2)
63+
num_rows += 1
64+
self.assertEqual(num_rows, count)
65+
5366
arr = table.read(start=5, stop=6)
5467
self.assertEqual(arr.shape, (1,))
5568

@@ -96,11 +109,25 @@ def test_query_table(self):
96109
# first two columns will come back as bytes, not strs
97110
self.assertEqual(row[col], item[col])
98111

99-
quotes = table.read_where("symbol == b'AAPL'")
112+
condition = "symbol == b'AAPL'"
113+
quotes = table.read_where(condition)
100114
self.assertEqual(len(quotes), 4)
101115
for i in range(4):
102116
quote = quotes[i]
103117
self.assertEqual(quote[0], b'AAPL')
118+
119+
# read up to 2 rows
120+
quotes = table.read_where(condition, limit=2)
121+
self.assertEqual(len(quotes), 2)
122+
123+
# use a query cursor
124+
cursor = table.create_cursor(condition=condition)
125+
num_rows = 0
126+
for row in cursor:
127+
self.assertEqual(len(row), 4)
128+
num_rows += 1
129+
self.assertEqual(num_rows, 4)
130+
104131
f.close()
105132

106133

0 commit comments

Comments
 (0)