Skip to content

Commit

Permalink
ci(test)/cd: refactor filtering and add tests; diable GPU in pipeline (
Browse files Browse the repository at this point in the history
…#88)

* modify filtering logic

* tests
  • Loading branch information
Charlie-XIAO authored Dec 11, 2024
1 parent 18c875b commit 3bce63e
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 134 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ node_modules/
dist/
*.tsbuildinfo
.coverage
.vscode/

# Generated by google-github-actions/auth action
gha-creds-*.json
Expand Down
19 changes: 9 additions & 10 deletions app/backend/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,19 @@ def sample_metadata():
return SAMPLE_METADATA


@pytest.fixture
def embedding_model():
"""Return a mock embedding model."""
return MockEmbeddingModel()


@pytest.fixture
def chromadb_client():
"""Return a mock ChromaDB client."""
return MockChromadbClient()


@pytest.fixture(autouse=True)
@pytest.fixture()
def setup(monkeypatch):
monkeypatch.setattr("vertexai.init", lambda: None)
monkeypatch.setattr("main.EMBEDDING_MODEL", MockEmbeddingModel())
monkeypatch.setattr("main.CHROMADB_CLIENT", MockChromadbClient())
def _setup(init_embedding_model=True, init_chromadb_client=True):
monkeypatch.setattr("vertexai.init", lambda: None)
if init_embedding_model:
monkeypatch.setattr("main.EMBEDDING_MODEL", MockEmbeddingModel())
if init_chromadb_client:
monkeypatch.setattr("main.CHROMADB_CLIENT", MockChromadbClient())

return _setup
126 changes: 20 additions & 106 deletions app/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import os
import time
from contextlib import asynccontextmanager
from datetime import datetime
from itertools import permutations

import chromadb
import chromadb.api
Expand All @@ -32,7 +30,12 @@
ModelType,
TrialFilters,
)
from utils import format_exc_details, get_metadata_from_id
from utils import (
construct_filters,
format_exc_details,
get_metadata_from_id,
post_filter,
)

logger = logging.getLogger("uvicorn.error")

Expand Down Expand Up @@ -101,28 +104,6 @@ def custom_openapi(): # pragma: no cover
return app.openapi_schema


def filter_by_date(results_full, date_key, from_date, to_date):
filtered_documents = []
filtered_ids = []

metadata = results_full["metadatas"][0]
documents = results_full["documents"][0]
ids = results_full["ids"][0]

for idx, meta in enumerate(metadata):
try:
date_str = meta.get(date_key)
if date_str:
date_obj = datetime.strptime(date_str, "%Y-%m-%d")
if from_date <= date_obj <= to_date:
filtered_documents.append(documents[idx])
filtered_ids.append(ids[idx])
except (ValueError, KeyError) as e:
print(f"Error processing metadata: {meta}. Error: {e}")

return {"ids": [filtered_ids], "documents": [filtered_documents]}


