Skip to content

Commit ccbc1aa

Browse files
committed
enable number of subprocesses to be set for direct mode
1 parent 20fce36 commit ccbc1aa

File tree

3 files changed

+56
-17
lines changed

3 files changed

+56
-17
lines changed

h5pyd/_hl/httpconn.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,24 @@ def __init__(self, domain_name, endpoint=None, username=None, password=None, buc
198198
# save lambda function name
199199
self._lambda = endpoint[len("lambda:"):]
200200

201-
elif endpoint == "local":
201+
elif endpoint.startswith("local"):
202202
# create a local hsds server
203-
# set the number of nodes equal to number of cores
204-
dn_count = multiprocessing.cpu_count()
205-
dn_count = -(-dn_count // 2) # get the ceiling of count / 2 (don't include hyperthreading cores)
203+
# set the number of nodes
204+
# if the endpoint is of the form: "local[n]", use n as the number of nodes
205+
# else set the number of nodes equal to number of cores
206+
bracket_start = endpoint.find('[')
207+
bracket_end = endpoint.find(']')
208+
dn_count = None
209+
if bracket_start > 0 and bracket_end > 0:
210+
try:
211+
dn_count = int(endpoint[bracket_start+1:bracket_end])
212+
except ValueError:
213+
# if value is '*' or something just drop down to default
214+
# setup based on cpu count
215+
pass
216+
if not dn_count:
217+
dn_count = multiprocessing.cpu_count()
218+
dn_count = -(-dn_count // 2) # get the ceiling of count / 2 (don't include hyperthreading cores)
206219
if dn_count < 1:
207220
dn_count = 1
208221

h5pyd/_hl/table.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,19 @@
2323
class Cursor():
2424
"""
2525
Cursor for retreiving rows from a table
26+
buffer_rows can be used to control how many rows
27+
will be fetched from the server
2628
"""
27-
def __init__(self, table, query=None, start=None, stop=None):
29+
def __init__(self, table, query=None, start=None, stop=None, buffer_rows=None):
2830
self._table = table
2931
self._query = query
32+
DEFAULT_BUFFER_BYTES = 1000000
33+
if buffer_rows is None:
34+
buffer_rows = DEFAULT_BUFFER_BYTES // table.dtype.itemsize
35+
if buffer_rows < 1:
36+
buffer_rows = 1
37+
self._buffer_rows = buffer_rows
38+
3039
if start is None:
3140
self._start = 0
3241
else:
@@ -41,33 +50,30 @@ def __iter__(self):
4150
4251
BEWARE: Modifications to the yielded data are *NOT* written to file.
4352
"""
44-
nrows = self._table.nrows
45-
# to reduce round trips, grab BUFFER_SIZE items at a time
46-
# TBD: set buffersize based on size of each row
47-
BUFFER_SIZE = 10000
53+
nrows = self._stop - self._start
4854

4955
arr = None
5056
query_complete = False
5157

52-
for indx in range(self._start, self._stop):
53-
if indx%BUFFER_SIZE == 0:
58+
for indx in range(self._stop - self._start):
59+
if indx % self._buffer_rows == 0:
5460
# grab another buffer
55-
read_count = BUFFER_SIZE
61+
read_count = self._buffer_rows
5662
if nrows - indx < read_count:
5763
read_count = nrows - indx
5864
if self._query is None:
59-
60-
arr = self._table[indx:read_count+indx]
65+
print("read row count:", (read_count+indx+self._start)-(indx+self._start))
66+
arr = self._table[indx+self._start:read_count+indx+self._start]
6167
else:
6268
# call table to return query result
6369
if query_complete:
6470
arr = None # nothing more to fetch
6571
else:
66-
arr = self._table.read_where(self._query, start=indx, limit=read_count)
72+
arr = self._table.read_where(self._query, start=indx+self._start, limit=read_count)
6773
if arr is not None and arr.shape[0] < read_count:
6874
query_complete = True # we've gotten all the rows
69-
if arr is not None and indx%BUFFER_SIZE < arr.shape[0]:
70-
yield arr[indx%BUFFER_SIZE]
75+
if arr is not None and indx%self._buffer_rows < arr.shape[0]:
76+
yield arr[indx%self._buffer_rows]
7177

7278
class Table(Dataset):
7379

test/hl/test_table.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,26 @@ def test_query_table(self):
109109
# first two columns will come back as bytes, not strs
110110
self.assertEqual(row[col], item[col])
111111

112+
cursor = table.create_cursor()
113+
indx = 0
114+
for row in cursor:
115+
item = data[indx]
116+
for col in range(2,3):
117+
# first two columns will come back as bytes, not strs
118+
self.assertEqual(row[col], item[col])
119+
indx += 1
120+
self.assertEqual(indx, len(data))
121+
122+
cursor = table.create_cursor(start=2, stop=5)
123+
indx = 2
124+
for row in cursor:
125+
item = data[indx]
126+
for col in range(2,3):
127+
# first two columns will come back as bytes, not strs
128+
self.assertEqual(row[col], item[col])
129+
indx += 1
130+
self.assertEqual(indx, 5)
131+
112132
condition = "symbol == b'AAPL'"
113133
quotes = table.read_where(condition)
114134
self.assertEqual(len(quotes), 4)

0 commit comments

Comments
 (0)