diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..85ea15d --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,19 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Restart HASS", + "type": "shell", + "command": "make hass-restart", + "problemMatcher": [], + "group": { + "kind": "build", + "isDefault": true + }, + "presentation": { + "reveal": "never", + "close": true + }, + } + ] +} diff --git a/custom_components/auth_header/__init__.py b/custom_components/auth_header/__init__.py index 5b11f28..8e83d30 100644 --- a/custom_components/auth_header/__init__.py +++ b/custom_components/auth_header/__init__.py @@ -1,7 +1,7 @@ import logging from http import HTTPStatus from ipaddress import ip_address -from typing import Any, OrderedDict +from typing import Any, OrderedDict, TYPE_CHECKING import homeassistant.helpers.config_validation as cv import voluptuous as vol @@ -16,6 +16,10 @@ from . import headers +if TYPE_CHECKING: + from homeassistant.components.http import FastUrlDispatcher + from aiohttp.web_urldispatcher import UrlDispatcher, AbstractResource + DOMAIN = "auth_header" _LOGGER = logging.getLogger(__name__) CONFIG_SCHEMA = vol.Schema( @@ -37,11 +41,22 @@ async def async_setup(hass: HomeAssistant, config): """Register custom view which includes request in context""" # Because we start after auth, we have access to store_result store_result = hass.data[AUTH_DOMAIN] + router: "FastUrlDispatcher" | "UrlDispatcher" = hass.http.app.router # Remove old LoginFlowIndexView - for route in hass.http.app.router._resources: - if route.canonical == "/auth/login_flow": - _LOGGER.debug("Removed original login_flow route") + # HASS < 2023.8 just has a list of all routes, which we can directly remove from + for route in router._resources: + if route.canonical == RequestLoginFlowIndexView.url: + _LOGGER.debug("Removed original login_flow route (UrlDispatcher) %s", route) hass.http.app.router._resources.remove(route) + # HASS 2023.8+ uses the "FastUrlDispatcher", which also keeps a dict for faster lookups + if hasattr(router, "_resource_index"): + resource_index: dict[str, list["AbstractResource"]] = router._resource_index + routes = resource_index.get(RequestLoginFlowIndexView.url, None) + if routes: + for route in routes: + if route.canonical == RequestLoginFlowIndexView.url: + _LOGGER.debug("Removed original login_flow route (FastUrlDispatcher) %s", route) + routes.remove(route) _LOGGER.debug("Add new login_flow route") hass.http.register_view( RequestLoginFlowIndexView( @@ -93,6 +108,7 @@ def __init__(self, flow_mgr, store_result, debug=False) -> None: @log_invalid_auth async def post(self, request: Request, data: dict[str, Any]) -> Response: """Create a new login flow.""" + _LOGGER.debug("post") client_id: str = data["client_id"] redirect_uri: str = data["redirect_uri"] diff --git a/custom_components/auth_header/headers.py b/custom_components/auth_header/headers.py index 54abfdf..8899014 100644 --- a/custom_components/auth_header/headers.py +++ b/custom_components/auth_header/headers.py @@ -5,7 +5,7 @@ import logging from typing import Any, Dict, List, Optional, cast -from aiohttp.web_request import Request +from aiohttp.web import Request from homeassistant.auth.models import Credentials, User, UserMeta from homeassistant.auth.providers import AUTH_PROVIDERS, AuthProvider, LoginFlow from homeassistant.auth.providers.trusted_networks import ( diff --git a/pyproject.toml b/pyproject.toml index 27f2e6f..22d7f01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,3 +13,17 @@ homeassistant = "^2023.7.3" pylint = "*" black = "*" isort = "*" + +[tool.pylint.master] +disable = [ + "protected-access" +] + +[tool.isort] +line_length = 100 + +[tool.black] +line-length = 100 + +[tool.ruff] +line-length = 100