Skip to content

Commit

Permalink
Merge pull request #258 from keisku/support-ipv6
Browse files Browse the repository at this point in the history
fix: Support IPv6
  • Loading branch information
keisku authored Jan 8, 2025
1 parent dd38fed commit efeb037
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
32 changes: 29 additions & 3 deletions bmemcached/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from urllib.parse import SplitResult # type: ignore[import-not-found]

import zlib
from ipaddress import ip_address
from io import BytesIO
import six
from six import binary_type, text_type
Expand Down Expand Up @@ -144,9 +145,7 @@ def _open_connection(self):

try:
if self.host:
self.connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.connection.settimeout(self.socket_timeout)
self.connection.connect((self.host, self.port))
self.connection = socket.create_connection((self.host, self.port), self.socket_timeout)

if self.tls_context:
self.connection = self.tls_context.wrap_socket(
Expand Down Expand Up @@ -174,11 +173,38 @@ def split_host_port(cls, server):
Port defaults to 11211.
When using IPv6 with a specified port, the address must be enclosed in brackets.
If the port is not specified, brackets are optional.
>>> split_host_port('127.0.0.1:11211')
('127.0.0.1', 11211)
>>> split_host_port('127.0.0.1')
('127.0.0.1', 11211)
>>> split_host_port('::1')
('::1', 11211)
>>> split_host_port('[::1]')
('::1', 11211)
>>> split_host_port('[::1]:11211')
('::1', 11211)
"""
default_port = 11211

def is_ip_address(address):
try:
ip_address(address)
return True
except ValueError:
return False

if is_ip_address(server):
return server, default_port

if server.startswith('['):
host, _, port = server[1:].partition(']')
if not is_ip_address(host):
raise ValueError('{} is not a valid IPv6 address'.format(server))
return host, default_port if not port else int(port.lstrip(':'))

u = SplitResult("", server, "", "", "")
return u.hostname, 11211 if u.port is None else u.port

Expand Down
13 changes: 13 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,16 @@ def memcached_socket():
yield p
p.kill()
p.wait()


@pytest.yield_fixture(scope="session", autouse=True)
def memcached_ipv6():
p = subprocess.Popen(
["memcached", "-l::1"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
time.sleep(0.1)
yield p
p.kill()
p.wait()
29 changes: 29 additions & 0 deletions test/test_server_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,38 @@ def testNoPortGiven(self):
self.assertEqual(server.host, os.environ['MEMCACHED_HOST'])
self.assertEqual(server.port, 11211)

def testIPv6(self):
server = bmemcached.protocol.Protocol('[::1]')
self.assertEqual(server.host, '::1')
self.assertEqual(server.port, 11211)
server = bmemcached.protocol.Protocol('::1')
self.assertEqual(server.host, '::1')
self.assertEqual(server.port, 11211)
server = bmemcached.protocol.Protocol('[2001:db8::2]')
self.assertEqual(server.host, '2001:db8::2')
self.assertEqual(server.port, 11211)
server = bmemcached.protocol.Protocol('2001:db8::2')
self.assertEqual(server.host, '2001:db8::2')
self.assertEqual(server.port, 11211)
# Since `2001:db8::2:8080` is a valid IPv6 address,
# it is ambiguous whether to split it into `2001:db8::2` and `8080`
# or treat it as `2001:db8::2:8080`.
# Therefore, it will be treated as `2001:db8::2:8080`.
server = bmemcached.protocol.Protocol('2001:db8::2:8080')
self.assertEqual(server.host, '2001:db8::2:8080')
self.assertEqual(server.port, 11211)
server = bmemcached.protocol.Protocol('[::1]:5000')
self.assertEqual(server.host, '::1')
self.assertEqual(server.port, 5000)
server = bmemcached.protocol.Protocol('[2001:db8::2]:5000')
self.assertEqual(server.host, '2001:db8::2')
self.assertEqual(server.port, 5000)

def testInvalidPort(self):
with self.assertRaises(ValueError):
bmemcached.protocol.Protocol('{}:blah'.format(os.environ['MEMCACHED_HOST']))
with self.assertRaises(ValueError):
bmemcached.protocol.Protocol('[::1]:blah')

def testNonStandardPort(self):
server = bmemcached.protocol.Protocol('{}:5000'.format(os.environ['MEMCACHED_HOST']))
Expand Down

0 comments on commit efeb037

Please sign in to comment.