Skip to content

Commit

Permalink
Merge pull request #6 from DARPA-ASKEM/terarium-integrations
Browse files Browse the repository at this point in the history
Yearly Previews
  • Loading branch information
satchelbaldwin authored Mar 21, 2024
2 parents 705ae59 + 79e0ef5 commit addcc89
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 82 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ Required Parameters:

Optional Parameters:
* `variable_id`: override the variable to render in the preview.
* `timestamps`: plot over a list of times. much slower, work in progress
* `timestamps`: plot over a list of times.
* The format should be `start,end` -- two values, comma separated.
* Example: `1970,1979`
* `time_index`: override time index to use.


Expand Down Expand Up @@ -109,6 +111,8 @@ Optional Parameters:
* Preserving all other fields, take every third data point from the fields `lat` and `lon`
* `thin_factor=2&thin_fields=!time,lev`
* Preserving all other fields, take every other data point from all fields *except* `time` and `lev`.
* `variable_id`:
* Which variable to render in the preview. Defaults to `""`. Will attempt to choose the best relevant variable if none is specified.

Output:
Returns a job description of the current process, queued to be completed.
Expand Down Expand Up @@ -160,7 +164,8 @@ Output:
"urls": [
"http://esgf-data.node.example/...",
"http://esgf-data.node.example/..."
]
],
"metadata": {}
}
```

Expand Down
27 changes: 18 additions & 9 deletions api/dataset/terarium_hmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def generate_description(
return string


def enumerate_dataset_skeleton(ds: xarray.Dataset, parent_id: str) -> HMIDataset:
def enumerate_dataset_skeleton(
ds: xarray.Dataset, parent_id: str, variable_id: str = ""
) -> HMIDataset:
"""
generates the generic body of the metadata field from a given dataset.
this function should remain as broadly applicable as possible with the only difference
Expand All @@ -46,9 +48,11 @@ def enumerate_dataset_skeleton(ds: xarray.Dataset, parent_id: str) -> HMIDataset
note: continues on preview not working with an exception!
"""
try:
preview = render(ds)
start = ds.isel(time=0).time.item().year
end = ds.isel(time=-1).time.item().year
preview = render(ds, timestamps=f"{start},{end}", variable_index=variable_id)
except Exception as e:
preview = ""
preview = f"error creating preview: {e}"
print(e, flush=True)
hmi_dataset = {
"userId": "",
Expand All @@ -62,9 +66,11 @@ def enumerate_dataset_skeleton(ds: xarray.Dataset, parent_id: str) -> HMIDataset
"dataStructure": {
k: {
"attrs": {
ak: ds[k].attrs[ak].item()
if isinstance(ds[k].attrs[ak], numpy.generic)
else ds[k].attrs[ak]
ak: (
ds[k].attrs[ak].item()
if isinstance(ds[k].attrs[ak], numpy.generic)
else ds[k].attrs[ak]
)
for ak in ds[k].attrs
# _ChunkSizes is an unserializable ndarray, safely ignorable
if ak != "_ChunkSizes"
Expand All @@ -75,9 +81,11 @@ def enumerate_dataset_skeleton(ds: xarray.Dataset, parent_id: str) -> HMIDataset
for k in ds.variables.keys()
},
"raw": {
k: ds.attrs[k].item()
if isinstance(ds.attrs[k], numpy.generic)
else ds.attrs[k]
k: (
ds.attrs[k].item()
if isinstance(ds.attrs[k], numpy.generic)
else ds.attrs[k]
)
for k in ds.attrs.keys()
},
},
Expand All @@ -92,6 +100,7 @@ def construct_hmi_dataset(
parent_dataset_id: str,
subset_uuid: str,
opts: DatasetSubsetOptions,
variable_id: str = "",
) -> HMIDataset:
"""
generic function for turning a given subset dataset into a terarium-postable request body.
Expand Down
114 changes: 87 additions & 27 deletions api/preview/render.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import datetime
import io
import base64
from api.search.provider import AccessURLs
import cartopy.crs as ccrs
import xarray
from matplotlib import pyplot as plt
from typing import List
from api.dataset.remote import (
cleanup_potential_artifacts,
open_dataset,
Expand All @@ -21,41 +21,44 @@ def buffer_to_b64_png(buffer: io.BytesIO) -> str:

# handles loading as to not share xarray over rq-worker boundaries
def render_preview_for_dataset(
urls: AccessURLs,
dataset: AccessURLs | str,
variable_index: str = "",
time_index: str = "",
timestamps: str = "",
**kwargs,
):
job_id = kwargs["job_id"]
try:
ds = open_dataset(urls, job_id)
png = render(ds, variable_index, time_index, timestamps)
ds: xarray.Dataset | None = None
# AccessURLs list or UUID str -- UUID str is terarium handle.
if isinstance(dataset, list):
ds = open_dataset(dataset, job_id)
elif isinstance(dataset, str):
ds = open_remote_dataset_hmi(dataset, job_id)
if timestamps != "":
if len(timestamps.split(",")) != 2:
return {
"error": f"invalid timestamps '{timestamps}'. ensure it is two timestamps, comma separated"
}
try:
png = render(ds, variable_index, time_index, timestamps)
except KeyError as e:
return {"error": f"{e}"}
cleanup_potential_artifacts(job_id)
return {"png": png}
return {"previews": png}
except IOError as e:
return {"error": f"upstream hosting is likely having a problem. {e}"}


def render_preview_for_hmi(uuid: str, **kwargs):
job_id = kwargs["job_id"]
try:
ds = open_remote_dataset_hmi(uuid, job_id)
png = render(ds=ds)
cleanup_potential_artifacts(job_id)
return {"png": png}
except IOError as e:
return {"error": f"failed with error {e}"}


def render(
ds,
ds: xarray.Dataset,
variable_index: str = "",
time_index: str = "",
timestamps: str = "",
**kwargs,
):
) -> list[dict[str, str]]:
axes = {}

for v in ds.variables.keys():
if "axis" in ds[v].attrs:
axes[ds[v].attrs["axis"]] = v
Expand All @@ -69,11 +72,22 @@ def render(
time_index = axes["T"]
else:
raise IOError("Dataset has no time axis, please provide time index")

# fix nonmonotonic time series without changing original data
print("cleaning up duplicates...", flush=True)
ds = ds.drop_duplicates(time_index, keep="first")
print("cleaned up duplicates.", flush=True)

if timestamps == "":
ds = ds.sel({time_index: ds[time_index][0]})
else:
ds = ds.sel({time_index: slice(timestamps.split(","))})

ts = [t.strip() for t in timestamps.split(",")]
try:
ds = ds.sel({time_index: slice(*ts)})
except KeyError as e:
msg = f"failed to create valid timestamp range: {e}"
print(msg, flush=True)
raise KeyError(msg)
# we're plotting x, y, time - others need to be shortened to the first element
print(axes, flush=True)
other_axes = [axis for axis in axes if axis not in ["X", "Y", "T"]]
Expand All @@ -85,13 +99,59 @@ def render(
f"failed to trim non-relevant axis {axis}: {ds[axes[axis]]}: {e}: (this can be safely ignored if expected)"
)

ds = ds[variable_index]
ds = ds[variable_index] # type: ignore

fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
ds.plot(transform=ccrs.PlateCarree(), x=axes["X"], y=axes["Y"], add_colorbar=True)
ax.coastlines()
preview_buffers: list[tuple[str, io.BytesIO]] = []

buffer = io.BytesIO()
plt.savefig(buffer, format="png")
def make_plot(data: xarray.Dataset) -> io.BytesIO:
fig, ax = plt.subplots(subplot_kw={"projection": ccrs.PlateCarree()})
data.plot(
ax=ax,
transform=ccrs.PlateCarree(),
x=axes["X"],
y=axes["Y"],
add_colorbar=True,
)
ax.coastlines()
buffer = io.BytesIO()
plt.savefig(buffer, format="png")
plt.close()
return buffer

return buffer_to_b64_png(buffer)
if axes["T"] in ds.dims:
# get delta of first two elements to see if it's yearly / monthly / daily
delta = ds[axes["T"]][1].item() - ds[axes["T"]][0].item()
steps = 0
if delta > datetime.timedelta(days=32):
steps = 1
elif delta > datetime.timedelta(days=1):
steps = 12
else:
steps = 365

leap_offset = 0
last_year = 0
# skip by frequency such that index points to head of year
for time_i in range(0, len(ds[axes["T"]]), steps):
# handle leap years
index = time_i + leap_offset
if index >= len(ds[axes["T"]]):
break
year_check = ds.isel({axes["T"]: index})[axes["T"]].item().year
if year_check == last_year:
leap_offset += 1
index += 1

data = ds.isel({axes["T"]: index})
date = data[axes["T"]].item()
print(f"rendering: {date}", flush=True)
last_year = date.year
preview_buffers.append((date.year, make_plot(data)))
else:
# single element rather than list
year = ds[axes["T"]].item().year
print(f"rendering: {year}:", flush=True)
preview_buffers.append((year, make_plot(ds)))
renders = [{"year": y, "image": buffer_to_b64_png(b)} for (y, b) in preview_buffers]
print(f"created {len(renders)} previews", flush=True)
return renders
4 changes: 3 additions & 1 deletion api/processing/providers/esgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def slice_and_store_dataset(
parent_id: str,
dataset_id: str,
params: Dict[str, Any],
variable_id: str,
**kwargs,
):
job_id = kwargs["job_id"]
Expand Down Expand Up @@ -53,9 +54,10 @@ def slice_and_store_dataset(
parent_id,
job_id,
filters.options_from_url_parameters(params),
variable_id,
)
hmi_id = post_hmi_dataset(dataset, filename)
return {"status": "ok", "dataset_id": hmi_id}
return {"status": "ok", "dataset_id": hmi_id, "filename": filename}
except Exception as e:
return {"status": "failed", "error": str(e), "dataset_id": ""}
finally:
Expand Down
30 changes: 23 additions & 7 deletions api/search/providers/esgf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
import requests
from urllib.parse import urlencode
from typing import List, Dict
from typing import Any, List, Dict
import itertools
import dask
from openai import OpenAI
Expand Down Expand Up @@ -173,13 +173,12 @@ def get_mirrors_for_dataset(self, dataset_id: str) -> List[str]:
full_ids = [d.metadata["id"] for d in response]
return full_ids

def get_access_paths_by_id(self, dataset_id: str) -> Dict[str, List[str]]:
def get_datasets_from_id(self, dataset_id: str) -> List[Dict[str, Any]]:
"""
returns a list of OPENDAP URLs for use in processing given a dataset.
returns a list of datasets for a given ID. includes mirrors.
"""
if dataset_id == "":
return []
self.get_mirrors_for_dataset(dataset_id)
return {}
params = urlencode(
{
"type": "File",
Expand All @@ -195,11 +194,18 @@ def get_access_paths_by_id(self, dataset_id: str) -> Dict[str, List[str]]:
raise ConnectionError(
f"Failed to extract files from dataset via file search: {full_url} {response}"
)
files = response["response"]["docs"]
if len(files) == 0:
datasets = response["response"]["docs"]
if len(datasets) == 0:
raise ConnectionError(
f"Failed to extract files from dataset: empty list {full_url}"
)
return datasets

def get_access_paths_by_id(self, dataset_id: str) -> Dict[str, List[str]]:
"""
returns a list of OPENDAP URLs for use in processing given a dataset.
"""
files = self.get_datasets_from_id(dataset_id)

# file url responses are lists of strings with their protocols separated by |
# e.x. https://esgf-node.example|mimetype|OPENDAP
Expand All @@ -217,6 +223,16 @@ def select(files, selector):

return {"opendap": opendap_urls, "http": http_urls}

def get_metadata_for_dataset(self, dataset_id: str) -> Dict[str, Any]:
"""
returns a list of OPENDAP URLs for use in processing given a dataset.
"""
datasets = self.get_datasets_from_id(dataset_id)
if len(datasets) == 0:
msg = "no datasets found for given ID"
raise ValueError(msg)
return datasets[0]

def get_access_paths(self, dataset: Dataset) -> AccessURLs:
return self.get_all_access_paths_by_id(dataset.metadata["id"])

Expand Down
Loading

0 comments on commit addcc89

Please sign in to comment.