-
Notifications
You must be signed in to change notification settings - Fork 936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(file_upload): validate mimetype as configured #1459
base: main
Are you sure you want to change the base?
Changes from all commits
f094ebc
624a0b6
e929b08
bb7b744
d0f41f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import asyncio | ||
import fnmatch | ||
import glob | ||
import json | ||
import mimetypes | ||
|
@@ -9,7 +10,7 @@ | |
import webbrowser | ||
from contextlib import asynccontextmanager | ||
from pathlib import Path | ||
from typing import Any, Optional, Union | ||
from typing import Any, List, Optional, Union | ||
|
||
import socketio | ||
from chainlit.auth import create_jwt, get_configuration, get_current_user | ||
|
@@ -870,13 +871,67 @@ async def upload_file( | |
assert file.filename, "No filename for uploaded file" | ||
assert file.content_type, "No content type for uploaded file" | ||
|
||
validate_file_upload(file) | ||
|
||
file_response = await session.persist_file( | ||
name=file.filename, content=content, mime=file.content_type | ||
) | ||
|
||
return JSONResponse(content=file_response) | ||
|
||
|
||
def validate_file_upload(file: UploadFile): | ||
if config.features.spontaneous_file_upload is None: | ||
return # TODO: if it is not configured what should happen? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The most elegant solution is to change the default value for The second best would be to follow the frontend in this and not allow uploads unless explicitly enabled. Definitely, we should try to avoid having a TODO in a PR. ;) |
||
|
||
if config.features.spontaneous_file_upload.enabled is False: | ||
raise HTTPException( | ||
status_code=400, | ||
detail="File upload is not enabled", | ||
) | ||
|
||
validate_file_mime_type(file) | ||
validate_file_size(file) | ||
|
||
|
||
def validate_file_mime_type(file: UploadFile): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring, please! :) |
||
if config.features.spontaneous_file_upload.accept is None: | ||
return | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add an explicit comment here that we'll allow any mime type unless a value is defined. |
||
|
||
accept = config.features.spontaneous_file_upload.accept | ||
if isinstance(accept, List): | ||
for pattern in accept: | ||
if fnmatch.fnmatch(file.content_type, pattern): | ||
return | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's either a list or a dict. Maybe add an assertion here so that assumption is clear (and the code will bulk if we ever haphazardly decide to change the typing). |
||
for pattern, extensions in accept.items(): | ||
if fnmatch.fnmatch(file.content_type, pattern): | ||
if len(extensions) == 0: | ||
return | ||
for extension in extensions: | ||
if file.filename is not None and file.filename.endswith(extension): | ||
return | ||
raise HTTPException( | ||
status_code=400, | ||
detail="File type not allowed", | ||
) | ||
|
||
|
||
def validate_file_size(file: UploadFile): | ||
if config.features.spontaneous_file_upload.max_size_mb is None: | ||
return | ||
|
||
if ( | ||
file.size is not None | ||
and file.size | ||
> config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024 | ||
): | ||
raise HTTPException( | ||
status_code=400, | ||
detail="File size too large", | ||
) | ||
|
||
|
||
@router.get("/project/file/{file_id}") | ||
async def get_file( | ||
file_id: str, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,24 @@ | ||
import datetime # Added import for datetime | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's a good time to remove this redundant comment. ;) |
||
import os | ||
from pathlib import Path | ||
import pathlib | ||
import tempfile | ||
from pathlib import Path | ||
from typing import Callable | ||
from unittest.mock import AsyncMock, Mock, create_autospec, mock_open | ||
import datetime # Added import for datetime | ||
|
||
import pytest | ||
import tempfile | ||
from chainlit.session import WebsocketSession | ||
from chainlit.auth import get_current_user | ||
from chainlit.config import APP_ROOT, ChainlitConfig, load_config | ||
from chainlit.config import ( | ||
APP_ROOT, | ||
ChainlitConfig, | ||
SpontaneousFileUploadFeature, | ||
load_config, | ||
) | ||
from chainlit.server import app | ||
from fastapi.testclient import TestClient | ||
from chainlit.session import WebsocketSession | ||
from chainlit.types import FileReference | ||
from chainlit.user import PersistedUser # Added import for PersistedUser | ||
from fastapi.testclient import TestClient | ||
|
||
|
||
@pytest.fixture | ||
|
@@ -509,6 +514,175 @@ def test_upload_file_unauthorized( | |
assert response.status_code == 422 | ||
|
||
|
||
def test_upload_file_disabled( | ||
test_client: TestClient, | ||
test_config: ChainlitConfig, | ||
mock_session_get_by_id_patched: Mock, | ||
monkeypatch: pytest.MonkeyPatch, | ||
): | ||
"""Test file upload being disabled by config.""" | ||
|
||
# Set accept in config | ||
monkeypatch.setattr( | ||
test_config.features, | ||
"spontaneous_file_upload", | ||
SpontaneousFileUploadFeature(enabled=False), | ||
) | ||
|
||
# Prepare the files to upload | ||
file_content = b"Sample file content" | ||
files = { | ||
"file": ("test_upload.txt", file_content, "text/plain"), | ||
} | ||
|
||
# Make the POST request to upload the file | ||
response = test_client.post( | ||
"/project/file", | ||
files=files, | ||
params={"session_id": mock_session_get_by_id_patched.id}, | ||
) | ||
|
||
# Verify the response | ||
assert response.status_code == 400 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"accept_pattern, mime_type, expected_status", | ||
[ | ||
({"image/*": [".png", ".gif", ".jpeg", ".jpg"]}, "image/jpeg", 400), | ||
(["image/*"], "text/plain", 400), | ||
(["image/*", "application/*"], "text/plain", 400), | ||
(["image/png", "application/pdf"], "image/jpeg", 400), | ||
(["text/*"], "text/plain", 200), | ||
(["application/*"], "application/pdf", 200), | ||
(["image/*"], "image/jpeg", 200), | ||
(["image/*", "text/*"], "text/plain", 200), | ||
(["*/*"], "text/plain", 200), | ||
(["*/*"], "image/jpeg", 200), | ||
(["*/*"], "application/pdf", 200), | ||
(["image/*", "application/*"], "application/pdf", 200), | ||
(["image/*", "application/*"], "image/jpeg", 200), | ||
(["image/png", "application/pdf"], "image/png", 200), | ||
(["image/png", "application/pdf"], "application/pdf", 200), | ||
({"image/*": []}, "image/jpeg", 200), | ||
( | ||
{"image/*": [".png", ".gif", ".jpeg", ".jpg"]}, | ||
"text/plain", | ||
400, | ||
), # mime type not allowed | ||
( | ||
{"*/*": [".txt", ".gif", ".jpeg", ".jpg"]}, | ||
"text/plain", | ||
200, | ||
), # extension allowed | ||
( | ||
{"*/*": [".gif", ".jpeg", ".jpg"]}, | ||
"text/plain", | ||
400, | ||
), # extension not allowed | ||
], | ||
) | ||
def test_upload_file_mime_type_check( | ||
test_client: TestClient, | ||
test_config: ChainlitConfig, | ||
mock_session_get_by_id_patched: Mock, | ||
monkeypatch: pytest.MonkeyPatch, | ||
accept_pattern: list[str], | ||
mime_type: str, | ||
expected_status: int, | ||
): | ||
"""Test check of mime_type.""" | ||
|
||
# Set accept in config | ||
monkeypatch.setattr( | ||
test_config.features, | ||
"spontaneous_file_upload", | ||
SpontaneousFileUploadFeature(enabled=True, accept=accept_pattern), | ||
) | ||
|
||
# Prepare the files to upload | ||
file_content = b"Sample file content" | ||
files = { | ||
"file": ("test_upload.txt", file_content, mime_type), | ||
} | ||
|
||
# Mock the persist_file method to return a known value | ||
expected_file_id = "mocked_file_id" | ||
mock_session_get_by_id_patched.persist_file = AsyncMock( | ||
return_value={ | ||
"id": expected_file_id, | ||
"name": "test_upload.txt", | ||
"type": "text/plain", | ||
"size": len(file_content), | ||
} | ||
) | ||
|
||
# Make the POST request to upload the file | ||
response = test_client.post( | ||
"/project/file", | ||
files=files, | ||
params={"session_id": mock_session_get_by_id_patched.id}, | ||
) | ||
|
||
# Verify the response | ||
assert response.status_code == expected_status | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"file_content, content_multiplier, max_size_mb, expected_status", | ||
[ | ||
(b"1", 1, 1, 200), | ||
(b"11", 1024 * 1024, 1, 400), | ||
], | ||
) | ||
def test_upload_file_mime_type_check( | ||
test_client: TestClient, | ||
test_config: ChainlitConfig, | ||
mock_session_get_by_id_patched: Mock, | ||
monkeypatch: pytest.MonkeyPatch, | ||
file_content: bytes, | ||
content_multiplier: int, | ||
max_size_mb: int, | ||
expected_status: int, | ||
): | ||
"""Test check of max_size_mb.""" | ||
|
||
file_content = file_content * content_multiplier | ||
|
||
# Set accept in config | ||
monkeypatch.setattr( | ||
test_config.features, | ||
"spontaneous_file_upload", | ||
SpontaneousFileUploadFeature(max_size_mb=max_size_mb), | ||
) | ||
|
||
# Prepare the files to upload | ||
files = { | ||
"file": ("test_upload.txt", file_content, "text/plain"), | ||
} | ||
|
||
# Mock the persist_file method to return a known value | ||
expected_file_id = "mocked_file_id" | ||
mock_session_get_by_id_patched.persist_file = AsyncMock( | ||
return_value={ | ||
"id": expected_file_id, | ||
"name": "test_upload.txt", | ||
"type": "text/plain", | ||
"size": len(file_content), | ||
} | ||
) | ||
|
||
# Make the POST request to upload the file | ||
response = test_client.post( | ||
"/project/file", | ||
files=files, | ||
params={"session_id": mock_session_get_by_id_patched.id}, | ||
) | ||
|
||
# Verify the response | ||
assert response.status_code == expected_status | ||
|
||
|
||
def test_project_translations_file_path_traversal( | ||
test_client: TestClient, monkeypatch: pytest.MonkeyPatch | ||
): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to have validation functions raise a well-defined
Exception
and to have a docstring specifying it's behaviour.For example, in pydantic this is a
ValueError
and in Django it's aValidationError
. It should definitely not be aHTTPException
as the file validation code should not have anything to do with HTTP, it's just validating anUploadFile
.We can catch validation errors and then re-raise them in a HTTP request handler (
upload_file
).