Skip to content

GMTDataArrayAccessor: Support applying grid operations on the current xarray.DataArray object #3854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions pygmt/tests/test_xarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,50 @@
from pygmt.datasets import load_earth_relief
from pygmt.enums import GridRegistration, GridType
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers.testing import load_static_earth_relief

_HAS_NETCDF4 = bool(importlib.util.find_spec("netCDF4"))


@pytest.fixture(scope="module", name="grid")
def fixture_grid():
"""
Load the grid data from the sample earth_relief file.
"""
return load_static_earth_relief()


@pytest.fixture(scope="module", name="expected_clipped_grid")
def fixture_expected_clipped_grid():
"""
The expected grdclip grid result.
"""
return xr.DataArray(
data=[
[1000.0, 570.5, -1000.0, -1000.0],
[1000.0, 1000.0, 571.5, 638.5],
[555.5, 556.0, 580.0, 1000.0],
],
coords={"lon": [-52.5, -51.5, -50.5, -49.5], "lat": [-18.5, -17.5, -16.5]},
dims=["lat", "lon"],
)


@pytest.fixture(scope="module", name="expected_equalized_grid")
def fixture_expected_equalized_grid():
"""
The expected grdhisteq grid result.
"""
return xr.DataArray(
data=[[0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 1, 1], [1, 1, 1, 1]],
coords={
"lon": [-51.5, -50.5, -49.5, -48.5],
"lat": [-21.5, -20.5, -19.5, -18.5],
},
dims=["lat", "lon"],
)


def test_xarray_accessor_gridline_cartesian():
"""
Check that the accessor returns the correct registration and gtype values for a
Expand Down Expand Up @@ -169,3 +209,25 @@ def test_xarray_accessor_tiled_grid_slice_and_add():
added_grid.gmt.gtype = GridType.GEOGRAPHIC
assert added_grid.gmt.registration is GridRegistration.PIXEL
assert added_grid.gmt.gtype is GridType.GEOGRAPHIC


def test_xarray_accessor_clip(grid, expected_clipped_grid):
"""
Check that the accessor has the clip method and that it works correctly.

This test is adapted from the `test_grdclip_no_outgrid` test.
"""
clipped_grid = grid.gmt.clip(
below=[550, -1000], above=[700, 1000], region=[-53, -49, -19, -16]
)
xr.testing.assert_allclose(a=clipped_grid, b=expected_clipped_grid)


def test_xarray_accessor_equalize(grid, expected_equalized_grid):
"""
Check that the accessor has the equalize_hist method and that it works correctly.

This test is adapted from the `test_equalize_grid_no_outgrid` test.
"""
equalized_grid = grid.gmt.equalize_hist(divisions=2, region=[-52, -48, -22, -18])
xr.testing.assert_allclose(a=equalized_grid, b=expected_equalized_grid)
59 changes: 58 additions & 1 deletion pygmt/xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,25 @@
"""

import contextlib
import functools
from pathlib import Path

import xarray as xr
from pygmt.enums import GridRegistration, GridType
from pygmt.exceptions import GMTInvalidInput
from pygmt.src.grdinfo import grdinfo
from pygmt.src import (
dimfilter,
grdclip,
grdcut,
grdfill,
grdfilter,
grdgradient,
grdhisteq,
grdinfo,
grdproject,
grdsample,
grdtrack,
)


@xr.register_dataarray_accessor("gmt")
Expand All @@ -23,6 +36,11 @@ class GMTDataArrayAccessor:
- ``registration``: Grid registration type :class:`pygmt.enums.GridRegistration`.
- ``gtype``: Grid coordinate system type :class:`pygmt.enums.GridType`.

The *gmt* accessor also provides a set of grid-operation methods that enables
applying GMT's grid processing functionalities directly to the current
:class:`xarray.DataArray` object. See the summary table below for the list of
available methods.

Notes
-----
When accessed the first time, the *gmt* accessor will first be initialized to the
Expand Down Expand Up @@ -150,6 +168,19 @@ class GMTDataArrayAccessor:
>>> zval.gmt.gtype = GridType.GEOGRAPHIC
>>> zval.gmt.registration, zval.gmt.gtype
(<GridRegistration.GRIDLINE: 0>, <GridType.GEOGRAPHIC: 1>)

Instead of calling a grid-processing function and passing the
:class:`xarray.DataArray` object as an input, you can call the corresponding method
directly on the object. For example, the following two are equivalent:

>>> from pygmt.datasets import load_earth_relief
>>> grid = load_earth_relief(resolution="30m", region=[10, 30, 15, 25])
>>> # Create a new grid from an input grid. Set all values below 1,000 to 0 and all
>>> # values above 1,500 to 10,000.
>>> # Option 1:
>>> new_grid = pygmt.grdclip(grid=grid, below=[1000, 0], above=[1500, 10000])
>>> # Option 2:
>>> new_grid = grid.gmt.clip(below=[1000, 0], above=[1500, 10000])
"""

def __init__(self, xarray_obj: xr.DataArray):
Expand Down Expand Up @@ -204,3 +235,29 @@ def gtype(self, value: GridType | int):
)
raise GMTInvalidInput(msg)
self._gtype = GridType(value)

@staticmethod
def _make_method(func):
"""
Create a wrapper method for PyGMT grid-processing methods.

The :class:`xarray.DataArray` object is passed as the first argument.
"""

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
return func(self._obj, *args, **kwargs)

return wrapper

# Accessor methods for grid operations.
clip = _make_method(grdclip)
cut = _make_method(grdcut)
dimfilter = _make_method(dimfilter)
equalize_hist = _make_method(grdhisteq.equalize_grid)
fill = _make_method(grdfill)
filter = _make_method(grdfilter)
gradient = _make_method(grdgradient)
project = _make_method(grdproject)
sample = _make_method(grdsample)
track = _make_method(grdtrack)
Loading