Skip to content

Commit afea92a

Browse files
committed
Changes since PR ethereum#3048 was put up for review
- Add web3.providers.websocket.rst to .gitignore - Put back formatters for eth_getCode / remove unnecessary ``compose()`` - Add read-friendly comment splitting Web3 from AsyncWeb3 in main.py - Use correct class name in docstring - Friendlier message when exception is raised connecting to websocket endpoint - Friendlier message for websocket restricted_kwargs; use a merge of default + provided websocket_kwargs with the provided values taking precedence - Validate "ws://" / "wss://" in websocket endpoint - Use Dict[str, Any] for websocket_kwargs type
1 parent 71f5205 commit afea92a

File tree

5 files changed

+32
-39
lines changed

5 files changed

+32
-39
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ docs/web3.gas_strategies.rst
4949
docs/web3.middleware.rst
5050
docs/web3.providers.eth_tester.rst
5151
docs/web3.providers.rst
52+
docs/web3.providers.websocket.rst
5253
docs/web3.rst
5354
docs/web3.scripts.release.rst
5455
docs/web3.scripts.rst

docs/web3.providers.websocket.rst

-29
This file was deleted.

web3/_utils/method_formatters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def apply_list_to_array_formatter(formatter: Any) -> Callable[..., Any]:
484484
to_hex_if_integer,
485485
0,
486486
),
487-
RPC.eth_getCode: compose(apply_formatter_at_index(to_hex_if_integer, 1)),
487+
RPC.eth_getCode: apply_formatter_at_index(to_hex_if_integer, 1),
488488
RPC.eth_getStorageAt: apply_formatter_at_index(to_hex_if_integer, 2),
489489
RPC.eth_getTransactionByBlockNumberAndIndex: compose(
490490
apply_formatter_at_index(to_hex_if_integer, 0),

web3/main.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ def ens(self, new_ens: Union[ENS, "Empty"]) -> None:
430430
self._ens = new_ens
431431

432432

433+
# -- async -- #
434+
435+
433436
class AsyncWeb3(BaseWeb3):
434437
# mypy Types
435438
eth: AsyncEth
@@ -505,7 +508,7 @@ def persistent_websocket(
505508
) -> "_PersistentConnectionWeb3":
506509
"""
507510
Establish a persistent connection via websockets to a websocket provider using
508-
a WebsocketProviderV2 instance.
511+
a ``PersistentConnectionProvider`` instance.
509512
"""
510513
return _PersistentConnectionWeb3(
511514
provider,

web3/providers/websocket/websocket_v2.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
import os
44
from typing import (
55
Any,
6+
Dict,
67
Optional,
78
Union,
89
)
910

1011
from eth_typing import (
1112
URI,
1213
)
14+
from toolz import (
15+
merge,
16+
)
1317
from websockets.client import (
1418
connect,
1519
)
@@ -31,6 +35,7 @@
3135
DEFAULT_PING_INTERVAL = 30 # 30 seconds
3236
DEFAULT_PING_TIMEOUT = 300 # 5 minutes
3337

38+
VALID_WEBSOCKET_URI_PREFIXES = {"ws://", "wss://"}
3439
RESTRICTED_WEBSOCKET_KWARGS = {"uri", "loop"}
3540
DEFAULT_WEBSOCKET_KWARGS = {
3641
# set how long to wait between pings from the server
@@ -51,26 +56,34 @@ class WebsocketProviderV2(PersistentConnectionProvider):
5156
def __init__(
5257
self,
5358
endpoint_uri: Optional[Union[URI, str]] = None,
54-
websocket_kwargs: Optional[Any] = None,
59+
websocket_kwargs: Optional[Dict[str, Any]] = None,
5560
call_timeout: Optional[int] = None,
5661
) -> None:
5762
self.endpoint_uri = URI(endpoint_uri)
5863
if self.endpoint_uri is None:
5964
self.endpoint_uri = get_default_endpoint()
6065

61-
if websocket_kwargs is None:
62-
websocket_kwargs = DEFAULT_WEBSOCKET_KWARGS
63-
else:
66+
if not any(
67+
self.endpoint_uri.startswith(prefix)
68+
for prefix in VALID_WEBSOCKET_URI_PREFIXES
69+
):
70+
raise Web3ValidationError(
71+
f"Websocket endpoint uri must begin with 'ws://' or 'wss://': "
72+
f"{self.endpoint_uri}"
73+
)
74+
75+
if websocket_kwargs is not None:
6476
found_restricted_keys = set(websocket_kwargs).intersection(
6577
RESTRICTED_WEBSOCKET_KWARGS
6678
)
6779
if found_restricted_keys:
6880
raise Web3ValidationError(
69-
f"{RESTRICTED_WEBSOCKET_KWARGS} are not allowed "
70-
f"in websocket_kwargs, found: {found_restricted_keys}"
81+
f"Found restricted keys for websocket_kwargs: "
82+
f"{found_restricted_keys}."
7183
)
7284

73-
self.websocket_kwargs = websocket_kwargs
85+
self.websocket_kwargs = merge(DEFAULT_WEBSOCKET_KWARGS, websocket_kwargs or {})
86+
7487
super().__init__(endpoint_uri, call_timeout=call_timeout)
7588

7689
def __str__(self) -> str:
@@ -93,7 +106,12 @@ async def is_connected(self, show_traceback: bool = False) -> bool:
93106
return False
94107

95108
async def connect(self) -> None:
96-
self.ws = await connect(self.endpoint_uri, **self.websocket_kwargs)
109+
try:
110+
self.ws = await connect(self.endpoint_uri, **self.websocket_kwargs)
111+
except Exception as e:
112+
raise ProviderConnectionError(
113+
f"Could not connect to endpoint: {self.endpoint_uri}"
114+
) from e
97115

98116
async def disconnect(self) -> None:
99117
await self.ws.close()

0 commit comments

Comments
 (0)