Skip to content

Commit

Permalink
preview for hmi ids
Browse files Browse the repository at this point in the history
  • Loading branch information
satchelbaldwin committed Feb 8, 2024
1 parent 2e4df50 commit 3c1f84f
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 13 deletions.
37 changes: 31 additions & 6 deletions api/dataset/remote.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from concurrent.futures import ThreadPoolExecutor
import glob
import xarray
from typing import List
from typing import List, Tuple
from api.search.provider import AccessURLs
from api.settings import default_settings
import os
Expand Down Expand Up @@ -71,7 +71,9 @@ def open_dataset(paths: AccessURLs, job_id=None) -> xarray.Dataset:
raise IOError(
"http downloads must have an associated job id for cleanup purposes"
)
ds = open_remote_dataset_http(http_urls, job_id)
ds = open_remote_dataset_http(
http_urls, job_id, default_settings.esgf_openid
)
return ds
except IOError as e:
print(f"failed to download via plain http: {e}")
Expand All @@ -96,23 +98,25 @@ def open_remote_dataset_s3(urls: List[str]) -> xarray.Dataset:
return xarray.merge(files)


def download_file_http(url: str, dir: str):
def download_file_http(url: str, dir: str, auth: Tuple[str, str] | None = None):
rs = requests.get(url, stream=True)
if rs.status_code == 401:
rs = requests.get(url, stream=True, auth=default_settings.esgf_openid)
rs = requests.get(url, stream=True, auth=auth)
filename = url.split("/")[-1]
print("writing ", os.path.join(dir, filename))
with open(os.path.join(dir, filename), mode="wb") as file:
for chunk in rs.iter_content(chunk_size=10 * 1024):
file.write(chunk)


def open_remote_dataset_http(urls: List[str], job_id) -> xarray.Dataset:
def open_remote_dataset_http(
urls: List[str], job_id: str, auth: Tuple[str, str]
) -> xarray.Dataset:
temp_directory = os.path.join(".", str(job_id))
if not os.path.exists(temp_directory):
os.makedirs(temp_directory)
with ThreadPoolExecutor() as executor:
executor.map(lambda url: download_file_http(url, temp_directory), urls)
executor.map(lambda url: download_file_http(url, temp_directory, auth), urls)
files = [os.path.join(temp_directory, f) for f in os.listdir(temp_directory)]
ds = xarray.open_mfdataset(
files,
Expand All @@ -132,3 +136,24 @@ def cleanup_potential_artifacts(job_id):
for file in glob.glob(os.path.join(temp_directory, "*.nc")):
os.remove(file)
os.removedirs(temp_directory)


def open_remote_dataset_hmi(dataset_id: str, job_id: str) -> xarray.Dataset:
base_url = f"{default_settings.terarium_url}/datasets/{dataset_id}"
auth = (default_settings.terarium_user, default_settings.terarium_pass)
response = requests.get(base_url, auth=auth)
if response.status_code != 200:
errors = {
204: "does not exist (204)",
400: "malformed request (400)",
500: "upstream error (500)",
401: "auth error (401)",
}
msg = errors.get(response.status_code, f"unknown error {response.status_code}")
raise IOError(f"Dataset not found: {msg}")
filenames = response.json().get("fileNames", [])
if len(filenames) == 0:
raise IOError("Dataset has no associated files")
filenames = [f"{base_url}/download-file?filename={f}" for f in filenames]

return open_remote_dataset_http(filenames, job_id, auth)
17 changes: 16 additions & 1 deletion api/preview/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import xarray
from matplotlib import pyplot as plt
from typing import List
from api.dataset.remote import cleanup_potential_artifacts, open_dataset
from api.dataset.remote import (
cleanup_potential_artifacts,
open_dataset,
open_remote_dataset_hmi,
)


def buffer_to_b64_png(buffer: io.BytesIO) -> str:
Expand Down Expand Up @@ -33,6 +37,17 @@ def render_preview_for_dataset(
return {"error": f"upstream hosting is likely having a problem. {e}"}


def render_preview_for_hmi(uuid: str, **kwargs):
job_id = kwargs["job_id"]
try:
ds = open_remote_dataset_hmi(uuid, job_id)
png = render(ds=ds)
cleanup_potential_artifacts(job_id)
return {"png": png}
except IOError as e:
return {"error": f"failed with error {e}"}


def render(
ds,
variable_index: str = "",
Expand Down
8 changes: 8 additions & 0 deletions api/search/providers/esgf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from api.settings import default_settings
from api.search.provider import (
AccessURLs,
Expand Down Expand Up @@ -140,6 +141,13 @@ def initialize_embeddings(self, force_refresh=False):
)
pickle.dump(self.embeddings, f)

def is_terarium_hmi_dataset(self, dataset_id: str) -> bool:
"""
checks if a dataset id is HMI or ESGF - uuid regex
"""
p = re.compile(r"^[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}$")
return bool(p.match(dataset_id.lower()))

def search(
self, query: str, page: int, force_refresh_cache: bool = False
) -> DatasetSearchResults:
Expand Down
19 changes: 13 additions & 6 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from openai import OpenAI
from urllib.parse import parse_qs
from typing import List, Dict
from api.preview.render import render_preview_for_dataset
from api.preview.render import render_preview_for_dataset, render_preview_for_hmi

app = FastAPI(docs_url="/")
client = OpenAI()
Expand Down Expand Up @@ -50,11 +50,18 @@ async def esgf_subset(

@app.get(path="/preview/esgf")
async def esgf_preview(dataset_id: str, redis=Depends(get_redis)):
urls = esgf.get_all_access_paths_by_id(dataset_id)
job = create_job(
func=render_preview_for_dataset, args=[urls], redis=redis, queue="preview"
)
return job
if esgf.is_terarium_hmi_dataset(dataset_id):
print("terarium uuid found", flush=True)
job = create_job(
func=render_preview_for_hmi, args=[dataset_id], redis=redis, queue="preview"
)
return job
else:
urls = esgf.get_all_access_paths_by_id(dataset_id)
job = create_job(
func=render_preview_for_dataset, args=[urls], redis=redis, queue="preview"
)
return job


@app.get(path="/status/{job_id}")
Expand Down

0 comments on commit 3c1f84f

Please sign in to comment.