From a34f00abe6cc1ac22dc9236f1d79c66d360db4b5 Mon Sep 17 00:00:00 2001
From: "Adam J. Stewart" <ajstewart426@gmail.com>
Date: Sat, 18 Jan 2025 16:25:17 +0100
Subject: [PATCH] IDTReeS: remove support for plotting lidar point cloud
 (#2428)

---
 .github/workflows/tests.yaml   |  4 ----
 .pre-commit-config.yaml        |  1 -
 docs/conf.py                   |  1 -
 pyproject.toml                 |  2 --
 requirements/datasets.txt      |  1 -
 requirements/min-reqs.old      |  2 --
 tests/datasets/test_idtrees.py |  8 --------
 torchgeo/datasets/idtrees.py   | 36 +---------------------------------
 8 files changed, 1 insertion(+), 54 deletions(-)

diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index a44a6cf7669..54781b8fa80 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -34,8 +34,6 @@ jobs:
           path: ${{ env.pythonLocation }}
           key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/required.txt', 'requirements/datasets.txt', 'requirements/tests.txt') }}
         if: ${{ runner.os != 'macOS' }}
-      - name: Setup headless display for pyvista
-        uses: pyvista/setup-headless-display-action@v3
       - name: Install pip dependencies
         if: steps.cache.outputs.cache-hit != 'true'
         run: |
@@ -68,8 +66,6 @@ jobs:
         with:
           path: ${{ env.pythonLocation }}
           key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/min-reqs.old') }}
-      - name: Setup headless display for pyvista
-        uses: pyvista/setup-headless-display-action@v3
       - name: Install pip dependencies
         if: steps.cache.outputs.cache-hit != 'true'
         run: |
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index d4d221e3c0d..bc948edcb59 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -28,7 +28,6 @@ repos:
           - numpy>=1.22
           - pillow>=10.4.0
           - pytest>=6.1.2
-          - pyvista>=0.34.2
           - scikit-image>=0.22.0
           - torch>=2.3
           - torchmetrics>=0.10
diff --git a/docs/conf.py b/docs/conf.py
index df1de398185..1b34024565d 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -115,7 +115,6 @@
     'numpy': ('https://numpy.org/doc/stable/', None),
     'python': ('https://docs.python.org/3', None),
     'lightning': ('https://lightning.ai/docs/pytorch/stable/', None),
-    'pyvista': ('https://docs.pyvista.org/version/stable/', None),
     'rasterio': ('https://rasterio.readthedocs.io/en/stable/', None),
     'rtree': ('https://rtree.readthedocs.io/en/stable/', None),
     'segmentation_models_pytorch': ('https://smp.readthedocs.io/en/stable/', None),
diff --git a/pyproject.toml b/pyproject.toml
index 8d7fcc89183..a6e16d1be59 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -93,8 +93,6 @@ datasets = [
     "pandas[parquet]>=2",
     # pycocotools 2.0.7+ required for wheels
     "pycocotools>=2.0.7",
-    # pyvista 0.34.2+ required to avoid ImportError in CI
-    "pyvista>=0.34.2",
     # scikit-image 0.19+ required for Python 3.10 wheels
     "scikit-image>=0.19",
     # scipy 1.7.2+ required for Python 3.10 wheels
diff --git a/requirements/datasets.txt b/requirements/datasets.txt
index 34f81c0d2db..680dae8987d 100644
--- a/requirements/datasets.txt
+++ b/requirements/datasets.txt
@@ -4,6 +4,5 @@ laspy==2.5.4
 opencv-python==4.11.0.86
 pandas[parquet]==2.2.3
 pycocotools==2.0.8
-pyvista==0.44.2
 scikit-image==0.25.0
 scipy==1.15.1
diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old
index a4a57e10968..e1058142ac6 100644
--- a/requirements/min-reqs.old
+++ b/requirements/min-reqs.old
@@ -27,10 +27,8 @@ laspy==2.0.0
 opencv-python==4.5.4.58
 pycocotools==2.0.7
 pyarrow==15.0.0  # Remove when we upgrade min version of pandas to `pandas[parquet]>=2`
-pyvista==0.34.2
 scikit-image==0.19.0
 scipy==1.7.2
-vtk==9.3.1  # PyVista is not yet compatible with VTK 9.4+
 
 # tests
 pytest==7.3.0
diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py
index 5fd858ac04f..4c145b4c006 100644
--- a/tests/datasets/test_idtrees.py
+++ b/tests/datasets/test_idtrees.py
@@ -95,11 +95,3 @@ def test_plot(self, dataset: IDTReeS) -> None:
             x['prediction_label'] = x['label']
             dataset.plot(x, show_titles=False)
             plt.close()
-
-    def test_plot_las(self, dataset: IDTReeS) -> None:
-        pyvista = pytest.importorskip('pyvista', minversion='0.34.2')
-        pyvista.OFF_SCREEN = True
-
-        # Test point cloud without colors
-        point_cloud = dataset.plot_las(index=0)
-        pyvista.plot(point_cloud, scalars=point_cloud.points, cpos='yz', cmap='viridis')
diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py
index 28e890dc69f..2c5777ba350 100644
--- a/torchgeo/datasets/idtrees.py
+++ b/torchgeo/datasets/idtrees.py
@@ -92,10 +92,9 @@ class IDTReeS(NonGeoDataset):
 
     * https://doi.org/10.1101/2021.08.06.453503
 
-    This dataset requires the following additional libraries to be installed:
+    This dataset requires the following additional library to be installed:
 
        * `laspy <https://pypi.org/project/laspy/>`_ to read lidar point clouds
-       * `pyvista <https://pypi.org/project/pyvista/>`_ to plot lidar point clouds
 
     .. versionadded:: 0.2
     """
@@ -552,36 +551,3 @@ def normalize(x: Tensor) -> Tensor:
             plt.suptitle(suptitle)
 
         return fig
-
-    def plot_las(self, index: int) -> 'pyvista.Plotter':  # type: ignore[name-defined] # noqa: F821
-        """Plot a sample point cloud at the index.
-
-        Args:
-            index: index to plot
-
-        Returns:
-            pyvista.PolyData object. Run pyvista.plot(point_cloud, ...) to display
-
-        Raises:
-            DependencyNotFoundError: If laspy or pyvista are not installed.
-
-        .. versionchanged:: 0.4
-           Ported from Open3D to PyVista, *colormap* parameter removed.
-        """
-        laspy = lazy_import('laspy')
-        pyvista = lazy_import('pyvista')
-        path = self.images[index]
-        path = path.replace('RGB', 'LAS').replace('.tif', '.las')
-        las = laspy.read(path)
-        points: np.typing.NDArray[np.int_] = np.stack(
-            [las.x, las.y, las.z], axis=0
-        ).transpose((1, 0))
-        point_cloud = pyvista.PolyData(points)
-
-        # Some point cloud files have no color->points mapping
-        if hasattr(las, 'red'):
-            colors = np.stack([las.red, las.green, las.blue], axis=0)
-            colors = colors.transpose((1, 0)) / np.iinfo(np.uint16).max
-            point_cloud['colors'] = colors
-
-        return point_cloud