Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(file_upload): validate mimetype as configured #1459

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion backend/chainlit/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import fnmatch
import glob
import json
import mimetypes
Expand All @@ -9,7 +10,7 @@
import webbrowser
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union

import socketio
from chainlit.auth import create_jwt, get_configuration, get_current_user
Expand Down Expand Up @@ -870,13 +871,67 @@ async def upload_file(
assert file.filename, "No filename for uploaded file"
assert file.content_type, "No content type for uploaded file"

validate_file_upload(file)

file_response = await session.persist_file(
name=file.filename, content=content, mime=file.content_type
)

return JSONResponse(content=file_response)


def validate_file_upload(file: UploadFile):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to have validation functions raise a well-defined Exception and to have a docstring specifying it's behaviour.

For example, in pydantic this is a ValueError and in Django it's a ValidationError. It should definitely not be a HTTPException as the file validation code should not have anything to do with HTTP, it's just validating an UploadFile.

We can catch validation errors and then re-raise them in a HTTP request handler (upload_file).

if config.features.spontaneous_file_upload is None:
return # TODO: if it is not configured what should happen?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most elegant solution is to change the default value for config.features.spontaneous_file_upload to always be populated with SpontaneousFileUploadFeature then to have enabled default to False and move max_files and max_size_mb default values from the frontend to the backend config. This would remove the if ... is None in favour of more explicit code. But I get it if perhaps that's out of scope and left for a later PR.

The second best would be to follow the frontend in this and not allow uploads unless explicitly enabled. Definitely, we should try to avoid having a TODO in a PR. ;)


if config.features.spontaneous_file_upload.enabled is False:
raise HTTPException(
status_code=400,
detail="File upload is not enabled",
)

validate_file_mime_type(file)
validate_file_size(file)


def validate_file_mime_type(file: UploadFile):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring, please! :)

if config.features.spontaneous_file_upload.accept is None:
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add an explicit comment here that we'll allow any mime type unless a value is defined.


accept = config.features.spontaneous_file_upload.accept
if isinstance(accept, List):
for pattern in accept:
if fnmatch.fnmatch(file.content_type, pattern):
return
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's either a list or a dict. Maybe add an assertion here so that assumption is clear (and the code will bulk if we ever haphazardly decide to change the typing).

for pattern, extensions in accept.items():
if fnmatch.fnmatch(file.content_type, pattern):
if len(extensions) == 0:
return
for extension in extensions:
if file.filename is not None and file.filename.endswith(extension):
return
raise HTTPException(
status_code=400,
detail="File type not allowed",
)


def validate_file_size(file: UploadFile):
if config.features.spontaneous_file_upload.max_size_mb is None:
return

if (
file.size is not None
and file.size
> config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
):
raise HTTPException(
status_code=400,
detail="File size too large",
)


