Skip to content

Commit

Permalink
app: initialize the filter framework for trial retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie-XIAO committed Nov 20, 2024
1 parent cc4ebf5 commit 011475b
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 21 deletions.
11 changes: 6 additions & 5 deletions app/backend/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,14 @@ class MockChromadbCollection:

RECORDS = set(f"id{i}" for i in range(50))

def _result(self, ids, include):
def _result(self, ids, include, where):
result = dict(
ids=ids,
documents=[] if "documents" in include else None,
metadatas=[] if "metadatas" in include else None,
include=include,
)

for key in ids:
if "documents" in include:
result["documents"].append(f"doc-{key}")
Expand All @@ -97,14 +98,14 @@ def _result(self, ids, include):
result["metadatas"].append(metadata)
return result

def query(self, *, query_embeddings, n_results, include):
result = self._result(list(self.RECORDS)[:n_results], include)
def query(self, *, query_embeddings, n_results, include, where=None):
result = self._result(list(self.RECORDS)[:n_results], include, where)
for k in result:
result[k] = [result[k] for _ in range(len(query_embeddings))]
return result

def get(self, *, ids, include):
return self._result(ids, include)
def get(self, *, ids, include, where=None):
return self._result(ids, include, where)


@pytest.fixture
Expand Down
4 changes: 4 additions & 0 deletions app/backend/localtyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ class TrialMetadataType(TypedDict):
documents: list[TrialMetadataDocumentType]


class TrialFilters(TypedDict, total=False):
studyType: str


class APIHeartbeatResponseType(TypedDict):
timestamp: int

Expand Down
24 changes: 23 additions & 1 deletion app/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
APIMetaResponseType,
APIRetrieveResponseType,
ModelType,
TrialFilters,
)
from utils import format_exc_details, get_metadata_from_id

Expand Down Expand Up @@ -115,7 +116,9 @@ async def heartbeat() -> APIHeartbeatResponseType:


@app.get("/retrieve")
async def retrieve(query: str, top_k: int) -> APIRetrieveResponseType:
async def retrieve(
query: str, top_k: int, filters_serialized: str
) -> APIRetrieveResponseType:
"""Retrieve items from the ChromaDB collection."""
if EMBEDDING_MODEL is None:
raise RuntimeError("Embedding model not initialized")
Expand All @@ -125,12 +128,31 @@ async def retrieve(query: str, top_k: int) -> APIRetrieveResponseType:
if top_k <= 0 or top_k > 30:
raise HTTPException(status_code=404, detail="Required 0 < top_k <= 30")

# Construct the filters
filters: TrialFilters = json.loads(filters_serialized)
processed_filters = []
if "studyType" in filters:
if filters["studyType"] == "interventional":
processed_filters.append({"study_type": "INTERVENTIONAL"})
elif filters["studyType"] == "observational":
processed_filters.append({"study_type": "OBSERVATIONAL"})

# Construct the where clause
where: chromadb.Where | None
if len(processed_filters) == 0:
where = None
elif len(processed_filters) == 1:
where = processed_filters[0] # type: ignore # TODO: Fix this
else:
where = {"$and": processed_filters} # type: ignore # TODO: Fix this

# Embed the query and query the collection
query_embedding = EMBEDDING_MODEL.encode(query)
results = CHROMADB_COLLECTION.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=[chromadb.api.types.IncludeEnum("documents")],
where=where,
)

