diff --git a/karabo_bridge/client.py b/karabo_bridge/client.py index b2ea2e1..d186471 100644 --- a/karabo_bridge/client.py +++ b/karabo_bridge/client.py @@ -31,8 +31,8 @@ class Client: Parameters ---------- - endpoint : str - server socket you want to connect to (only support TCP socket). + endpoints : str or Iterable + server socket(s) you want to connect to (only support TCP socket). sock : str, optional socket type - supported: REQ, SUB. ser : str, DEPRECATED @@ -58,7 +58,7 @@ class Client: ZMQError if provided endpoint is not valid. """ - def __init__(self, endpoint, sock='REQ', ser='msgpack', timeout=None, + def __init__(self, endpoints, sock='REQ', ser='msgpack', timeout=None, context=None): if ser != 'msgpack': @@ -74,19 +74,32 @@ def __init__(self, endpoint, sock='REQ', ser='msgpack', timeout=None, elif sock == 'SUB': self._socket = self._context.socket(zmq.SUB) self._socket.setsockopt(zmq.SUBSCRIBE, b'') + elif sock == 'DEALER': + self._socket = self._context.socket(zmq.DEALER) else: raise NotImplementedError('Unsupported socket: %s' % str(sock)) self._socket.setsockopt(zmq.LINGER, 0) self._socket.set_hwm(1) - self._socket.connect(endpoint) + + if isinstance(endpoints, str): + endpoints = [endpoints] + + self._num_endpoints = len(endpoints) + + if self._num_endpoints > 1 and self._socket.type != zmq.DEALER: + raise ValueError('multiple endpoints only supported with DEALER ' + 'sockets') + + for endpoint in endpoints: + self._socket.connect(endpoint) if timeout is not None: self._socket.setsockopt(zmq.RCVTIMEO, int(timeout * 1000)) self._recv_ready = False - self._pattern = self._socket.TYPE + self._pattern = self._socket.type - def next(self): + def next(self, divisor=None, remainder=None): """Request next data container. This function call is blocking. @@ -107,18 +120,49 @@ def next(self): TimeoutError If timeout is reached before receiving data. """ - if self._pattern == zmq.REQ and not self._recv_ready: - self._socket.send(b'next') + + recv_slice = slice(None) + + if self._pattern in {zmq.REQ, zmq.DEALER} and not self._recv_ready: + if divisor is not None and remainder is not None: + msg = f'next {divisor} {remainder}'.encode('ascii') + else: + msg = b'next' + + if self._pattern == zmq.DEALER: + for _ in range(self._num_endpoints): + self._socket.send_multipart([b'', msg]) + + # Account for empty delimiter frame. + recv_slice = slice(1, None) + else: + # There can only be a single endpoint with REP. + self._socket.send(msg) + self._recv_ready = True + try: - msg = self._socket.recv_multipart(copy=False) + # TODO: No guarantee to be actually matched if the bridges + # are out of sync themselves. + replies = [self._socket.recv_multipart(copy=False)[recv_slice] + for _ in range(self._num_endpoints)] + except zmq.error.Again: raise TimeoutError( 'No data received from {} in the last {} ms'.format( self._socket.getsockopt_string(zmq.LAST_ENDPOINT), self._socket.getsockopt(zmq.RCVTIMEO))) self._recv_ready = False - return deserialize(msg) + + data = {} + meta = {} + + for reply in replies: + data_, meta_ = deserialize(reply) + data.update(data_) + meta.update(meta_) + + return data, meta def __enter__(self): return self