Skip to content

Commit

Permalink
Better ipv6 support when checking network bans
Browse files Browse the repository at this point in the history
- Rely on using `in` between IP addresses and IP Networks
  (https://docs.python.org/3/library/ipaddress.html#networks-as-containers-of-addresses)
  rather than re-implement that ip address / network matching with
  CIDRs ourselves.
- Brings in *much* better ipv6 support
- Adds a few test cases for ipv6 support as well

Discovered while investigating the test failures in
#1856
  • Loading branch information
yuvipanda committed Nov 27, 2024
1 parent 0284be9 commit 7d04688
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 41 deletions.
6 changes: 3 additions & 3 deletions binderhub/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def check_request_ip(self):
match = ip_in_networks(
request_ip,
ban_networks,
min_prefix_len=self.settings["ban_networks_min_prefix_len"],
)
if match:
network, message = match
network_spec = match
message = ban_networks[network_spec]
app_log.warning(
f"Blocking request from {request_ip} matching banned network {network}: {message}"
f"Blocking request from {request_ip} matching banned network {network_spec}: {message}"
)
raise web.HTTPError(403, f"Requests from {message} are not allowed")

Expand Down
19 changes: 4 additions & 15 deletions binderhub/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ipaddress
from unittest import mock

import pytest
Expand Down Expand Up @@ -116,24 +115,14 @@ def later():
("192.168.1.2", ["192.168.1.0/24", "255.255.0.0/16"], True),
("192.168.1.2", ["255.255.0.0/16", "192.168.1.0/24"], True),
("192.168.1.2", [], False),
("2001:db8:0:0:0:0:0:1", ["2001:db8::/32", "192.168.1.1/32"], True),
("3001:db8:0:0:0:0:0:1", ["2001:db8::/32", "192.168.1.1/32"], False),
],
)
def test_ip_in_networks(ip, cidrs, found):
networks = {ipaddress.ip_network(cidr): f"message {cidr}" for cidr in cidrs}
if networks:
min_prefix = min(net.prefixlen for net in networks)
else:
min_prefix = 1
match = utils.ip_in_networks(ip, networks, min_prefix)
match = utils.ip_in_networks(ip, cidrs)
if found:
assert match
net, message = match
assert message == f"message {net}"
assert ipaddress.ip_address(ip) in net
assert match in cidrs
else:
assert match is False


def test_ip_in_networks_invalid():
with pytest.raises(ValueError):
utils.ip_in_networks("1.2.3.4", {}, 0)
35 changes: 12 additions & 23 deletions binderhub/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from collections import OrderedDict
from hashlib import blake2b
from typing import Iterable
from unittest.mock import Mock

from kubernetes.client import api_client
Expand Down Expand Up @@ -167,32 +168,20 @@ def url_path_join(*pieces):
return result


def ip_in_networks(ip, networks, min_prefix_len=1):
"""Return whether `ip` is in the dict of networks
This is O(1) regardless of the size of networks
Implementation based on netaddr.IPSet.__contains__
Repeatedly checks if ip/32; ip/31; ip/30; etc. is in networks
for all netmasks that match the given ip,
for a max of 32 dict key lookups for ipv4.
def ip_in_networks(ip_addr: str, networks: Iterable[str]):
"""
Checks if `ip_addr` is contained within any of the networks in `networks`
If all netmasks have a prefix length of e.g. 24 or greater,
min_prefix_len prevents checking wider network masks that can't possibly match.
If ip_addr is in any of the provided networks, return the first network that matches.
If not, return False
Returns `(netmask, networks[netmask])` for matching netmask
in networks, if found; False, otherwise.
Both ipv6 and ipv4 are supported
"""
if min_prefix_len < 1:
raise ValueError(f"min_prefix_len must be >= 1, got {min_prefix_len}")
if not networks:
return False
check_net = ipaddress.ip_network(ip)
while check_net.prefixlen >= min_prefix_len:
if check_net in networks:
return check_net, networks[check_net]
check_net = check_net.supernet(1)
ip = ipaddress.ip_address(ip_addr)
for network_spec in networks:
network = ipaddress.ip_network(network_spec)
if ip in network:
return network_spec
return False


Expand Down

0 comments on commit 7d04688

Please sign in to comment.