Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can filter out ipv6 addresses #732

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/aiortc/rtcicetransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,15 @@ class RTCIceGatherer(AsyncIOEventEmitter):
exchanged in signaling.
"""

def __init__(self, iceServers: Optional[List[RTCIceServer]] = None) -> None:
def __init__(self, iceServers: Optional[List[RTCIceServer]] = None, **ice_kwargs) -> None:
super().__init__()

if iceServers is None:
iceServers = self.getDefaultIceServers()
ice_kwargs = connection_kwargs(iceServers)
if ice_kwargs is None:
ice_kwargs = dict()
ice_kwargs_ = connection_kwargs(iceServers)
ice_kwargs.update(ice_kwargs_)

self._connection = Connection(ice_controlling=False, **ice_kwargs)
self._remote_candidates_end = False
Expand Down
5 changes: 3 additions & 2 deletions src/aiortc/rtcpeerconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,9 @@ class RTCPeerConnection(AsyncIOEventEmitter):
:param configuration: An optional :class:`RTCConfiguration`.
"""

def __init__(self, configuration: Optional[RTCConfiguration] = None) -> None:
def __init__(self, configuration: Optional[RTCConfiguration] = None, **ice_kwargs) -> None:
super().__init__()
self.__iceKwargs = ice_kwargs or dict()
self.__certificates = [RTCCertificate.generateCertificate()]
self.__cname = f"{uuid.uuid4()}"
self.__configuration = configuration or RTCConfiguration()
Expand Down Expand Up @@ -1045,7 +1046,7 @@ def __assertTrackHasNoSender(self, track: MediaStreamTrack) -> None:

def __createDtlsTransport(self) -> RTCDtlsTransport:
# create ICE transport
iceGatherer = RTCIceGatherer(iceServers=self.__configuration.iceServers)
iceGatherer = RTCIceGatherer(iceServers=self.__configuration.iceServers, **self.__iceKwargs)
iceGatherer.on("statechange", self.__updateIceGatheringState)
iceTransport = RTCIceTransport(iceGatherer)
iceTransport.on("statechange", self.__updateIceConnectionState)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_rtcicetransport.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from ipaddress import ip_address, IPv6Address, IPv4Address
from unittest import TestCase

import aioice.stun
Expand Down Expand Up @@ -300,6 +301,18 @@ async def test_construct(self):
await connection.addRemoteCandidate(None)
self.assertEqual(connection.getRemoteCandidates(), [candidate])

@asynctest
async def test_noipv6(self):
g1 = RTCIceGatherer(use_ipv4=False)
await g1.gather()
for candidate in g1.getLocalCandidates():
self.assertTrue(type(ip_address(candidate.ip)) is IPv6Address, msg=f"{candidate.ip} is not ipv6")

g2 = RTCIceGatherer(use_ipv6=False)
await g2.gather()
for candidate in g2.getLocalCandidates():
self.assertTrue(type(ip_address(candidate.ip)) is IPv4Address, msg=f"{candidate.ip} is not ipv4")

@asynctest
async def test_connect(self):
gatherer_1 = RTCIceGatherer()
Expand Down
12 changes: 12 additions & 0 deletions tests/test_rtcpeerconnection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import re
from ipaddress import IPv6Address, ip_address, IPv4Address
from unittest import TestCase

import aioice.ice
Expand Down Expand Up @@ -572,6 +573,17 @@ def tearDown(self):
aioice.stun.RETRY_MAX = self.retry_max
aioice.stun.RETRY_RTO = self.retry_rto

@asynctest
async def test_can_construct_with_ipv6_disabled(self):
pc = RTCPeerConnection(use_ipv6=False)
pc.createDataChannel("hello")
offer = await pc.createOffer()
await pc.setLocalDescription(offer)
candidates = pc.sctp.transport.transport.iceGatherer.getLocalCandidates()
self.assertTrue(len(candidates) > 0)
for candidate in candidates:
self.assertTrue(type(ip_address(candidate.ip)) is IPv4Address, msg=f"{candidate.ip} is not ipv4")

@asynctest
async def test_addIceCandidate_no_sdpMid_or_sdpMLineIndex(self):
pc = RTCPeerConnection()
Expand Down