Skip to content

Commit

Permalink
ci(test): extend backend test suite except for the chat endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie-XIAO committed Nov 19, 2024
1 parent d4e7f7c commit b6da8fd
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 21 deletions.
48 changes: 32 additions & 16 deletions app/backend/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json

import numpy as np
import pytest

SAMPLE_METADATA = dict(
Expand Down Expand Up @@ -68,48 +69,63 @@
)


class MockEmbeddingModel:
"""Mock embedding model."""

def encode(self, query):
return np.array([0.1, 0.2, 0.3, 0.4, 0.5])


class MockChromadbCollection:
"""Mock ChromaDB collection."""

RECORDS = {"id1", "id2", "id3", "id4"}

def get(
self,
ids=None,
where=None,
limit=None,
offset=None,
where_document=None,
include=["metadatas", "documents"],
):
RECORDS = set(f"id{i}" for i in range(50))

def _result(self, ids, include):
result = dict(
ids=[],
ids=ids,
documents=[] if "documents" in include else None,
metadatas=[] if "metadatas" in include else None,
include=include,
)

for key in ids:
if key not in self.RECORDS:
raise ValueError(f"Record {key} not found.")
result["ids"].append(key)
if "documents" in include:
result["documents"].append(f"doc-{key}")
if "metadatas" in include:
metadata = SAMPLE_METADATA.copy()
metadata["short_title"] = f"Sample Metadata {key}"
result["metadatas"].append(metadata)
return result

def query(self, *, query_embeddings, n_results, include):
result = self._result(list(self.RECORDS)[:n_results], include)
for k in result:
result[k] = [result[k] for _ in range(len(query_embeddings))]
return result

def get(self, *, ids, include):
return self._result(ids, include)


@pytest.fixture
def sample_metadata():
"""Return a fixed sample metadata."""
return SAMPLE_METADATA


@pytest.fixture
def embedding_model():
"""Return a mock embedding model."""
return MockEmbeddingModel()


@pytest.fixture
def chromadb_collection():
"""Return a mock ChromaDB collection."""
return MockChromadbCollection()


@pytest.fixture(autouse=True)
def setup(monkeypatch):
monkeypatch.setattr("main.EMBEDDING_MODEL", MockEmbeddingModel())
monkeypatch.setattr("main.CHROMADB_COLLECTION", MockChromadbCollection())
8 changes: 5 additions & 3 deletions app/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: FastAPI): # pragma: no cover
"""Context manager to handle the lifespan of the FastAPI app."""
global EMBEDDING_MODEL, CHROMADB_COLLECTION
EMBEDDING_MODEL = FlagModel("BAAI/bge-small-en-v1.5", use_fp16=True)
Expand All @@ -61,7 +61,7 @@ async def lifespan(app: FastAPI):
)


def custom_openapi():
def custom_openapi(): # pragma: no cover
"""OpenAPI schema customization."""
if app.openapi_schema:
return app.openapi_schema
Expand All @@ -83,7 +83,9 @@ def custom_openapi():


@app.exception_handler(Exception)
async def custom_exception_handler(request: Request, exc: Exception):
async def custom_exception_handler(
request: Request, exc: Exception
): # pragma: no cover
"""Custom handle for all types of exceptions."""
response = JSONResponse(
status_code=500,
Expand Down
34 changes: 33 additions & 1 deletion app/backend/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test the main module."""

import pytest
from fastapi.testclient import TestClient

from main import app
Expand All @@ -11,4 +12,35 @@ def test_heartbeat():
"""Test the /heartbeat endpoint."""
response = client.get("/heartbeat")
assert response.status_code == 200
assert "timestamp" in response.json()
response_json = response.json()
assert "timestamp" in response_json


@pytest.mark.parametrize("top_k", [1, 3, 5])
def test_retrieve(top_k):
"""Test the /retrieve endpoint."""
response = client.get(f"/retrieve?query=Dummy Query&top_k={top_k}")
assert response.status_code == 200
response_json = response.json()
assert "ids" in response_json
assert "documents" in response_json
assert len(response_json["ids"]) == top_k
assert len(response_json["documents"]) == top_k


@pytest.mark.parametrize("top_k", [0, 31])
def test_retrieve_invalid_top_k(top_k):
"""Test the /retrieve endpoint with invalid top_k."""
response = client.get(f"/retrieve?query=Dummy Query&top_k={top_k}")
assert response.status_code == 404
assert "Required 0 < top_k <= 30" in response.text


@pytest.mark.parametrize("item_id", ["id0", "id1", "id2"])
def test_meta(item_id):
"""Test the /meta/{item_id} endpoint."""
response = client.get(f"/meta/{item_id}")
assert response.status_code == 200
response_json = response.json()
assert "metadata" in response_json
assert response_json["metadata"]["shortTitle"] == f"Sample Metadata {item_id}"
2 changes: 1 addition & 1 deletion app/backend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_format_exc_details():
assert 'np.sum("")' in details


@pytest.mark.parametrize("item_id", ["id1", "id2", "id3", "id4"])
@pytest.mark.parametrize("item_id", ["id0", "id1", "id2"])
def test_get_metadata_from_id(item_id, chromadb_collection, sample_metadata):
metadata = get_metadata_from_id(chromadb_collection, item_id)
assert metadata["shortTitle"] == f"Sample Metadata {item_id}"
Expand Down

0 comments on commit b6da8fd

Please sign in to comment.