diff --git a/.gitignore b/.gitignore index c7b987a..f700394 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ node_modules/ dist/ *.tsbuildinfo .coverage +.vscode/ # Generated by google-github-actions/auth action gha-creds-*.json diff --git a/app/backend/conftest.py b/app/backend/conftest.py index 4d0438c..8474e4c 100644 --- a/app/backend/conftest.py +++ b/app/backend/conftest.py @@ -120,20 +120,19 @@ def sample_metadata(): return SAMPLE_METADATA -@pytest.fixture -def embedding_model(): - """Return a mock embedding model.""" - return MockEmbeddingModel() - - @pytest.fixture def chromadb_client(): """Return a mock ChromaDB client.""" return MockChromadbClient() -@pytest.fixture(autouse=True) +@pytest.fixture() def setup(monkeypatch): - monkeypatch.setattr("vertexai.init", lambda: None) - monkeypatch.setattr("main.EMBEDDING_MODEL", MockEmbeddingModel()) - monkeypatch.setattr("main.CHROMADB_CLIENT", MockChromadbClient()) + def _setup(init_embedding_model=True, init_chromadb_client=True): + monkeypatch.setattr("vertexai.init", lambda: None) + if init_embedding_model: + monkeypatch.setattr("main.EMBEDDING_MODEL", MockEmbeddingModel()) + if init_chromadb_client: + monkeypatch.setattr("main.CHROMADB_CLIENT", MockChromadbClient()) + + return _setup diff --git a/app/backend/main.py b/app/backend/main.py index 74244d9..c917a5d 100644 --- a/app/backend/main.py +++ b/app/backend/main.py @@ -5,8 +5,6 @@ import os import time from contextlib import asynccontextmanager -from datetime import datetime -from itertools import permutations import chromadb import chromadb.api @@ -32,7 +30,12 @@ ModelType, TrialFilters, ) -from utils import format_exc_details, get_metadata_from_id +from utils import ( + construct_filters, + format_exc_details, + get_metadata_from_id, + post_filter, +) logger = logging.getLogger("uvicorn.error") @@ -101,28 +104,6 @@ def custom_openapi(): # pragma: no cover return app.openapi_schema -def filter_by_date(results_full, date_key, from_date, to_date): - filtered_documents = [] - filtered_ids = [] - - metadata = results_full["metadatas"][0] - documents = results_full["documents"][0] - ids = results_full["ids"][0] - - for idx, meta in enumerate(metadata): - try: - date_str = meta.get(date_key) - if date_str: - date_obj = datetime.strptime(date_str, "%Y-%m-%d") - if from_date <= date_obj <= to_date: - filtered_documents.append(documents[idx]) - filtered_ids.append(ids[idx]) - except (ValueError, KeyError) as e: - print(f"Error processing metadata: {meta}. Error: {e}") - - return {"ids": [filtered_ids], "documents": [filtered_documents]} - - @app.exception_handler(Exception) async def custom_exception_handler( request: Request, exc: Exception @@ -158,55 +139,14 @@ async def retrieve( if top_k <= 0 or top_k > 30: raise HTTPException(status_code=404, detail="Required 0 < top_k <= 30") - # Construct the filters + # Construct the filters; we will need to include the full metadata in the query + # results if post-filtering is needed, otherwise only documents are needed; TODO: + # we should avoid post-processing filters if possible filters: TrialFilters = json.loads(filters_serialized) - processed_filters = [] - - if "studyType" in filters: - if filters["studyType"] == "interventional": - processed_filters.append({"study_type": "INTERVENTIONAL"}) - elif filters["studyType"] == "observational": - processed_filters.append({"study_type": "OBSERVATIONAL"}) - - if "acceptsHealthy" in filters and not filters["acceptsHealthy"]: - # NOTE: If this filter is True, it means to accept healthy participants, - # and unhealthy participants are always accepted so it is equivalent to - # not having this filter at all - processed_filters.append({"accepts_healthy": False}) # type: ignore # TODO: Fix this - - if "eligibleSex" in filters: - if filters["eligibleSex"] == "female": - processed_filters.append({"eligible_sex": "FEMALE"}) - elif filters["eligibleSex"] == "male": - processed_filters.append({"eligible_sex": "MALE"}) - - # TODO: Change this to a post-filter - if "studyPhases" in filters and len(filters["studyPhases"]) > 0: - # NOTE: ChromaDB does not support the $contains operator on strings for - # metadata fields. Therefore, we need to generate all possible - # combinations and do exact matching - all_phases = ["EARLY_PHASE1", "PHASE1", "PHASE2", "PHASE3", "PHASE4"] - possible_values = [] - for r in range(len(all_phases)): - for combo in permutations(all_phases, r + 1): - if any(phase in combo for phase in filters["studyPhases"]): - possible_values.append(", ".join(combo)) - processed_filters.append({"study_phases": {"$in": possible_values}}) # type: ignore # TODO: Fix this - - if "ageRange" in filters: - # NOTE: We want the age range to intersect with the desired range - min_age, max_age = filters["ageRange"] - processed_filters.append({"min_age": {"$lte": max_age}}) # type: ignore # TODO: Fix this - processed_filters.append({"max_age": {"$gte": min_age}}) # type: ignore # TODO: Fix this - - # Construct the where clause - where: chromadb.Where | None - if len(processed_filters) == 0: - where = None - elif len(processed_filters) == 1: - where = processed_filters[0] # type: ignore # TODO: Fix this - else: - where = {"$and": processed_filters} # type: ignore # TODO: Fix this + needs_post_filter, where = construct_filters(filters) + include = [chromadb.api.types.IncludeEnum("documents")] + if needs_post_filter: + include.append(chromadb.api.types.IncludeEnum("metadatas")) # Embed the query and query the collection query_embedding = EMBEDDING_MODEL.encode(query) @@ -214,44 +154,18 @@ async def retrieve( results = collection.query( query_embeddings=[query_embedding], n_results=top_k, - include=[chromadb.api.types.IncludeEnum("documents")], + include=include, where=where, ) - results_full = collection.query( - query_embeddings=[query_embedding], - n_results=top_k, - include=[ - chromadb.api.types.IncludeEnum("documents"), - chromadb.api.types.IncludeEnum("metadatas"), - ], - where=where, - ) - - if "lastUpdateDatePosted" in filters: - from_timestamp, to_timestamp = filters["lastUpdateDatePosted"] - from_date = datetime.fromtimestamp(from_timestamp / 1000) - to_date = datetime.fromtimestamp(to_timestamp / 1000) - results = filter_by_date( - results_full, "last_update_date_posted", from_date, to_date - ) + # Post-filter the results if needed + if needs_post_filter: + return post_filter(results, filters) - if "resultsDatePosted" in filters: - from_timestamp, to_timestamp = filters["resultsDatePosted"] - from_date = datetime.fromtimestamp(from_timestamp / 1000) - to_date = datetime.fromtimestamp(to_timestamp / 1000) - results = filter_by_date( - results_full, "results_date_posted", from_date, to_date - ) - - # Retrieve the results + # Retrieve the results as is ids = results["ids"][0] - if results["documents"] is not None: - documents = results["documents"][0] - else: - documents = [""] * len(ids) - - return {"ids": ids, "documents": documents} + assert results["documents"] is not None, "Missing documents in query results" + return {"ids": ids, "documents": results["documents"][0]} @app.get("/meta/{item_id}") diff --git a/app/backend/test_main.py b/app/backend/test_main.py index 3dc816e..a650160 100644 --- a/app/backend/test_main.py +++ b/app/backend/test_main.py @@ -20,11 +20,12 @@ def test_heartbeat(): @pytest.mark.parametrize("top_k", [1, 3, 5]) -def test_retrieve(top_k): +def test_retrieve(top_k, setup): """Test the /retrieve endpoint.""" + setup() filters_serialized = quote(json.dumps({})) response = client.get( - f"/retrieve?query=Dummy Query&{top_k=}&filters_serialized={filters_serialized}" + f"/retrieve?query=Dummy&{top_k=}&filters_serialized={filters_serialized}" ) assert response.status_code == 200 response_json = response.json() @@ -35,19 +36,49 @@ def test_retrieve(top_k): @pytest.mark.parametrize("top_k", [0, 31]) -def test_retrieve_invalid_top_k(top_k): +def test_retrieve_invalid_top_k(top_k, setup): """Test the /retrieve endpoint with invalid top_k.""" + setup() filters_serialized = quote(json.dumps({})) response = client.get( - f"/retrieve?query=Dummy Query&{top_k=}&filters_serialized={filters_serialized}" + f"/retrieve?query=Dummy&{top_k=}&filters_serialized={filters_serialized}" ) assert response.status_code == 404 assert "Required 0 < top_k <= 30" in response.text +@pytest.mark.parametrize("init_embedding_model", [True, False]) +@pytest.mark.parametrize("init_chromadb_client", [True, False]) +def test_retrieve_partial_setup(setup, init_embedding_model, init_chromadb_client): + """Test the /retrieve endpoint with partial setup.""" + setup( + init_embedding_model=init_embedding_model, + init_chromadb_client=init_chromadb_client, + ) + + def _run(): + filters_serialized = quote(json.dumps({})) + return client.get( + f"/retrieve?query=Dummy&top_k=3&filters_serialized={filters_serialized}" + ) + + if not init_embedding_model: + with pytest.raises(RuntimeError, match="Embedding model not initialized"): + _run() + return + if not init_chromadb_client: + with pytest.raises(RuntimeError, match="ChromaDB not reachable"): + _run() + return + + response = _run() + assert response.status_code == 200 + + @pytest.mark.parametrize("item_id", ["id0", "id1", "id2"]) -def test_meta(item_id): +def test_meta(item_id, setup): """Test the /meta/{item_id} endpoint.""" + setup() response = client.get(f"/meta/{item_id}") assert response.status_code == 200 response_json = response.json() @@ -55,4 +86,11 @@ def test_meta(item_id): assert response_json["metadata"]["shortTitle"] == f"Sample Metadata {item_id}" +def test_meta_partial_setup(setup): + """Test the /meta/{item_id} endpoint with partial setup.""" + setup(init_chromadb_client=False) + with pytest.raises(RuntimeError, match="ChromaDB not reachable"): + client.get("/meta/id0") + + # TODO: Add tests for the /chat/{model}/{item_id} endpoint diff --git a/app/backend/test_utils.py b/app/backend/test_utils.py index add5408..1aae437 100644 --- a/app/backend/test_utils.py +++ b/app/backend/test_utils.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from utils import format_exc_details, get_metadata_from_id +from utils import construct_filters, format_exc_details, get_metadata_from_id def _snake_to_camel(snake_str): @@ -32,6 +32,7 @@ def test_format_exc_details(): @pytest.mark.parametrize("item_id", ["id0", "id1", "id2"]) def test_get_metadata_from_id(item_id, chromadb_client, sample_metadata): + """Test getting metadata from an item ID.""" collection = chromadb_client.get_collection("test") metadata = get_metadata_from_id(collection, item_id) assert metadata["shortTitle"] == f"Sample Metadata {item_id}" @@ -60,3 +61,47 @@ def test_get_metadata_from_id(item_id, chromadb_client, sample_metadata): assert item["timeFrame"] == item_exp["time_frame"] else: assert metadata[_snake_to_camel(key)] == value + + +@pytest.mark.parametrize( + "filters, expected_where, expected_needs_post_filter", + [ + # Single filter + ({"studyType": "interventional"}, {"study_type": "INTERVENTIONAL"}, False), + ({"acceptsHealthy": False}, {"accepts_healthy": False}, False), + ({"acceptsHealthy": True}, None, False), + ({"eligibleSex": "female"}, {"eligible_sex": "FEMALE"}, False), + ( + {"ageRange": (18, 30)}, + {"$and": [{"min_age": {"$lte": 30}}, {"max_age": {"$gte": 18}}]}, + False, + ), + # Multiple filters + ( + {"studyType": "interventional", "acceptsHealthy": False}, + {"$and": [{"study_type": "INTERVENTIONAL"}, {"accepts_healthy": False}]}, + False, + ), + ( + {"eligibleSex": "female", "ageRange": (18, 30)}, + { + "$and": [ + {"eligible_sex": "FEMALE"}, + {"min_age": {"$lte": 30}}, + {"max_age": {"$gte": 18}}, + ] + }, + False, + ), + # Expected post-processing + ({"studyPhases": ["PHASE1"]}, None, True), + ({"lastUpdateDatePosted": (0, 0)}, None, True), + ({"resultsDatePosted": (0, 0)}, None, True), + ], +) +def test_construct_filters(filters, expected_where, expected_needs_post_filter): + """Test constructing filters.""" + assert construct_filters(filters) == (expected_needs_post_filter, expected_where) + + +# TODO: Add tests for the post_filter function diff --git a/app/backend/utils.py b/app/backend/utils.py index 309e5ec..c39d5ae 100644 --- a/app/backend/utils.py +++ b/app/backend/utils.py @@ -2,13 +2,15 @@ import json import traceback +from datetime import datetime +from functools import partial from pathlib import Path -from typing import Any +from typing import Any, Literal import chromadb import chromadb.api -from localtyping import TrialMetadataType +from localtyping import APIRetrieveResponseType, TrialFilters, TrialMetadataType def format_exc_details(exc: Exception) -> str: @@ -122,3 +124,114 @@ def get_metadata_from_id( if metadatas is None: return None return _clean_metadata(metadatas[0]) + + +def construct_filters(filters: TrialFilters) -> tuple[bool, chromadb.Where | None]: + """Construct filters for querying trials.""" + processed_filters: list[chromadb.Where] = [] + + if (study_type := filters.get("studyType")) is not None: + processed_filters.append({"study_type": study_type.upper()}) + + if ( + accepts_healthy := filters.get("acceptsHealthy") + ) is not None and not accepts_healthy: + # NOTE: The accepts_healthy filter being True means that the study accepts + # healthy participants; yet unhealthy participants are always accepted, so it is + # equivalent to not having this filter at all + processed_filters.append({"accepts_healthy": False}) + + if (eligible_sex := filters.get("eligibleSex")) is not None: + processed_filters.append({"eligible_sex": eligible_sex.upper()}) + + if (age_range := filters.get("ageRange")) is not None: + # NOTE: We want the age range to intersect with the desired range, so it + # suffices to have actual minimum <= desired maximum and actual maximum >= + # desired minimum + min_age, max_age = age_range + processed_filters.append({"min_age": {"$lte": max_age}}) # type: ignore + processed_filters.append({"max_age": {"$gte": min_age}}) # type: ignore + + # Construct the where clause + where: chromadb.Where | None = None + if len(processed_filters) == 1: + where = processed_filters[0] + elif len(processed_filters) > 1: + where = {"$and": processed_filters} + + # Determine if there are post-processing filters required + needs_post_filter = any( + key in filters + for key in ["studyPhases", "lastUpdateDatePosted", "resultsDatePosted"] + ) + + return needs_post_filter, where + + +def post_filter( + results: chromadb.QueryResult, filters: TrialFilters +) -> APIRetrieveResponseType: + """Post-filtering of query results.""" + filtered_ids, filtered_documents = [], [] + + assert ( + results["documents"] is not None and results["metadatas"] is not None + ), "Missing documents or metadatas required for post-filtering" + ids = results["ids"][0] + documents = results["documents"][0] + metadatas = results["metadatas"][0] + + def _accept_by_study_phases(metadata: Any, study_phases_filter: list[str]) -> bool: + # If any of the study phases as in the metadata is in the desired study phases, + # then we accept this metadata; TODO: ChromaDB should support more flexible + # string matching on metadata fields, e.g., the $contains operator is currently + # only supported for document filters but not metadata field filters; by then we + # will be able to move this post-filtering logic to the database query + study_phases = metadata["study_phases"].split(", ") + return any(phase in study_phases_filter for phase in study_phases) + + def _accept_by_date( + metadata: Any, + key: Literal["last_update_date_posted", "results_date_posted"], + date_range_filter: tuple[int, int], + ) -> bool: + # If the date field is within the desired range, then we accept this metadata; + # we note that some data fields in the metadata do not have the day part, but + # the two types we accept here both do; TODO: we should consider storing the + # date fields as timestamps in the database so that we can move this + # post-filtering logic to the database query + date_from_filter, date_to_filter = date_range_filter + date = datetime.strptime(metadata[key], "%Y-%m-%d").timestamp() * 1000 + return date_from_filter <= date <= date_to_filter + + # Construc the list of post-filtering functions + post_filter_funcs = [] + if (study_phase_filter := filters.get("studyPhases")) is not None: + post_filter_funcs.append( + partial(_accept_by_study_phases, study_phases_filter=study_phase_filter) + ) + if (last_update_date_filter := filters.get("lastUpdateDatePosted")) is not None: + post_filter_funcs.append( + partial( + _accept_by_date, + key="last_update_date_posted", + date_range_filter=last_update_date_filter, + ) + ) + if (results_date_filter := filters.get("resultsDatePosted")) is not None: + post_filter_funcs.append( + partial( + _accept_by_date, + key="results_date_posted", + date_range_filter=results_date_filter, + ) + ) + + # For each item, append to the filtered list only if all post-filtering functions + # accept that item + for _id, document, metadata in zip(ids, documents, metadatas): + if all(func(metadata) for func in post_filter_funcs): + filtered_ids.append(_id) + filtered_documents.append(document) + + return {"ids": filtered_ids, "documents": filtered_documents} diff --git a/app/frontend/src/components/ChatErrorMessage.tsx b/app/frontend/src/components/ChatErrorMessage.tsx index a13806f..c123a26 100644 --- a/app/frontend/src/components/ChatErrorMessage.tsx +++ b/app/frontend/src/components/ChatErrorMessage.tsx @@ -20,6 +20,7 @@ export const ChatErrorMessage = ({ error }: ChatErrorMessageProps) => { css={{ lineHeight: "calc(var(--font-size-2) * 1.3)", fontFamily: "var(--code-font-family)", + whiteSpace: "pre-wrap", }} > {error} diff --git a/app/frontend/src/components/Sidebar.tsx b/app/frontend/src/components/Sidebar.tsx index 619cb1f..9a6e915 100644 --- a/app/frontend/src/components/Sidebar.tsx +++ b/app/frontend/src/components/Sidebar.tsx @@ -17,6 +17,7 @@ import { Dispatch, MutableRefObject, SetStateAction } from "react"; import { TbDatabaseSearch } from "react-icons/tb"; import { MetaInfo } from "../types"; import { MdDeleteOutline } from "react-icons/md"; +import { toast } from "sonner"; interface SidebarProps { tabRefs: MutableRefObject>; @@ -111,8 +112,11 @@ export const Sidebar = ({ size="2" variant="surface" color="red" - onClick={clearTabs} disabled={metaMapping.size === 0} + onClick={() => { + clearTabs(); + toast.success("All chats deleted"); + }} > Clear chats diff --git a/deploy/pipeline/pipeline.py b/deploy/pipeline/pipeline.py index 60935de..6c98ecb 100755 --- a/deploy/pipeline/pipeline.py +++ b/deploy/pipeline/pipeline.py @@ -48,15 +48,7 @@ def embedding_model(): @dsl.pipeline def pipeline(): data_pipeline_task = data_pipeline().set_display_name("Data pipeline") - - # NOTE: This requires GPU quota for custom Vertex AI training jobs - embedding_model_task = ( - embedding_model() - .set_display_name("Embedding model") - .set_accelerator_type("NVIDIA_TESLA_T4") - .set_accelerator_limit(1) - .after(data_pipeline_task) - ) + embedding_model().set_display_name("Embedding model").after(data_pipeline_task) def main():