Skip to content

Commit

Permalink
app: initialize the filter framework for trial retrieval (#32)
Browse files Browse the repository at this point in the history
* app: initialize the filter framework for trial retrieval

* WIP: Temporary changes

* Add filters

* fix localtyping

* linting

* AgeRange, EligibleSex, StudyPhases done. AcceptsHealthy and two Dates still have some issues.

* Fix filters: lastUpdateDatePosted and resultsDatePosted

* ui tweaks

* lastupdatedateposted

* mypy

* minor

* Fix studyPhases and acceptsHealthy

* Fix part of the formatting and linting

* fix calendar and other problems

* lint

* fix case

* simplify tests

* remove unrelated

* minor

* Add reset button to date filters

* Fix linting

---------

Co-authored-by: KristineXiao <[email protected]>
  • Loading branch information
Charlie-XIAO and KristineXiao authored Dec 10, 2024
1 parent 2d65be2 commit 3e93fde
Show file tree
Hide file tree
Showing 24 changed files with 784 additions and 20 deletions.
4 changes: 2 additions & 2 deletions app/backend/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ def _result(self, ids, include):
result["metadatas"].append(metadata)
return result

def query(self, *, query_embeddings, n_results, include):
def query(self, *, query_embeddings, n_results, include, where=None):
result = self._result(list(self.RECORDS)[:n_results], include)
for k in result:
result[k] = [result[k] for _ in range(len(query_embeddings))]
return result

def get(self, *, ids, include):
def get(self, *, ids, include, where=None):
return self._result(ids, include)


Expand Down
10 changes: 10 additions & 0 deletions app/backend/localtyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ class TrialMetadataType(TypedDict):
documents: list[TrialMetadataDocumentType]


class TrialFilters(TypedDict, total=False):
studyType: str
studyPhases: list[str]
acceptsHealthy: bool
ageRange: tuple[int, int]
eligibleSex: str
resultsDatePosted: tuple[int, int]
lastUpdateDatePosted: tuple[int, int]


class APIHeartbeatResponseType(TypedDict):
timestamp: int

Expand Down
105 changes: 103 additions & 2 deletions app/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import time
from contextlib import asynccontextmanager
from datetime import datetime
from itertools import permutations

import chromadb
import chromadb.api
Expand All @@ -28,6 +30,7 @@
APIMetaResponseType,
APIRetrieveResponseType,
ModelType,
TrialFilters,
)
from utils import format_exc_details, get_metadata_from_id

Expand Down Expand Up @@ -98,7 +101,26 @@ def custom_openapi(): # pragma: no cover
return app.openapi_schema


app.openapi = custom_openapi # type: ignore
def filter_by_date(results_full, date_key, from_date, to_date):
filtered_documents = []
filtered_ids = []

metadata = results_full["metadatas"][0]
documents = results_full["documents"][0]
ids = results_full["ids"][0]

for idx, meta in enumerate(metadata):
try:
date_str = meta.get(date_key)
if date_str:
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
if from_date <= date_obj <= to_date:
filtered_documents.append(documents[idx])
filtered_ids.append(ids[idx])
except (ValueError, KeyError) as e:
print(f"Error processing metadata: {meta}. Error: {e}")

return {"ids": [filtered_ids], "documents": [filtered_documents]}


@app.exception_handler(Exception)
Expand All @@ -124,7 +146,9 @@ async def heartbeat() -> APIHeartbeatResponseType:


@app.get("/retrieve")
async def retrieve(query: str, top_k: int) -> APIRetrieveResponseType:
async def retrieve(
query: str, top_k: int, filters_serialized: str
) -> APIRetrieveResponseType:
"""Retrieve items from the ChromaDB collection."""
if EMBEDDING_MODEL is None:
raise RuntimeError("Embedding model not initialized")
Expand All @@ -134,15 +158,92 @@ async def retrieve(query: str, top_k: int) -> APIRetrieveResponseType:
if top_k <= 0 or top_k > 30:
raise HTTPException(status_code=404, detail="Required 0 < top_k <= 30")

# Construct the filters
filters: TrialFilters = json.loads(filters_serialized)
processed_filters = []

