From 7d0468869bf8741c7d4514adf81e00f50839981f Mon Sep 17 00:00:00 2001 From: YuviPanda Date: Tue, 26 Nov 2024 21:50:43 -0800 Subject: [PATCH] Better ipv6 support when checking network bans - Rely on using `in` between IP addresses and IP Networks (https://docs.python.org/3/library/ipaddress.html#networks-as-containers-of-addresses) rather than re-implement that ip address / network matching with CIDRs ourselves. - Brings in *much* better ipv6 support - Adds a few test cases for ipv6 support as well Discovered while investigating the test failures in https://github.com/jupyterhub/binderhub/pull/1856 --- binderhub/base.py | 6 +++--- binderhub/tests/test_utils.py | 19 ++++--------------- binderhub/utils.py | 35 ++++++++++++----------------------- 3 files changed, 19 insertions(+), 41 deletions(-) diff --git a/binderhub/base.py b/binderhub/base.py index 5f198c401..3695f1cdc 100644 --- a/binderhub/base.py +++ b/binderhub/base.py @@ -38,12 +38,12 @@ def check_request_ip(self): match = ip_in_networks( request_ip, ban_networks, - min_prefix_len=self.settings["ban_networks_min_prefix_len"], ) if match: - network, message = match + network_spec = match + message = ban_networks[network_spec] app_log.warning( - f"Blocking request from {request_ip} matching banned network {network}: {message}" + f"Blocking request from {request_ip} matching banned network {network_spec}: {message}" ) raise web.HTTPError(403, f"Requests from {message} are not allowed") diff --git a/binderhub/tests/test_utils.py b/binderhub/tests/test_utils.py index d3d242c10..717eebcbc 100644 --- a/binderhub/tests/test_utils.py +++ b/binderhub/tests/test_utils.py @@ -1,4 +1,3 @@ -import ipaddress from unittest import mock import pytest @@ -116,24 +115,14 @@ def later(): ("192.168.1.2", ["192.168.1.0/24", "255.255.0.0/16"], True), ("192.168.1.2", ["255.255.0.0/16", "192.168.1.0/24"], True), ("192.168.1.2", [], False), + ("2001:db8:0:0:0:0:0:1", ["2001:db8::/32", "192.168.1.1/32"], True), + ("3001:db8:0:0:0:0:0:1", ["2001:db8::/32", "192.168.1.1/32"], False), ], ) def test_ip_in_networks(ip, cidrs, found): - networks = {ipaddress.ip_network(cidr): f"message {cidr}" for cidr in cidrs} - if networks: - min_prefix = min(net.prefixlen for net in networks) - else: - min_prefix = 1 - match = utils.ip_in_networks(ip, networks, min_prefix) + match = utils.ip_in_networks(ip, cidrs) if found: assert match - net, message = match - assert message == f"message {net}" - assert ipaddress.ip_address(ip) in net + assert match in cidrs else: assert match is False - - -def test_ip_in_networks_invalid(): - with pytest.raises(ValueError): - utils.ip_in_networks("1.2.3.4", {}, 0) diff --git a/binderhub/utils.py b/binderhub/utils.py index 400aaa956..59f4a7675 100644 --- a/binderhub/utils.py +++ b/binderhub/utils.py @@ -4,6 +4,7 @@ import time from collections import OrderedDict from hashlib import blake2b +from typing import Iterable from unittest.mock import Mock from kubernetes.client import api_client @@ -167,32 +168,20 @@ def url_path_join(*pieces): return result -def ip_in_networks(ip, networks, min_prefix_len=1): - """Return whether `ip` is in the dict of networks - - This is O(1) regardless of the size of networks - - Implementation based on netaddr.IPSet.__contains__ - - Repeatedly checks if ip/32; ip/31; ip/30; etc. is in networks - for all netmasks that match the given ip, - for a max of 32 dict key lookups for ipv4. +def ip_in_networks(ip_addr: str, networks: Iterable[str]): + """ + Checks if `ip_addr` is contained within any of the networks in `networks` - If all netmasks have a prefix length of e.g. 24 or greater, - min_prefix_len prevents checking wider network masks that can't possibly match. + If ip_addr is in any of the provided networks, return the first network that matches. + If not, return False - Returns `(netmask, networks[netmask])` for matching netmask - in networks, if found; False, otherwise. + Both ipv6 and ipv4 are supported """ - if min_prefix_len < 1: - raise ValueError(f"min_prefix_len must be >= 1, got {min_prefix_len}") - if not networks: - return False - check_net = ipaddress.ip_network(ip) - while check_net.prefixlen >= min_prefix_len: - if check_net in networks: - return check_net, networks[check_net] - check_net = check_net.supernet(1) + ip = ipaddress.ip_address(ip_addr) + for network_spec in networks: + network = ipaddress.ip_network(network_spec) + if ip in network: + return network_spec return False