Skip to content

Commit

Permalink
preview endpoint preliminary support
Browse files Browse the repository at this point in the history
  • Loading branch information
satchelbaldwin committed Feb 1, 2024
1 parent 979c8f2 commit 33d3aef
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 44 deletions.
76 changes: 33 additions & 43 deletions api/processing/providers/esgf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from .. import filters
from api.settings import default_settings
import xarray
Expand All @@ -7,7 +8,8 @@
import s3fs
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.basemap import Basemap
import io
import cartopy.crs as ccrs

# we have to operate on urls / dataset_ids due to the fact that
# rq jobs can't pass the context of a loaded xarray dataset in memory (json serialization)
Expand All @@ -25,16 +27,17 @@ def open_remote_dataset(urls: List[str]) -> xarray.Dataset:
except IOError as e:
print(f"failed to open parallel: {e}")
try:
ds = xarray.open_mfdataset(urls)
ds = xarray.open_mfdataset(urls, concat_dim="time", combine="nested")
except IOError as e:
print(f"failed to open sequentially: {e}")
raise IOError(e)
print(f"failed to open sequentially, falling back to s3: {e}")
return open_remote_dataset_s3(urls)
return ds


def open_remote_dataset_s3(urls: List[str]) -> xarray.Dataset:
fs = s3fs.S3FileSystem(anon=True)
urls = ["s3://esgf-world" + url[url.find("/CMIP6") :] for url in urls]
print(urls, flush=True)
files = [xarray.open_dataset(fs.open(url), chunks={"time": 10}) for url in urls]
return xarray.merge(files)

Expand Down Expand Up @@ -83,12 +86,19 @@ def slice_and_store_dataset(
return {"status": "failed", "error": str(e), "dataset_id": ""}


def buffer_to_b64_png(buffer: io.BytesIO) -> str:
buffer.seek(0)
content = buffer.read()
payload = base64.b64encode(content).decode("utf-8")
return f"data:image/png;base64,{payload}"


def render_preview_for_dataset(
urls: List[str],
dataset_id: str,
variable_index: str,
time_index: str,
timestamps: str,
variable_index: str = "",
time_index: str = "",
timestamps: str = "",
**kwargs,
):
ds = open_remote_dataset(urls)
axes = {}
Expand All @@ -109,39 +119,19 @@ def render_preview_for_dataset(
ds = ds.sel({time_index: ds[time_index][0]})
else:
ds = ds.sel({time_index: slice(timestamps.split(","))})

# we're plotting x, y, time - others need to be shortened to the first element
other_axes = [axis for axis in axes if axis not in ["X", "Y", "T"]]
for axis in other_axes:
ds = ds.sel({axes[axis]: ds[axes[axis]][0]})

ds = ds[variable_index]
x_points = ds[axes["X"]][:]
y_points = ds[axes["Y"]][:]
units = ds.units
name = ds.long_name
center_x = x_points.mean()
center_y = y_points.mean()

m = Basemap(
width=5000000,
height=3500000,
resolution="l",
projection="merc",
llcrnrlat=-80,
urcrnrlat=80,
llcrnrlon=0,
urcrnrlon=360,
)
lon, lat = np.meshgrid(x_points, y_points)
xi, yi = m(lon, lat)
cs = m.pcolor(xi, yi, np.squeeze(ds))
m.drawparallels(np.arange(-80.0, 81.0, 10.0), labels=[1, 0, 0, 0], fontsize=10)
m.drawmeridians(np.arange(-180.0, 181.0, 10.0), labels=[0, 0, 0, 1], fontsize=10)

# Add Coastlines, States, and Country Boundaries
m.drawcoastlines()
m.drawstates()
m.drawcountries()

# Add Colorbar
cbar = m.colorbar(cs, location="bottom", pad="10%")
cbar.set_label(units)

# Add Title
plt.title(name)
plt.savefig("x.png")

fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
ds.plot(transform=ccrs.PlateCarree(), x=axes["X"], y=axes["Y"], add_colorbar=True)
ax.coastlines()

buffer = io.BytesIO()
plt.savefig(buffer, format="png")

return {"png": buffer_to_b64_png(buffer)}
1 change: 1 addition & 0 deletions api/search/providers/esgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"activity_id",
"nominal_resolution",
"frequency",
"realm",
],
# only take exact matches
"exact": [
Expand Down
14 changes: 13 additions & 1 deletion api/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from fastapi import FastAPI, Request, Depends
from api.search.providers.esgf import ESGFProvider
from api.processing.providers.esgf import slice_and_store_dataset
from api.processing.providers.esgf import (
render_preview_for_dataset,
slice_and_store_dataset,
)
from api.dataset.job_queue import create_job, fetch_job_status, get_redis
from openai import OpenAI
from urllib.parse import parse_qs
Expand Down Expand Up @@ -43,6 +46,15 @@ async def esgf_subset(request: Request, redis=Depends(get_redis), dataset_id: st
return job


@app.get(path="/preview/esgf")
async def esgf_preview(dataset_id: str, redis=Depends(get_redis)):
urls = esgf.get_access_urls_by_id(dataset_id)
job = create_job(
func=render_preview_for_dataset, args=[urls], redis=redis, queue="preview"
)
return job


@app.get(path="/status/{job_id}")
async def job_status(job_id: str, redis=Depends(get_redis)):
return fetch_job_status(job_id, redis=redis)

0 comments on commit 33d3aef

Please sign in to comment.