if "studyType" in filters:
if filters["studyType"] == "interventional":
processed_filters.append({"study_type": "INTERVENTIONAL"})
elif filters["studyType"] == "observational":
processed_filters.append({"study_type": "OBSERVATIONAL"})

if "acceptsHealthy" in filters and not filters["acceptsHealthy"]:
# NOTE: If this filter is True, it means to accept healthy participants,
# and unhealthy participants are always accepted so it is equivalent to
# not having this filter at all
processed_filters.append({"accepts_healthy": False}) # type: ignore # TODO: Fix this

if "eligibleSex" in filters:
if filters["eligibleSex"] == "female":
processed_filters.append({"eligible_sex": "FEMALE"})
elif filters["eligibleSex"] == "observational":
processed_filters.append({"eligible_sex": "MALE"})

# TODO: Change this to a post-filter
if "studyPhases" in filters and len(filters["studyPhases"]) > 0:
# NOTE: ChromaDB does not support the $contains operator on strings for
# metadata fields. Therefore, we need to generate all possible
# combinations and do exact matching
all_phases = ["EARLY_PHASE1", "PHASE1", "PHASE2", "PHASE3", "PHASE4"]
possible_values = []
for r in range(len(all_phases)):
for combo in permutations(all_phases, r + 1):
if any(phase in combo for phase in filters["studyPhases"]):
possible_values.append(", ".join(combo))
processed_filters.append({"study_phases": {"$in": possible_values}}) # type: ignore # TODO: Fix this

if "ageRange" in filters:
# NOTE: We want the age range to intersect with the desired range
min_age, max_age = filters["ageRange"]
processed_filters.append({"min_age": {"$lte": max_age}}) # type: ignore # TODO: Fix this
processed_filters.append({"max_age": {"$gte": min_age}}) # type: ignore # TODO: Fix this

# Construct the where clause
where: chromadb.Where | None
if len(processed_filters) == 0:
where = None
elif len(processed_filters) == 1:
where = processed_filters[0] # type: ignore # TODO: Fix this
else:
where = {"$and": processed_filters} # type: ignore # TODO: Fix this

# Embed the query and query the collection
query_embedding = EMBEDDING_MODEL.encode(query)
collection = CHROMADB_CLIENT.get_collection(CHROMADB_COLLECTION_NAME)
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=[chromadb.api.types.IncludeEnum("documents")],
where=where,
)

results_full = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=[
chromadb.api.types.IncludeEnum("documents"),
chromadb.api.types.IncludeEnum("metadatas"),
],
where=where,
)

if "lastUpdateDatePosted" in filters:
from_timestamp, to_timestamp = filters["lastUpdateDatePosted"]
from_date = datetime.fromtimestamp(from_timestamp / 1000)
to_date = datetime.fromtimestamp(to_timestamp / 1000)
results = filter_by_date(
results_full, "last_update_date_posted", from_date, to_date
)

if "resultsDatePosted" in filters:
from_timestamp, to_timestamp = filters["resultsDatePosted"]
from_date = datetime.fromtimestamp(from_timestamp / 1000)
to_date = datetime.fromtimestamp(to_timestamp / 1000)
results = filter_by_date(
results_full, "results_date_posted", from_date, to_date
)

