diff --git a/api/processing/providers/esgf.py b/api/processing/providers/esgf.py index 69d366c..0a8f913 100644 --- a/api/processing/providers/esgf.py +++ b/api/processing/providers/esgf.py @@ -1,3 +1,4 @@ +import base64 from .. import filters from api.settings import default_settings import xarray @@ -7,7 +8,8 @@ import s3fs import matplotlib.pyplot as plt import numpy as np -from mpl_toolkits.basemap import Basemap +import io +import cartopy.crs as ccrs # we have to operate on urls / dataset_ids due to the fact that # rq jobs can't pass the context of a loaded xarray dataset in memory (json serialization) @@ -25,16 +27,17 @@ def open_remote_dataset(urls: List[str]) -> xarray.Dataset: except IOError as e: print(f"failed to open parallel: {e}") try: - ds = xarray.open_mfdataset(urls) + ds = xarray.open_mfdataset(urls, concat_dim="time", combine="nested") except IOError as e: - print(f"failed to open sequentially: {e}") - raise IOError(e) + print(f"failed to open sequentially, falling back to s3: {e}") + return open_remote_dataset_s3(urls) return ds def open_remote_dataset_s3(urls: List[str]) -> xarray.Dataset: fs = s3fs.S3FileSystem(anon=True) urls = ["s3://esgf-world" + url[url.find("/CMIP6") :] for url in urls] + print(urls, flush=True) files = [xarray.open_dataset(fs.open(url), chunks={"time": 10}) for url in urls] return xarray.merge(files) @@ -83,12 +86,19 @@ def slice_and_store_dataset( return {"status": "failed", "error": str(e), "dataset_id": ""} +def buffer_to_b64_png(buffer: io.BytesIO) -> str: + buffer.seek(0) + content = buffer.read() + payload = base64.b64encode(content).decode("utf-8") + return f"data:image/png;base64,{payload}" + + def render_preview_for_dataset( urls: List[str], - dataset_id: str, - variable_index: str, - time_index: str, - timestamps: str, + variable_index: str = "", + time_index: str = "", + timestamps: str = "", + **kwargs, ): ds = open_remote_dataset(urls) axes = {} @@ -109,39 +119,19 @@ def render_preview_for_dataset( ds = ds.sel({time_index: ds[time_index][0]}) else: ds = ds.sel({time_index: slice(timestamps.split(","))}) + + # we're plotting x, y, time - others need to be shortened to the first element + other_axes = [axis for axis in axes if axis not in ["X", "Y", "T"]] + for axis in other_axes: + ds = ds.sel({axes[axis]: ds[axes[axis]][0]}) + ds = ds[variable_index] - x_points = ds[axes["X"]][:] - y_points = ds[axes["Y"]][:] - units = ds.units - name = ds.long_name - center_x = x_points.mean() - center_y = y_points.mean() - - m = Basemap( - width=5000000, - height=3500000, - resolution="l", - projection="merc", - llcrnrlat=-80, - urcrnrlat=80, - llcrnrlon=0, - urcrnrlon=360, - ) - lon, lat = np.meshgrid(x_points, y_points) - xi, yi = m(lon, lat) - cs = m.pcolor(xi, yi, np.squeeze(ds)) - m.drawparallels(np.arange(-80.0, 81.0, 10.0), labels=[1, 0, 0, 0], fontsize=10) - m.drawmeridians(np.arange(-180.0, 181.0, 10.0), labels=[0, 0, 0, 1], fontsize=10) - - # Add Coastlines, States, and Country Boundaries - m.drawcoastlines() - m.drawstates() - m.drawcountries() - - # Add Colorbar - cbar = m.colorbar(cs, location="bottom", pad="10%") - cbar.set_label(units) - - # Add Title - plt.title(name) - plt.savefig("x.png") + + fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()}) + ds.plot(transform=ccrs.PlateCarree(), x=axes["X"], y=axes["Y"], add_colorbar=True) + ax.coastlines() + + buffer = io.BytesIO() + plt.savefig(buffer, format="png") + + return {"png": buffer_to_b64_png(buffer)} diff --git a/api/search/providers/esgf.py b/api/search/providers/esgf.py index 6ba53ef..76fb31c 100644 --- a/api/search/providers/esgf.py +++ b/api/search/providers/esgf.py @@ -79,6 +79,7 @@ "activity_id", "nominal_resolution", "frequency", + "realm", ], # only take exact matches "exact": [ diff --git a/api/server.py b/api/server.py index 220483a..5e05273 100644 --- a/api/server.py +++ b/api/server.py @@ -1,6 +1,9 @@ from fastapi import FastAPI, Request, Depends from api.search.providers.esgf import ESGFProvider -from api.processing.providers.esgf import slice_and_store_dataset +from api.processing.providers.esgf import ( + render_preview_for_dataset, + slice_and_store_dataset, +) from api.dataset.job_queue import create_job, fetch_job_status, get_redis from openai import OpenAI from urllib.parse import parse_qs @@ -43,6 +46,15 @@ async def esgf_subset(request: Request, redis=Depends(get_redis), dataset_id: st return job +@app.get(path="/preview/esgf") +async def esgf_preview(dataset_id: str, redis=Depends(get_redis)): + urls = esgf.get_access_urls_by_id(dataset_id) + job = create_job( + func=render_preview_for_dataset, args=[urls], redis=redis, queue="preview" + ) + return job + + @app.get(path="/status/{job_id}") async def job_status(job_id: str, redis=Depends(get_redis)): return fetch_job_status(job_id, redis=redis)