Skip to content

Commit 108e081

Browse files
HyeockJinKimseedspirit
authored andcommitted
fix(BA-82): Add request & response policy middleware for web security (#2937)
1 parent b9f0668 commit 108e081

File tree

4 files changed

+200
-1
lines changed

4 files changed

+200
-1
lines changed

changes/2937.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add reject middleware for web security

src/ai/backend/web/security.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Callable, Iterable, Self
2+
3+
from aiohttp import web
4+
from aiohttp.typedefs import Handler
5+
6+
7+
@web.middleware
8+
async def security_policy_middleware(request: web.Request, handler: Handler) -> web.StreamResponse:
9+
security_policy: SecurityPolicy = request.app["security_policy"]
10+
security_policy.check_request_policies(request)
11+
response = await handler(request)
12+
return security_policy.apply_response_policies(response)
13+
14+
15+
class SecurityPolicy:
16+
_request_policies: Iterable[Callable[[web.Request], None]]
17+
_response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]]
18+
19+
def __init__(
20+
self,
21+
request_policies: Iterable[Callable[[web.Request], None]],
22+
response_policies: Iterable[Callable[[web.StreamResponse], web.StreamResponse]],
23+
) -> None:
24+
self._request_policies = request_policies
25+
self._response_policies = response_policies
26+
27+
@classmethod
28+
def default_policy(cls) -> Self:
29+
request_policies = [reject_metadata_local_link_policy, reject_access_for_unsafe_file_policy]
30+
response_policies = [add_self_content_security_policy, set_content_type_nosniff_policy]
31+
return cls(request_policies, response_policies)
32+
33+
def check_request_policies(self, request: web.Request) -> None:
34+
for policy in self._request_policies:
35+
policy(request)
36+
37+
def apply_response_policies(self, response: web.StreamResponse) -> web.StreamResponse:
38+
for policy in self._response_policies:
39+
response = policy(response)
40+
return response
41+
42+
43+
def reject_metadata_local_link_policy(request: web.Request) -> None:
44+
metadata_local_link_map = {
45+
"metadata.google.internal": True,
46+
"169.254.169.254": True,
47+
"100.100.100.200": True,
48+
"alibaba.zaproxy.org": True,
49+
"metadata.oraclecloud.com": True,
50+
}
51+
if metadata_local_link_map.get(request.host):
52+
raise web.HTTPForbidden()
53+
54+
55+
def reject_access_for_unsafe_file_policy(request: web.Request) -> None:
56+
unsafe_file_map = {
57+
"._darcs": True,
58+
".bzr": True,
59+
".hg": True,
60+
"BitKeeper": True,
61+
".bak": True,
62+
".log": True,
63+
".git": True,
64+
".svn": True,
65+
}
66+
file_name = request.path.split("/")[-1]
67+
if unsafe_file_map.get(file_name):
68+
raise web.HTTPForbidden()
69+
70+
71+
def add_self_content_security_policy(response: web.StreamResponse) -> web.StreamResponse:
72+
response.headers["Content-Security-Policy"] = (
73+
"default-src 'self'; frame-ancestors 'none'; form-action 'self';"
74+
)
75+
return response
76+
77+
78+
def set_content_type_nosniff_policy(response: web.StreamResponse) -> web.StreamResponse:
79+
response.headers["X-Content-Type-Options"] = "nosniff"
80+
return response

src/ai/backend/web/server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from ai.backend.common.web.session import setup as setup_session
3838
from ai.backend.common.web.session.redis_storage import RedisStorage
3939
from ai.backend.logging import BraceStyleAdapter, Logger, LogLevel
40+
from ai.backend.web.security import SecurityPolicy, security_policy_middleware
4041