@app.exception_handler(Exception)
async def custom_exception_handler(
request: Request, exc: Exception
Expand Down Expand Up @@ -158,100 +139,33 @@ async def retrieve(
if top_k <= 0 or top_k > 30:
raise HTTPException(status_code=404, detail="Required 0 < top_k <= 30")

# Construct the filters
# Construct the filters; we will need to include the full metadata in the query
# results if post-filtering is needed, otherwise only documents are needed; TODO:
# we should avoid post-processing filters if possible
filters: TrialFilters = json.loads(filters_serialized)
processed_filters = []

if "studyType" in filters:
if filters["studyType"] == "interventional":
processed_filters.append({"study_type": "INTERVENTIONAL"})
elif filters["studyType"] == "observational":
processed_filters.append({"study_type": "OBSERVATIONAL"})

if "acceptsHealthy" in filters and not filters["acceptsHealthy"]:
# NOTE: If this filter is True, it means to accept healthy participants,
# and unhealthy participants are always accepted so it is equivalent to
# not having this filter at all
processed_filters.append({"accepts_healthy": False}) # type: ignore # TODO: Fix this

if "eligibleSex" in filters:
if filters["eligibleSex"] == "female":
processed_filters.append({"eligible_sex": "FEMALE"})
elif filters["eligibleSex"] == "male":
processed_filters.append({"eligible_sex": "MALE"})

# TODO: Change this to a post-filter
if "studyPhases" in filters and len(filters["studyPhases"]) > 0:
# NOTE: ChromaDB does not support the $contains operator on strings for
# metadata fields. Therefore, we need to generate all possible
# combinations and do exact matching
all_phases = ["EARLY_PHASE1", "PHASE1", "PHASE2", "PHASE3", "PHASE4"]
possible_values = []
for r in range(len(all_phases)):
for combo in permutations(all_phases, r + 1):
if any(phase in combo for phase in filters["studyPhases"]):
possible_values.append(", ".join(combo))
processed_filters.append({"study_phases": {"$in": possible_values}}) # type: ignore # TODO: Fix this

if "ageRange" in filters:
# NOTE: We want the age range to intersect with the desired range
min_age, max_age = filters["ageRange"]
processed_filters.append({"min_age": {"$lte": max_age}}) # type: ignore # TODO: Fix this
processed_filters.append({"max_age": {"$gte": min_age}}) # type: ignore # TODO: Fix this

# Construct the where clause
where: chromadb.Where | None
if len(processed_filters) == 0:
where = None
elif len(processed_filters) == 1:
where = processed_filters[0] # type: ignore # TODO: Fix this
else:
where = {"$and": processed_filters} # type: ignore # TODO: Fix this
needs_post_filter, where = construct_filters(filters)
include = [chromadb.api.types.IncludeEnum("documents")]
if needs_post_filter:
include.append(chromadb.api.types.IncludeEnum("metadatas"))

# Embed the query and query the collection
query_embedding = EMBEDDING_MODEL.encode(query)
collection = CHROMADB_CLIENT.get_collection(CHROMADB_COLLECTION_NAME)
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=[chromadb.api.types.IncludeEnum("documents")],
include=include,
where=where,
)

results_full = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=[
chromadb.api.types.IncludeEnum("documents"),
chromadb.api.types.IncludeEnum("metadatas"),
],
where=where,
)

if "lastUpdateDatePosted" in filters:
from_timestamp, to_timestamp = filters["lastUpdateDatePosted"]
from_date = datetime.fromtimestamp(from_timestamp / 1000)
to_date = datetime.fromtimestamp(to_timestamp / 1000)
results = filter_by_date(
results_full, "last_update_date_posted", from_date, to_date
)
# Post-filter the results if needed
if needs_post_filter:
return post_filter(results, filters)

if "resultsDatePosted" in filters:
from_timestamp, to_timestamp = filters["resultsDatePosted"]
from_date = datetime.fromtimestamp(from_timestamp / 1000)
to_date = datetime.fromtimestamp(to_timestamp / 1000)
results = filter_by_date(
results_full, "results_date_posted", from_date, to_date
)

# Retrieve the results
# Retrieve the results as is
ids = results["ids"][0]
if results["documents"] is not None:
documents = results["documents"][0]
else:
documents = [""] * len(ids)

return {"ids": ids, "documents": documents}
assert results["documents"] is not None, "Missing documents in query results"
return {"ids": ids, "documents": results["documents"][0]}


@app.get("/meta/{item_id}")
Expand Down
48 changes: 43 additions & 5 deletions app/backend/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ def test_heartbeat():


@pytest.mark.parametrize("top_k", [1, 3, 5])
def test_retrieve(top_k):
def test_retrieve(top_k, setup):
"""Test the /retrieve endpoint."""
setup()
filters_serialized = quote(json.dumps({}))
response = client.get(
f"/retrieve?query=Dummy Query&{top_k=}&filters_serialized={filters_serialized}"
f"/retrieve?query=Dummy&{top_k=}&filters_serialized={filters_serialized}"
)
assert response.status_code == 200
response_json = response.json()
Expand All @@ -35,24 +36,61 @@ def test_retrieve(top_k):


@pytest.mark.parametrize("top_k", [0, 31])
def test_retrieve_invalid_top_k(top_k):
def test_retrieve_invalid_top_k(top_k, setup):
"""Test the /retrieve endpoint with invalid top_k."""
setup()
filters_serialized = quote(json.dumps({}))
response = client.get(
f"/retrieve?query=Dummy Query&{top_k=}&filters_serialized={filters_serialized}"
f"/retrieve?query=Dummy&{top_k=}&filters_serialized={filters_serialized}"
)
assert response.status_code == 404
assert "Required 0 < top_k <= 30" in response.text


