Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple endpoints and receive data by train stride #70

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions karabo_bridge/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class Client:

Parameters
----------
endpoint : str
server socket you want to connect to (only support TCP socket).
endpoints : str or Iterable
server socket(s) you want to connect to (only support TCP socket).
sock : str, optional
socket type - supported: REQ, SUB.
ser : str, DEPRECATED
Expand All @@ -58,7 +58,7 @@ class Client:
ZMQError
if provided endpoint is not valid.
"""
def __init__(self, endpoint, sock='REQ', ser='msgpack', timeout=None,
def __init__(self, endpoints, sock='REQ', ser='msgpack', timeout=None,
context=None):

if ser != 'msgpack':
Expand All @@ -74,19 +74,32 @@ def __init__(self, endpoint, sock='REQ', ser='msgpack', timeout=None,
elif sock == 'SUB':
self._socket = self._context.socket(zmq.SUB)
self._socket.setsockopt(zmq.SUBSCRIBE, b'')
elif sock == 'DEALER':
self._socket = self._context.socket(zmq.DEALER)
else:
raise NotImplementedError('Unsupported socket: %s' % str(sock))
self._socket.setsockopt(zmq.LINGER, 0)
self._socket.set_hwm(1)
self._socket.connect(endpoint)

if isinstance(endpoints, str):
endpoints = [endpoints]

self._num_endpoints = len(endpoints)

if self._num_endpoints > 1 and self._socket.type != zmq.DEALER:
raise ValueError('multiple endpoints only supported with DEALER '
'sockets')

for endpoint in endpoints:
self._socket.connect(endpoint)

if timeout is not None:
self._socket.setsockopt(zmq.RCVTIMEO, int(timeout * 1000))
self._recv_ready = False

self._pattern = self._socket.TYPE
self._pattern = self._socket.type

def next(self):
def next(self, divisor=None, remainder=None):
"""Request next data container.

This function call is blocking.
Expand All @@ -107,18 +120,49 @@ def next(self):
TimeoutError
If timeout is reached before receiving data.
"""
if self._pattern == zmq.REQ and not self._recv_ready:
self._socket.send(b'next')

recv_slice = slice(None)

if self._pattern in {zmq.REQ, zmq.DEALER} and not self._recv_ready:
if divisor is not None and remainder is not None:
msg = f'next {divisor} {remainder}'.encode('ascii')
else:
msg = b'next'

if self._pattern == zmq.DEALER:
for _ in range(self._num_endpoints):
self._socket.send_multipart([b'', msg])

# Account for empty delimiter frame.
recv_slice = slice(1, None)
else:
# There can only be a single endpoint with REP.
self._socket.send(msg)

self._recv_ready = True

try:
msg = self._socket.recv_multipart(copy=False)
# TODO: No guarantee to be actually matched if the bridges
# are out of sync themselves.
replies = [self._socket.recv_multipart(copy=False)[recv_slice]
for _ in range(self._num_endpoints)]

except zmq.error.Again:
raise TimeoutError(
'No data received from {} in the last {} ms'.format(
self._socket.getsockopt_string(zmq.LAST_ENDPOINT),
self._socket.getsockopt(zmq.RCVTIMEO)))
self._recv_ready = False
return deserialize(msg)

data = {}
meta = {}

for reply in replies:
data_, meta_ = deserialize(reply)
data.update(data_)
meta.update(meta_)

return data, meta

def __enter__(self):
return self
Expand Down