-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
app: init backend/frontend framework and docker compose (#9)
* app: init app framework * solve cors issue and successfully call backend server in dev mode * fix production build * update apis * text area slight improvments * init retrieval panel layout * the chatport component * format retrieve response; add clear history and loading indicator * remove unneeded svg * remove redundant readme * fix wrong keys * pa update * chat panel * modularize code * better handling/display of internal server error * minor ui/ux improvements * minor ui * command palette * more modularization * better copy functionalities and toasts * header links and delete chat functionality * tab delete, auto scroll into view, and other minor ui/ux improvement
- Loading branch information
1 parent
42c6a72
commit 06fc9df
Showing
46 changed files
with
9,421 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
__pycache__/ | ||
node_modules/ | ||
dist/ | ||
*.tsbuildinfo | ||
|
||
# Data files | ||
data/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
all: build run | ||
|
||
build: | ||
docker-compose build | ||
|
||
run: | ||
docker-compose up | ||
|
||
# [DEV] Update lock files and development environment on the host machine | ||
devlock: | ||
@cd backend && pipenv lock && pipenv sync -d && pipenv clean | ||
@cd frontend && pnpm install | ||
|
||
# [DEV] Format and lint the codebase | ||
devlint: | ||
@cd backend && pipenv run bash -c "black . && ruff check --select I --fix . && ruff format . && mypy ." | ||
@cd frontend && pnpm format && pnpm lint | ||
|
||
# [DEV] Check formatting and linting of the codebase | ||
devlintcheck: | ||
@cd backend && pipenv run bash -c "black --check . && ruff check --select I . && mypy ." | ||
@cd frontend && pnpm format:check && pnpm lint:check | ||
|
||
# [DEV] Run the backend server only; this should be run in a separate terminal | ||
# window before running `devfrontend` | ||
devbackend: | ||
docker-compose up chromadb backend --build | ||
|
||
# [DEV] Run the frontend in development mode; this should be run in a separate | ||
# terminal window when `devbackend` is running | ||
devfrontend: | ||
cd frontend && VITE_BACKEND_URL=http://localhost:8001 pnpm dev | ||
|
||
.PHONY: build run |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
FROM python:3.11-slim-bookworm | ||
|
||
# Environment variables | ||
ENV DEBIAN_FRONTEND=noninteractive | ||
ENV PYENV_SHELL=/bin/bash | ||
ENV LANG=C.UTF-8 | ||
ENV PYTHONBUFFERED=1 | ||
|
||
# Install system dependencies and pipenv | ||
RUN set -ex; \ | ||
apt-get update && \ | ||
apt-get upgrade -y && \ | ||
apt-get install -y build-essential curl && \ | ||
rm -rf /var/lib/apt/lists/* && \ | ||
pip install --no-cache-dir --upgrade pip && \ | ||
pip install --no-cache-dir pipenv | ||
|
||
# Set up user and working directory | ||
RUN set -ex; \ | ||
useradd -ms /bin/bash veritastrial -d /home/veritastrial -u 1000 -p "$(openssl passwd -1 Passw0rd)" && \ | ||
mkdir -p /veritastrial && \ | ||
chown veritastrial:veritastrial /veritastrial | ||
USER veritastrial | ||
WORKDIR /veritastrial | ||
|
||
# Copy Pipfile and Pipfile.lock | ||
COPY --chown=veritastrial:veritastrial Pipfile Pipfile.lock /veritastrial/ | ||
|
||
# Install Python dependencies and clear cache | ||
RUN pipenv sync --clear && \ | ||
rm -rf /home/veritastrial/.cache/pip/* && \ | ||
rm -rf /home/veritastrial/.cache/pipenv/* | ||
|
||
# Add the rest of the source code; this is done last to take advantage of | ||
# Docker's layer caching mechanism | ||
COPY --chown=veritastrial:veritastrial *.py /veritastrial/ | ||
|
||
# Run the app on port 8001 | ||
EXPOSE 8001 | ||
CMD [ "pipenv", "run", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001" ] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[[source]] | ||
name = "pypi" | ||
url = "https://pypi.org/simple" | ||
verify_ssl = true | ||
|
||
[packages] | ||
chromadb-client = "*" | ||
fastapi = "*" | ||
FlagEmbedding = "*" | ||
google-cloud-aiplatform = "*" | ||
peft = "*" # XXX: required by FlagEmbedding, but not in setup.py before 1.0.6 | ||
typing-extensions = "*" | ||
uvicorn = "*" | ||
|
||
[dev-packages] | ||
black = "*" | ||
mypy = "*" | ||
ruff = "*" | ||
|
||
[requires] | ||
python_version = "3.11" |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import Literal, NotRequired, TypeAlias | ||
|
||
from typing_extensions import TypedDict | ||
|
||
ModelType: TypeAlias = Literal["gemini-1.5-flash-001", "6894888983713546240"] | ||
|
||
|
||
class TrialMetadataMeasureOutcomeType(TypedDict): | ||
measure: str | ||
description: str | ||
timeFrame: str | ||
|
||
|
||
class TrialMetadataInterventionType(TypedDict): | ||
type: str | ||
name: str | ||
description: str | ||
|
||
|
||
class TrialMetadataReferenceType(TypedDict): | ||
pmid: str | ||
citation: str | ||
|
||
|
||
class TrialMetadataDocumentType(TypedDict): | ||
url: str | ||
size: int | ||
|
||
|
||
class TrialMetadataType(TypedDict): | ||
shortTitle: str | ||
longTitle: str | ||
organization: str | ||
submitDate: str | ||
submitDateQc: str | ||
submitDatePosted: str | ||
resultsDate: str | ||
resultsDateQc: str | ||
resultsDatePosted: str | ||
lastUpdateDate: str | ||
lastUpdateDatePosted: str | ||
verifyDate: str | ||
sponsor: str | ||
collaborators: list[str] | ||
summary: str | ||
details: str | ||
conditions: list[str] | ||
studyPhases: str | ||
studyType: str | ||
enrollmentCount: int | ||
allocation: str | ||
interventionModel: str | ||
observationalModel: str | ||
primaryPurpose: str | ||
whoMasked: str | ||
interventions: list[TrialMetadataInterventionType] | ||
primaryMeasureOutcomes: list[TrialMetadataMeasureOutcomeType] | ||
secondaryMeasureOutcomes: list[TrialMetadataMeasureOutcomeType] | ||
otherMeasureOutcomes: list[TrialMetadataMeasureOutcomeType] | ||
minAge: int | ||
maxAge: int | ||
eligibleSex: str | ||
acceptsHealthy: bool | ||
inclusionCriteria: str | ||
exclusionCriteria: str | ||
officials: list[str] | ||
locations: list[str] | ||
references: list[TrialMetadataReferenceType] | ||
documents: list[TrialMetadataDocumentType] | ||
|
||
|
||
class APIHeartbeatResponseType(TypedDict): | ||
timestamp: int | ||
|
||
|
||
class APIRetrieveResponseType(TypedDict): | ||
ids: list[str] | ||
documents: list[str] | ||
|
||
|
||
class APIMetaResponseType(TypedDict): | ||
metadata: TrialMetadataType | ||
|
||
|
||
class APIChatResponseType(TypedDict): | ||
response: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import json | ||
import os | ||
import time | ||
|
||
import chromadb | ||
import vertexai # type: ignore | ||
from fastapi import FastAPI, HTTPException, Request | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from fastapi.openapi.utils import get_openapi | ||
from fastapi.responses import JSONResponse | ||
from FlagEmbedding import FlagModel # type: ignore | ||
from vertexai.generative_models import GenerativeModel # type: ignore | ||
|
||
from localtyping import ( | ||
APIChatResponseType, | ||
APIHeartbeatResponseType, | ||
APIMetaResponseType, | ||
APIRetrieveResponseType, | ||
ModelType, | ||
) | ||
from utils import _get_metadata_from_id, format_exc_details | ||
|
||
FRONTEND_URL = os.getenv("FRONTEND_URL", "http://localhost:8080") | ||
|
||
EMBEDDING_MODEL = FlagModel("BAAI/bge-small-en-v1.5", use_fp16=True) | ||
CHROMADB_CLIENT = chromadb.HttpClient(host="chromadb", port=8000) | ||
CHROMADB_COLLECTION = CHROMADB_CLIENT.get_collection("veritas-trial-embeddings") | ||
|
||
GCP_PROJECT_ID = "veritastrial" | ||
GCP_PROJECT_LOCATION = "us-central1" | ||
vertexai.init(project=GCP_PROJECT_ID, location=GCP_PROJECT_LOCATION) | ||
|
||
app = FastAPI(docs_url=None, redoc_url="/") | ||
|
||
# Handle cross-origin requests from the frontend | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=[FRONTEND_URL], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
|
||
|
||
def custom_openapi(): | ||
"""OpenAPI schema customization.""" | ||
if app.openapi_schema: | ||
return app.openapi_schema | ||
openapi_schema = get_openapi( | ||
title="VeritasTrial APIs", | ||
version="0.0.0", | ||
description="OpenAPI specification for the VeritasTrial APIs.", | ||
routes=app.routes, | ||
) | ||
openapi_schema["info"]["x-logo"] = { | ||
# TODO: Change to the VeritasTrial logo | ||
"url": "https://fastapi.tiangolo.com/img/logo-margin/logo-teal.png" | ||
} | ||
app.openapi_schema = openapi_schema | ||
return app.openapi_schema | ||
|
||
|
||
app.openapi = custom_openapi # type: ignore | ||
|
||
|
||
@app.exception_handler(Exception) | ||
async def custom_exception_handler(request: Request, exc: Exception): | ||
"""Custom handle for all types of exceptions.""" | ||
response = JSONResponse( | ||
status_code=500, | ||
content={"details": format_exc_details(exc)}, | ||
) | ||
# Manually set the CORS headers for the error response | ||
response.headers["Access-Control-Allow-Origin"] = FRONTEND_URL | ||
response.headers["Access-Control-Allow-Credentials"] = "true" | ||
return response | ||
|
||
|
||
@app.get("/heartbeat") | ||
async def heartbeat() -> APIHeartbeatResponseType: | ||
"""Get the current timestamp in nanoseconds.""" | ||
return {"timestamp": time.time_ns()} | ||
|
||
|
||
@app.get("/retrieve") | ||
async def retrieve(query: str, top_k: int) -> APIRetrieveResponseType: | ||
"""Retrieve items from the ChromaDB collection.""" | ||
if top_k <= 0 or top_k > 30: | ||
raise HTTPException(status_code=404, detail="Required 0 < top_k <= 30") | ||
|
||
# Embed the query and query the collection | ||
query_embedding = EMBEDDING_MODEL.encode(query) | ||
results = CHROMADB_COLLECTION.query( | ||
query_embeddings=[query_embedding], | ||
n_results=top_k, | ||
include=[chromadb.api.types.IncludeEnum("documents")], | ||
) | ||
|
||
# Retrieve the results | ||
ids = results["ids"][0] | ||
if results["documents"] is not None: | ||
documents = results["documents"][0] | ||
else: | ||
documents = [""] * len(ids) | ||
|
||
return {"ids": ids, "documents": documents} | ||
|
||
|
||
@app.get("/meta/{item_id}") | ||
async def meta(item_id: str) -> APIMetaResponseType: | ||
"""Retrieve metadata for a specific item.""" | ||
metadata = _get_metadata_from_id(CHROMADB_COLLECTION, item_id) | ||
if metadata is None: | ||
raise HTTPException(status_code=404, detail="Trial metadata not found") | ||
|
||
return {"metadata": metadata} | ||
|
||
|
||
@app.get("/chat/{model}/{item_id}") | ||
async def chat(model: ModelType, item_id: str, query: str) -> APIChatResponseType: | ||
"""Chat with a generative model about a specific item.""" | ||
metadata = _get_metadata_from_id(CHROMADB_COLLECTION, item_id) | ||
if metadata is None: | ||
raise HTTPException(status_code=404, detail="Trial metadata not found") | ||
|
||
# Determine the model to use | ||
if model not in ("gemini-1.5-flash-001",): | ||
model_name = ( | ||
f"projects/{GCP_PROJECT_ID}/locations/{GCP_PROJECT_LOCATION}/" | ||
f"endpoints/{model}" | ||
) | ||
else: | ||
model_name = model | ||
|
||
# Initialize the generative model | ||
gen_model = GenerativeModel( | ||
model_name=model_name, | ||
generation_config={ | ||
"max_output_tokens": 2048, | ||
"temperature": 0.75, | ||
"top_p": 0.95, | ||
}, | ||
) | ||
|
||
# Combine metadata into the query | ||
query = ( | ||
"You will be given the information of a clinical trial and asked a " | ||
"question. The information is as follows:\n\n" | ||
f"{json.dumps(metadata, indent=2)}\n\n" | ||
"## Question\n\n" | ||
f"{query}" | ||
) | ||
|
||
# Generate the response | ||
response = gen_model.generate_content(query, stream=False) | ||
return {"response": response.text.strip()} |
Oops, something went wrong.