# Retrieve the results
ids = results["ids"][0]
if results["documents"] is not None:
Expand Down
13 changes: 11 additions & 2 deletions app/backend/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Test the main module."""

import json
from urllib.parse import quote

import pytest
from fastapi.testclient import TestClient

Expand All @@ -19,7 +22,10 @@ def test_heartbeat():
@pytest.mark.parametrize("top_k", [1, 3, 5])
def test_retrieve(top_k):
"""Test the /retrieve endpoint."""
response = client.get(f"/retrieve?query=Dummy Query&top_k={top_k}")
filters_serialized = quote(json.dumps({}))
response = client.get(
f"/retrieve?query=Dummy Query&{top_k=}&filters_serialized={filters_serialized}"
)
assert response.status_code == 200
response_json = response.json()
assert "ids" in response_json
Expand All @@ -31,7 +37,10 @@ def test_retrieve(top_k):
@pytest.mark.parametrize("top_k", [0, 31])
def test_retrieve_invalid_top_k(top_k):
"""Test the /retrieve endpoint with invalid top_k."""
response = client.get(f"/retrieve?query=Dummy Query&top_k={top_k}")
filters_serialized = quote(json.dumps({}))
response = client.get(
f"/retrieve?query=Dummy Query&{top_k=}&filters_serialized={filters_serialized}"
)
assert response.status_code == 404
assert "Required 0 < top_k <= 30" in response.text

Expand Down
1 change: 1 addition & 0 deletions app/frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"@radix-ui/themes": "^3.1.6",
"query-string": "^9.1.1",
"react": "^18.3.1",
"react-day-picker": "^9.4.1",
"react-dom": "^18.3.1",
"react-icons": "^5.4.0",
"react-markdown": "^9.0.1",
Expand Down
25 changes: 25 additions & 0 deletions app/frontend/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion app/frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
ChatDisplay,
MetaInfo,
ModelType,
TrialFilters,
UpdateMessagesFunction,
} from "./types";
import { Header } from "./components/Header";
Expand All @@ -37,6 +38,7 @@ export const App = () => {
const [metaMapping, setMetaMapping] = useState<Map<string, MetaInfo>>(
new Map(),
);
const [filters, setFilters] = useState<TrialFilters>({});
const tabRefs = useRef<Map<string, HTMLButtonElement>>(new Map());

// Switch to a different tab, creating a new tab if it does not exist yet
Expand Down Expand Up @@ -126,7 +128,7 @@ export const App = () => {
},
}}
/>
<Flex css={{ height: "100vh" }}>
<Flex css={{ height: "100dvh" }}>
{/* Left-hand sidebar panel */}
{isSidebarVisible && (
<Box
Expand Down Expand Up @@ -202,6 +204,8 @@ export const App = () => {
return newMessagesMapping;
})
}
filters={filters}
setFilters={setFilters}
switchTab={switchTab}
></RetrievePanel>
)}
Expand Down
8 changes: 7 additions & 1 deletion app/frontend/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
APIMetaResponseType,
APIRetrieveResponseType,
ModelType,
TrialFilters,
WrapAPI,
} from "./types";

Expand Down Expand Up @@ -66,8 +67,13 @@ const formatNonOkResponse = async (response: Response) => {
export const callRetrieve = async (
query: string,
topK: number,
filters: TrialFilters,
): Promise<WrapAPI<APIRetrieveResponseType>> => {
const params = queryString.stringify({ query, top_k: topK });
const params = queryString.stringify({
query,
top_k: topK,
filters_serialized: JSON.stringify(filters),
});
const url = `${import.meta.env.VITE_BACKEND_URL}/retrieve?${params}`;
try {
const response = await getResponse(url);
Expand Down
2 changes: 1 addition & 1 deletion app/frontend/src/components/ChatCollapsibleHint.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export const ChatCollapsibleHint = ({
</Collapsible.Trigger>
<Collapsible.Content
css={{
maxHeight: "20vh",
maxHeight: "30dvh",
'&[data-state="open"]': {
animation: `${slideDown} 300ms ease-in-out`,
},
Expand Down
4 changes: 2 additions & 2 deletions app/frontend/src/components/ChatPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import { ChatCollapsibleHint } from "./ChatCollapsibleHint";
import { ChatErrorMessage } from "./ChatErrorMessage";
import { addMessageUtilities, scrollToBottom } from "../utils";
import { MessageDocs } from "./MessageDocs";
import { RetrievalPanelCommandPalette } from "./RetrievePanelCommandPalette";
import { ChatPanelCommandPalette } from "./ChatPanelCommandPalette";
import { FCDeleteChatButton } from "./FCDeleteChatButton";
import { FCScrollButtons } from "./FCScrollButtons";
import { FCModelSelector } from "./FCModelSelector";
Expand Down Expand Up @@ -171,7 +171,7 @@ export const ChatPanel = ({
</ExternalLink>
}
>
<RetrievalPanelCommandPalette metaInfo={metaInfo} />
<ChatPanelCommandPalette metaInfo={metaInfo} />
</ChatCollapsibleHint>
<ChatInput
query={query}
Expand Down
Loading

0 comments on commit 3e93fde

Please sign in to comment.