Skip to content

Commit

Permalink
Close the WebSocket connection immediately when the stream is stopped.
Browse files Browse the repository at this point in the history
Currently, when the stream is stopped, we set the stream status accordingly and then wait for the `_consume` loop to check the stream status and close the WebSocket connection. The `_consume` loop calls `self._ws.recv()` with a timeout of 5 seconds, so it can take up to 5 seconds for the WebSocket connection to be closed after the stream is stopped. This is unnecessarily inefficient and complicated.

Instead, we could close the WebSocket connection immediately when the stream is stopped. The `_consume` loop would still be broken out of properly because `self._ws.recv()` would raise a `ConnectionClosed` error.
  • Loading branch information
fumoboy007 committed Feb 24, 2023
1 parent bf02b2b commit 8170a32
Showing 1 changed file with 19 additions and 52 deletions.
71 changes: 19 additions & 52 deletions alpaca_trade_api/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import msgpack
import re
import websockets
import queue

from .common import get_base_url, get_data_stream_url, get_credentials, URL
from .entity import Entity
Expand Down Expand Up @@ -59,7 +58,6 @@ def __init__(self,
self._running = False
self._loop = None
self._raw_data = raw_data
self._stop_stream_queue = queue.Queue()
self._handlers = {
'trades': {},
'quotes': {},
Expand Down Expand Up @@ -113,26 +111,14 @@ async def close(self):

async def stop_ws(self):
self._should_run = False
if self._stop_stream_queue.empty():
self._stop_stream_queue.put_nowait({"should_stop": True})
await self.close()

async def _consume(self):
while True:
if not self._stop_stream_queue.empty():
self._stop_stream_queue.get(timeout=1)
await self.close()
break
else:
try:
r = await asyncio.wait_for(self._ws.recv(), 5)
msgs = msgpack.unpackb(r)
for msg in msgs:
await self._dispatch(msg)
except asyncio.TimeoutError:
# ws.recv is hanging when no data is received. by using
# wait_for we break when no data is received, allowing us
# to break the loop when needed
pass
r = await self._ws.recv()
msgs = msgpack.unpackb(r)
for msg in msgs:
await self._dispatch(msg)

def _cast(self, msg_type, msg):
result = msg
Expand Down Expand Up @@ -230,14 +216,10 @@ async def _run_forever(self):
v for k, v in self._handlers.items()
if k not in ("cancelErrors", "corrections")
):
if not self._stop_stream_queue.empty():
# the ws was signaled to stop before starting the loop so
# we break
self._stop_stream_queue.get(timeout=1)
if not self._should_run:
return
await asyncio.sleep(0.1)
log.info(f'started {self._name} stream')
self._should_run = True
self._running = False
while True:
try:
Expand All @@ -253,10 +235,10 @@ async def _run_forever(self):
self._running = True
await self._consume()
except websockets.WebSocketException as wse:
await self.close()
self._running = False
log.warn('data websocket error, restarting connection: ' +
str(wse))
if self._should_run:
await self.close()
log.warn('data websocket error, restarting connection: ' +
str(wse))
except Exception as e:
log.exception('error during websocket '
'communication: {}'.format(str(e)))
Expand Down Expand Up @@ -610,7 +592,6 @@ def __init__(self,
self._running = False
self._loop = None
self._raw_data = raw_data
self._stop_stream_queue = queue.Queue()
self._should_run = True
self._websocket_params = websocket_params

Expand Down Expand Up @@ -675,31 +656,18 @@ async def _start_ws(self):

async def _consume(self):
while True:
if not self._stop_stream_queue.empty():
self._stop_stream_queue.get(timeout=1)
await self.close()
break
else:
try:
r = await asyncio.wait_for(self._ws.recv(), 5)
msg = json.loads(r)
await self._dispatch(msg)
except asyncio.TimeoutError:
# ws.recv is hanging when no data is received. by using
# wait_for we break when no data is received, allowing us
# to break the loop when needed
pass
r = await self._ws.recv()
msg = json.loads(r)
await self._dispatch(msg)

async def _run_forever(self):
self._loop = asyncio.get_running_loop()
# do not start the websocket connection until we subscribe to something
while not self._trade_updates_handler:
if not self._stop_stream_queue.empty():
self._stop_stream_queue.get(timeout=1)
if not self._should_run:
return
await asyncio.sleep(0.1)
log.info('started trading stream')
self._should_run = True
self._running = False
while True:
try:
Expand All @@ -712,10 +680,10 @@ async def _run_forever(self):
self._running = True
await self._consume()
except websockets.WebSocketException as wse:
await self.close()
self._running = False
log.warn('trading stream websocket error, restarting ' +
' connection: ' + str(wse))
if self._should_run:
await self.close()
log.warn('trading stream websocket error, restarting ' +
' connection: ' + str(wse))
except Exception as e:
log.exception('error during websocket '
'communication: {}'.format(str(e)))
Expand All @@ -730,8 +698,7 @@ async def close(self):

async def stop_ws(self):
self._should_run = False
if self._stop_stream_queue.empty():
self._stop_stream_queue.put_nowait({"should_stop": True})
await self.close()

def stop(self):
if self._loop.is_running():
Expand Down

0 comments on commit 8170a32

Please sign in to comment.