# Retrieve the results
Expand Down
19 changes: 16 additions & 3 deletions app/backend/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Test the main module."""

import json
from urllib.parse import quote

import pytest
from fastapi.testclient import TestClient

Expand All @@ -17,9 +20,16 @@ def test_heartbeat():


@pytest.mark.parametrize("top_k", [1, 3, 5])
def test_retrieve(top_k):
@pytest.mark.parametrize("filters", [{}])
def test_retrieve(top_k, filters):
"""Test the /retrieve endpoint."""
response = client.get(f"/retrieve?query=Dummy Query&top_k={top_k}")
# TODO: parametrize with filters when mock is improved

filters_serialized = quote(json.dumps(filters))
response = client.get(
f"/retrieve?query=Dummy Query&top_k={top_k}"
f"&filters_serialized={filters_serialized}"
)
assert response.status_code == 200
response_json = response.json()
assert "ids" in response_json
Expand All @@ -31,7 +41,10 @@ def test_retrieve(top_k):
@pytest.mark.parametrize("top_k", [0, 31])
def test_retrieve_invalid_top_k(top_k):
"""Test the /retrieve endpoint with invalid top_k."""
response = client.get(f"/retrieve?query=Dummy Query&top_k={top_k}")
filters_serialized = quote(json.dumps({}))
response = client.get(
f"/retrieve?query=Dummy Query&top_k={top_k}&filters_serialized={filters_serialized}"
)
assert response.status_code == 404
assert "Required 0 < top_k <= 30" in response.text

Expand Down
4 changes: 4 additions & 0 deletions app/frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
ChatDisplay,
MetaInfo,
ModelType,
TrialFilters,
UpdateMessagesFunction,
} from "./types";
import { Header } from "./components/Header";
Expand All @@ -29,6 +30,7 @@ export const App = () => {
const [metaMapping, setMetaMapping] = useState<Map<string, MetaInfo>>(
new Map(),
);
const [filters, setFilters] = useState<TrialFilters>({});
const tabRefs = useRef<Map<string, HTMLButtonElement>>(new Map());

// Switch to a different tab, creating a new tab if it does not exist yet
Expand Down Expand Up @@ -154,6 +156,8 @@ export const App = () => {
return newMessagesMapping;
})
}
filters={filters}
setFilters={setFilters}
switchTab={switchTab}
></RetrievePanel>
)}
Expand Down
8 changes: 7 additions & 1 deletion app/frontend/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
APIMetaResponseType,
APIRetrieveResponseType,
ModelType,
TrialFilters,
WrapAPI,
} from "./types";

Expand Down Expand Up @@ -66,8 +67,13 @@ const formatNonOkResponse = async (response: Response) => {
export const callRetrieve = async (
query: string,
topK: number,
filters: TrialFilters,
): Promise<WrapAPI<APIRetrieveResponseType>> => {
const params = queryString.stringify({ query, top_k: topK });
const params = queryString.stringify({
query,
top_k: topK,
filters_serialized: JSON.stringify(filters),
});
const url = `${import.meta.env.VITE_BACKEND_URL}/retrieve?${params}`;
try {
const response = await getResponse(url);
Expand Down
4 changes: 2 additions & 2 deletions app/frontend/src/components/ChatPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import { ChatCollapsibleHint } from "./ChatCollapsibleHint";
import { ChatErrorMessage } from "./ChatErrorMessage";
import { addMessageUtilities, scrollToBottom } from "../utils";
import { MessageDocs } from "./MessageDocs";
import { RetrievalPanelCommandPalette } from "./RetrievePanelCommandPalette";
import { ChatPanelCommandPalette } from "./ChatPanelCommandPalette";
import { FCDeleteChatButton } from "./FCDeleteChatButton";

interface ChatPanelProps {
Expand Down Expand Up @@ -144,7 +144,7 @@ export const ChatPanel = ({
</ExternalLink>
}
>
<RetrievalPanelCommandPalette metaInfo={metaInfo} />
<ChatPanelCommandPalette metaInfo={metaInfo} />
</ChatCollapsibleHint>
<ChatInput
query={query}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* @file RetrievalPanelCommandPalette.tsx
* @file ChatPanelCommandPalette.tsx
*
* The command palette component in the retrieval panel.
* The command palette component in the chat panel.
*/

import { Code, DataList, Flex, Text } from "@radix-ui/themes";
Expand All @@ -10,13 +10,13 @@ import { PUBMED_URL } from "../consts";
import { MetaInfo } from "../types";
import { CopyButton } from "./CopyButton";

interface RetrievalPanelCommandPaletteProps {
interface ChatPanelCommandPaletteProps {
metaInfo: MetaInfo;
}

export const RetrievalPanelCommandPalette = ({
export const ChatPanelCommandPalette = ({
metaInfo,
}: RetrievalPanelCommandPaletteProps) => {
}: ChatPanelCommandPaletteProps) => {
const { title } = metaInfo;

return (
Expand Down
23 changes: 19 additions & 4 deletions app/frontend/src/components/RetrievePanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
*/

import { Flex, Text } from "@radix-ui/themes";
import { useEffect, useRef, useState } from "react";
import { Dispatch, SetStateAction, useEffect, useRef, useState } from "react";
import { MdFilterList } from "react-icons/md";
import { callRetrieve } from "../api";
import { ChatDisplay, MetaInfo, UpdateMessagesFunction } from "../types";
import {
ChatDisplay,
MetaInfo,
TrialFilters,
UpdateMessagesFunction,
} from "../types";
import { ChatPort } from "./ChatPort";
import { ChatInput } from "./ChatInput";
import { FCSendButton } from "./FCSendButton";
Expand All @@ -20,16 +25,21 @@ import { ChatErrorMessage } from "./ChatErrorMessage";
import { addMessageUtilities, scrollToBottom } from "../utils";
import { MessageRetrieved } from "./MessageRetrieved";
import { FCTopKSelector } from "./FCTopKSelector";
import { RetrievePanelFilters } from "./RetrievePanelFilters";

interface RetrievalPanelProps {
messages: ChatDisplay[];
setMessages: (fn: UpdateMessagesFunction) => void;
filters: TrialFilters;
setFilters: Dispatch<SetStateAction<TrialFilters>>;
switchTab: (tab: string, metaInfo: MetaInfo) => void;
}

export const RetrievePanel = ({
messages,
setMessages,
filters,
setFilters,
switchTab,
}: RetrievalPanelProps) => {
const chatPortRef = useRef<HTMLDivElement>(null);
Expand All @@ -44,7 +54,7 @@ export const RetrievePanel = ({
setQuery(""); // This will take effect only after the next render
addUserMessage(<Text size="2">{query}</Text>, query);

const callResult = await callRetrieve(query, topK);
const callResult = await callRetrieve(query, topK, filters);
if ("error" in callResult) {
addBotMessage(
<ChatErrorMessage error={callResult.error} />,
Expand Down Expand Up @@ -87,6 +97,11 @@ export const RetrievePanel = ({
scrollToBottom(chatPortRef);
}, [messages]);

// TODO: Remove
useEffect(() => {
console.log(filters);
}, [filters]);

return (
<Flex direction="column" justify="end" gap="5" px="3" height="100%">
<ChatPort ref={chatPortRef} messages={messages} loading={loading} />
Expand All @@ -96,7 +111,7 @@ export const RetrievePanel = ({
hintText="Retrieval filters"
HintIcon={MdFilterList}
>
TODO: FILTERS HERE!
<RetrievePanelFilters filters={filters} setFilters={setFilters} />
</ChatCollapsibleHint>
<ChatInput
query={query}
Expand Down
44 changes: 44 additions & 0 deletions app/frontend/src/components/RetrievePanelFilterStudyType.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
* @file RetrievePanelFilterStudyType.tsx
*
* The filter for the study type in the retrieval panel.
*/

import { Dispatch, SetStateAction } from "react";
import { Flex, RadioGroup } from "@radix-ui/themes";
import { TrialFilters } from "../types";

interface RetrievePanelFilterStudyTypeProps {
filters: TrialFilters;
setFilters: Dispatch<SetStateAction<TrialFilters>>;
}

export const RetrievePanelFilterStudyType = ({
filters,
setFilters,
}: RetrievePanelFilterStudyTypeProps) => {
const value = filters.studyType ?? "all";

// Handler for the filter change
const handler = (value: string) => {
setFilters((prevFilters) => ({
...prevFilters,
studyType: value === "all" ? undefined : value,
}));
};

return (
<RadioGroup.Root
variant="surface"
asChild
value={value}
onValueChange={handler}
>
<Flex direction="row" gap="4">
<RadioGroup.Item value="all">All</RadioGroup.Item>
<RadioGroup.Item value="interventional">Interventional</RadioGroup.Item>
<RadioGroup.Item value="observational">Observational</RadioGroup.Item>
</Flex>
</RadioGroup.Root>
);
};
41 changes: 41 additions & 0 deletions app/frontend/src/components/RetrievePanelFilters.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* @file RetrievePanelFilters.tsx
*
* The filters component in the retrieval panel.
*/

import { DataList } from "@radix-ui/themes";
import { Dispatch, SetStateAction } from "react";
import { TrialFilters } from "../types";
import { RetrievePanelFilterStudyType } from "./RetrievePanelFilterStudyType";

interface RetrievePanelFiltersProps {
filters: TrialFilters;
setFilters: Dispatch<SetStateAction<TrialFilters>>;
}

export const RetrievePanelFilters = ({
filters,
setFilters,
}: RetrievePanelFiltersProps) => {
return (
<DataList.Root
size="2"
css={{
rowGap: "var(--space-2)",
columnGap: "var(--space-6)",
padding: "var(--space-1) 0",
}}
>
<DataList.Item>
<DataList.Label minWidth="0">Study type</DataList.Label>
<DataList.Value>
<RetrievePanelFilterStudyType
filters={filters}
setFilters={setFilters}
/>
</DataList.Value>
</DataList.Item>
</DataList.Root>
);
};
4 changes: 4 additions & 0 deletions app/frontend/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ export type UpdateMessagesFunction = (
prevMessages: ChatDisplay[],
) => ChatDisplay[];

export interface TrialFilters {
studyType?: string;
}

/* ==========================================================================
* THE FOLLOWING ARE MIRROR DEFINITIONS OF BACKEND TYPES. CHECK OUT THE
* BACKEND API FOR REFERENCE.
Expand Down

0 comments on commit 011475b

Please sign in to comment.