@router.get("/project/file/{file_id}")
async def get_file(
file_id: str,
Expand Down
186 changes: 180 additions & 6 deletions backend/tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import datetime # Added import for datetime
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's a good time to remove this redundant comment. ;)

import os
from pathlib import Path
import pathlib
import tempfile
from pathlib import Path
from typing import Callable
from unittest.mock import AsyncMock, Mock, create_autospec, mock_open
import datetime # Added import for datetime

import pytest
import tempfile
from chainlit.session import WebsocketSession
from chainlit.auth import get_current_user
from chainlit.config import APP_ROOT, ChainlitConfig, load_config
from chainlit.config import (
APP_ROOT,
ChainlitConfig,
SpontaneousFileUploadFeature,
load_config,
)
from chainlit.server import app
from fastapi.testclient import TestClient
from chainlit.session import WebsocketSession
from chainlit.types import FileReference
from chainlit.user import PersistedUser # Added import for PersistedUser
from fastapi.testclient import TestClient


@pytest.fixture
Expand Down Expand Up @@ -509,6 +514,175 @@ def test_upload_file_unauthorized(
assert response.status_code == 422


def test_upload_file_disabled(
test_client: TestClient,
test_config: ChainlitConfig,
mock_session_get_by_id_patched: Mock,
monkeypatch: pytest.MonkeyPatch,
):
"""Test file upload being disabled by config."""

# Set accept in config
monkeypatch.setattr(
test_config.features,
"spontaneous_file_upload",
SpontaneousFileUploadFeature(enabled=False),
)

# Prepare the files to upload
file_content = b"Sample file content"
files = {
"file": ("test_upload.txt", file_content, "text/plain"),
}

# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
params={"session_id": mock_session_get_by_id_patched.id},
)

# Verify the response
assert response.status_code == 400


@pytest.mark.parametrize(
"accept_pattern, mime_type, expected_status",
[
({"image/*": [".png", ".gif", ".jpeg", ".jpg"]}, "image/jpeg", 400),
(["image/*"], "text/plain", 400),
(["image/*", "application/*"], "text/plain", 400),
(["image/png", "application/pdf"], "image/jpeg", 400),
(["text/*"], "text/plain", 200),
(["application/*"], "application/pdf", 200),
(["image/*"], "image/jpeg", 200),
(["image/*", "text/*"], "text/plain", 200),
(["*/*"], "text/plain", 200),
(["*/*"], "image/jpeg", 200),
(["*/*"], "application/pdf", 200),
(["image/*", "application/*"], "application/pdf", 200),
(["image/*", "application/*"], "image/jpeg", 200),
(["image/png", "application/pdf"], "image/png", 200),
(["image/png", "application/pdf"], "application/pdf", 200),
({"image/*": []}, "image/jpeg", 200),
(
{"image/*": [".png", ".gif", ".jpeg", ".jpg"]},
"text/plain",
400,
), # mime type not allowed
(
{"*/*": [".txt", ".gif", ".jpeg", ".jpg"]},
"text/plain",
200,
), # extension allowed
(
{"*/*": [".gif", ".jpeg", ".jpg"]},
"text/plain",
400,
), # extension not allowed
],
)
def test_upload_file_mime_type_check(
test_client: TestClient,
test_config: ChainlitConfig,
mock_session_get_by_id_patched: Mock,
monkeypatch: pytest.MonkeyPatch,
accept_pattern: list[str],
mime_type: str,
expected_status: int,
):
"""Test check of mime_type."""

# Set accept in config
monkeypatch.setattr(
test_config.features,
"spontaneous_file_upload",
SpontaneousFileUploadFeature(enabled=True, accept=accept_pattern),
)

# Prepare the files to upload
file_content = b"Sample file content"
files = {
"file": ("test_upload.txt", file_content, mime_type),
}

# Mock the persist_file method to return a known value
expected_file_id = "mocked_file_id"
mock_session_get_by_id_patched.persist_file = AsyncMock(
return_value={
"id": expected_file_id,
"name": "test_upload.txt",
"type": "text/plain",
"size": len(file_content),
}
)

# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
params={"session_id": mock_session_get_by_id_patched.id},
)

# Verify the response
assert response.status_code == expected_status


@pytest.mark.parametrize(
"file_content, content_multiplier, max_size_mb, expected_status",
[
(b"1", 1, 1, 200),
(b"11", 1024 * 1024, 1, 400),
],
)
def test_upload_file_mime_type_check(
test_client: TestClient,
test_config: ChainlitConfig,
mock_session_get_by_id_patched: Mock,
monkeypatch: pytest.MonkeyPatch,
file_content: bytes,
content_multiplier: int,
max_size_mb: int,
expected_status: int,
):
"""Test check of max_size_mb."""

file_content = file_content * content_multiplier

# Set accept in config
monkeypatch.setattr(
test_config.features,
"spontaneous_file_upload",
SpontaneousFileUploadFeature(max_size_mb=max_size_mb),
)

# Prepare the files to upload
files = {
"file": ("test_upload.txt", file_content, "text/plain"),
}

# Mock the persist_file method to return a known value
expected_file_id = "mocked_file_id"
mock_session_get_by_id_patched.persist_file = AsyncMock(
return_value={
"id": expected_file_id,
"name": "test_upload.txt",
"type": "text/plain",
"size": len(file_content),
}
)

# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
params={"session_id": mock_session_get_by_id_patched.id},
)

# Verify the response
assert response.status_code == expected_status


def test_project_translations_file_path_traversal(
test_client: TestClient, monkeypatch: pytest.MonkeyPatch
):
Expand Down