diff --git a/README.md b/README.md index 1249367..c42675f 100644 --- a/README.md +++ b/README.md @@ -15,8 +15,7 @@ Update your configuration.yaml file with ```yaml http: - # Ensure this is turned off, otherwise this integration will only get the IP from the client - use_x_forwarded_for: false + use_x_forwarded_for: true trusted_proxies: - 1.2.3.4/32 # This needs to be set to the IP of your reverse proxy auth_header: diff --git a/custom_components/auth_header/__init__.py b/custom_components/auth_header/__init__.py index 39adad3..f9e309e 100644 --- a/custom_components/auth_header/__init__.py +++ b/custom_components/auth_header/__init__.py @@ -6,7 +6,6 @@ import homeassistant.helpers.config_validation as cv import voluptuous as vol from homeassistant import data_entry_flow -from homeassistant.auth import providers from homeassistant.components.auth import DOMAIN as AUTH_DOMAIN from homeassistant.components.auth import indieauth from homeassistant.components.auth.login_flow import ( @@ -67,6 +66,15 @@ async def async_setup(hass: HomeAssistant, config): return True +def get_actual_ip(request: Request) -> str: + """Get remote from `request` without considering overrides. This is because + when behind a reverse proxy, hass overrides the .remote attributes with the X-Forwarded-For + value. We still need to check the actual remote though, to verify its from a valid proxy.""" + if isinstance(request._transport_peername, (list, tuple)): + return request._transport_peername[0] + return request._transport_peername + + class RequestLoginFlowIndexView(LoginFlowIndexView): debug: bool @@ -101,13 +109,14 @@ async def post(self, request: Request, data): handler = data["handler"] try: - if self.debug: - _LOGGER.warning(request.headers) + _LOGGER.debug(request.headers) + actual_ip = get_actual_ip(request) + _LOGGER.debug("Got actual IP %s", actual_ip) result = await self._flow_mgr.async_init( handler, context={ "request": request, - "ip_address": ip_address(request.remote), + "ip_address": ip_address(actual_ip), "credential_only": data.get("type") == "link_user", }, )