@pytest.mark.parametrize("init_embedding_model", [True, False])
@pytest.mark.parametrize("init_chromadb_client", [True, False])
def test_retrieve_partial_setup(setup, init_embedding_model, init_chromadb_client):
"""Test the /retrieve endpoint with partial setup."""
setup(
init_embedding_model=init_embedding_model,
init_chromadb_client=init_chromadb_client,
)

def _run():
filters_serialized = quote(json.dumps({}))
return client.get(
f"/retrieve?query=Dummy&top_k=3&filters_serialized={filters_serialized}"
)

if not init_embedding_model:
with pytest.raises(RuntimeError, match="Embedding model not initialized"):
_run()
return
if not init_chromadb_client:
with pytest.raises(RuntimeError, match="ChromaDB not reachable"):
_run()
return

response = _run()
assert response.status_code == 200


@pytest.mark.parametrize("item_id", ["id0", "id1", "id2"])
def test_meta(item_id):
def test_meta(item_id, setup):
"""Test the /meta/{item_id} endpoint."""
setup()
response = client.get(f"/meta/{item_id}")
assert response.status_code == 200
response_json = response.json()
assert "metadata" in response_json
assert response_json["metadata"]["shortTitle"] == f"Sample Metadata {item_id}"


def test_meta_partial_setup(setup):
"""Test the /meta/{item_id} endpoint with partial setup."""
setup(init_chromadb_client=False)
with pytest.raises(RuntimeError, match="ChromaDB not reachable"):
client.get("/meta/id0")


# TODO: Add tests for the /chat/{model}/{item_id} endpoint
47 changes: 46 additions & 1 deletion app/backend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from utils import format_exc_details, get_metadata_from_id
from utils import construct_filters, format_exc_details, get_metadata_from_id


def _snake_to_camel(snake_str):
Expand All @@ -32,6 +32,7 @@ def test_format_exc_details():

@pytest.mark.parametrize("item_id", ["id0", "id1", "id2"])
def test_get_metadata_from_id(item_id, chromadb_client, sample_metadata):
"""Test getting metadata from an item ID."""
collection = chromadb_client.get_collection("test")
metadata = get_metadata_from_id(collection, item_id)
assert metadata["shortTitle"] == f"Sample Metadata {item_id}"
Expand Down Expand Up @@ -60,3 +61,47 @@ def test_get_metadata_from_id(item_id, chromadb_client, sample_metadata):
assert item["timeFrame"] == item_exp["time_frame"]
else:
assert metadata[_snake_to_camel(key)] == value


@pytest.mark.parametrize(
"filters, expected_where, expected_needs_post_filter",
[
# Single filter
({"studyType": "interventional"}, {"study_type": "INTERVENTIONAL"}, False),
({"acceptsHealthy": False}, {"accepts_healthy": False}, False),
({"acceptsHealthy": True}, None, False),
({"eligibleSex": "female"}, {"eligible_sex": "FEMALE"}, False),
(
{"ageRange": (18, 30)},
{"$and": [{"min_age": {"$lte": 30}}, {"max_age": {"$gte": 18}}]},
False,
),
# Multiple filters
(
{"studyType": "interventional", "acceptsHealthy": False},
{"$and": [{"study_type": "INTERVENTIONAL"}, {"accepts_healthy": False}]},
False,
),
(
{"eligibleSex": "female", "ageRange": (18, 30)},
{
"$and": [
{"eligible_sex": "FEMALE"},
{"min_age": {"$lte": 30}},
{"max_age": {"$gte": 18}},
]
},
False,
),
# Expected post-processing
({"studyPhases": ["PHASE1"]}, None, True),
({"lastUpdateDatePosted": (0, 0)}, None, True),
({"resultsDatePosted": (0, 0)}, None, True),
],
)
def test_construct_filters(filters, expected_where, expected_needs_post_filter):
"""Test constructing filters."""
assert construct_filters(filters) == (expected_needs_post_filter, expected_where)


# TODO: Add tests for the post_filter function
Loading

0 comments on commit 3bce63e

Please sign in to comment.