4142
from . import __version__, user_agent
4243
from .auth import fill_forwarding_hdrs_to_api_session, get_client_ip
@@ -603,8 +604,11 @@ async def server_main(
603604
args: Tuple[Any, ...],
604605
) -> AsyncIterator[None]:
605606
config = args[0]
606-
app = web.Application(middlewares=[decrypt_payload, track_active_handlers])
607+
app = web.Application(
608+
middlewares=[decrypt_payload, track_active_handlers, security_policy_middleware]
609+
)
607610
app["config"] = config
611+
app["security_policy"] = SecurityPolicy.default_policy()
608612
j2env = jinja2.Environment(
609613
extensions=[
610614
"ai.backend.web.template.TOMLField",
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import pytest
2+
from aiohttp import web
3+
from aiohttp.test_utils import make_mocked_request
4+
from aiohttp.typedefs import Handler
5+
6+
from ai.backend.web.security import (
7+
SecurityPolicy,
8+
add_self_content_security_policy,
9+
reject_access_for_unsafe_file_policy,
10+
reject_metadata_local_link_policy,
11+
security_policy_middleware,
12+
set_content_type_nosniff_policy,
13+
)
14+
15+
16+
@pytest.fixture
17+
def default_app():
18+
app = web.Application(middlewares=[security_policy_middleware])
19+
app["security_policy"] = SecurityPolicy.default_policy()
20+
return app
21+
22+
23+
@pytest.fixture
24+
async def async_handler() -> Handler:
25+
async def handler(request):
26+
return web.Response()
27+
28+
return handler
29+
30+
31+
async def test_default_security_policy_reject_metadata_local_link(
32+
default_app, async_handler
33+
) -> None:
34+
request = make_mocked_request("GET", "/", headers={"Host": "169.254.169.254"}, app=default_app)
35+
with pytest.raises(web.HTTPForbidden):
36+
await security_policy_middleware(request, async_handler)
37+
38+
39+
async def test_default_security_policy_response(default_app, async_handler) -> None:
40+
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=default_app)
41+
response = await security_policy_middleware(request, async_handler)
42+
assert (
43+
response.headers["Content-Security-Policy"]
44+
== "default-src 'self'; frame-ancestors 'none'; form-action 'self';"
45+
)
46+
assert response.headers["X-Content-Type-Options"] == "nosniff"
47+
48+
49+
@pytest.mark.parametrize(
50+
"meta_local_link",
51+
[
52+
"metadata.google.internal",
53+
"169.254.169.254",
54+
"100.100.100.200",
55+
"alibaba.zaproxy.org",
56+
"metadata.oraclecloud.com",
57+
],
58+
)
59+
async def test_reject_metadata_local_link_policy(async_handler, meta_local_link) -> None:
60+
test_app = web.Application()
61+
test_app["security_policy"] = SecurityPolicy(
62+
request_policies=[reject_metadata_local_link_policy], response_policies=[]
63+
)
64+
request = make_mocked_request("GET", "/", headers={"Host": meta_local_link}, app=test_app)
65+
with pytest.raises(web.HTTPForbidden):
66+
await security_policy_middleware(request, async_handler)
67+
68+
69+
@pytest.mark.parametrize(
70+
"url_suffix",
71+
[
72+
"._darcs",
73+
".bzr",
74+
".hg",
75+
"BitKeeper",
76+
".bak",
77+
".log",
78+
".git",
79+
".svn",
80+
],
81+
)
82+
async def test_reject_access_for_unsafe_file_policy(async_handler, url_suffix) -> None:
83+
test_app = web.Application()
84+
test_app["security_policy"] = SecurityPolicy(
85+
request_policies=[reject_access_for_unsafe_file_policy], response_policies=[]
86+
)
87+
request = make_mocked_request(
88+
"GET", f"/{url_suffix}", headers={"Host": "localhost"}, app=test_app
89+
)
90+
with pytest.raises(web.HTTPForbidden):
91+
await security_policy_middleware(request, async_handler)
92+
93+
94+
async def test_add_self_content_security_policy(async_handler) -> None:
95+
test_app = web.Application()
96+
test_app["security_policy"] = SecurityPolicy(
97+
request_policies=[], response_policies=[add_self_content_security_policy]
98+
)
99+
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=test_app)
100+
response = await security_policy_middleware(request, async_handler)
101+
assert (
102+
response.headers["Content-Security-Policy"]
103+
== "default-src 'self'; frame-ancestors 'none'; form-action 'self';"
104+
)
105+
106+
107+
async def test_set_content_type_nosniff_policy(async_handler) -> None:
108+
test_app = web.Application()
109+
test_app["security_policy"] = SecurityPolicy(
110+
request_policies=[], response_policies=[set_content_type_nosniff_policy]
111+
)
112+
request = make_mocked_request("GET", "/", headers={"Host": "localhost"}, app=test_app)
113+
response = await security_policy_middleware(request, async_handler)
114+
assert response.headers["X-Content-Type-Options"] == "nosniff"

0 commit comments

Comments
 (0)