Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
nilsleh committed Jan 27, 2025
1 parent df21841 commit 67b3653
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 20 deletions.
7 changes: 4 additions & 3 deletions tests/data/bigearthnetV2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import hashlib
import os
import shutil
import tarfile
from pathlib import Path

import numpy as np
import pandas as pd
from pathlib import Path
import rasterio
import zstandard as zstd
import tarfile

# Constants
IMG_SIZE = 120
Expand Down
14 changes: 3 additions & 11 deletions tests/datasets/test_bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_plot(self, dataset: BigEarthNet) -> None:

class TestBigEarthNetV2:
@pytest.fixture(
params=zip(['all', 's1', 's2'], [19, 19, 19], ['train', 'val', 'test'])
params=zip(['all', 's1', 's2'], ['train', 'val', 'test'])
)
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
Expand Down Expand Up @@ -179,12 +179,12 @@ def dataset(
}
monkeypatch.setattr(BigEarthNetV2, 'metadata_locs', metadata)

bands, num_classes, split = request.param
bands, split = request.param

root = tmp_path
transforms = nn.Identity()
return BigEarthNetV2(
root, split, bands, num_classes, transforms, download=True, checksum=True
root, split, bands, transforms, download=True, checksum=True
)

def test_getitem(self, dataset: BigEarthNetV2) -> None:
Expand All @@ -204,7 +204,6 @@ def test_getitem(self, dataset: BigEarthNetV2) -> None:
assert x['image_s1'].shape == (2, 120, 120)

assert x['mask'].shape == (1, 120, 120)
assert x['label'].shape == (dataset.num_classes,)

assert x['mask'].dtype == torch.int64
assert x['label'].dtype == torch.int64
Expand All @@ -227,7 +226,6 @@ def test_already_downloaded(self, dataset: BigEarthNetV2, tmp_path: Path) -> Non
root=tmp_path,
bands=dataset.bands,
split=dataset.split,
num_classes=dataset.num_classes,
download=True,
)

Expand Down Expand Up @@ -264,7 +262,6 @@ def test_already_downloaded_not_extracted(
root=tmp_path,
bands=dataset.bands,
split=dataset.split,
num_classes=dataset.num_classes,
download=False,
)

Expand All @@ -278,11 +275,6 @@ def test_invalid_bands(self, tmp_path: Path) -> None:
with pytest.raises(AssertionError):
BigEarthNetV2(tmp_path, bands='invalid')

def test_invalid_num_classes(self, tmp_path: Path) -> None:
"""Test error on invalid number of classes."""
with pytest.raises(AssertionError):
BigEarthNetV2(tmp_path, num_classes=20)

def test_plot(self, dataset: BigEarthNetV2) -> None:
"""Test plotting functionality."""
x = dataset[0].copy()
Expand Down
10 changes: 4 additions & 6 deletions torchgeo/datasets/bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import os
import tarfile
import tempfile
import textwrap
from collections.abc import Callable
from typing import ClassVar

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import numpy as np
import pandas as pd
import textwrap
import rasterio
import torch
from matplotlib.colors import BoundaryNorm, ListedColormap
from matplotlib.figure import Figure
from rasterio.enums import Resampling
from torch import Tensor
Expand Down Expand Up @@ -974,11 +974,10 @@ def plot(

# Create custom colormap
cmap = ListedColormap(colors)
bounds = unique_labels + [unique_labels[-1] + 1]
bounds = [*unique_labels, unique_labels[-1] + 1]
norm = BoundaryNorm(bounds, len(colors))

# Plot mask with custom colormap
im = axes[mask_idx].imshow(mask, cmap=cmap, norm=norm)
axes[mask_idx].imshow(mask, cmap=cmap, norm=norm)

# Add legend with class names
legend_elements = [
Expand All @@ -997,7 +996,6 @@ def plot(
if show_titles:
axes[mask_idx].set_title('Land Cover Map')

# Add classification labels to suptitle
if 'label' in sample:
label_indices = sample['label'].nonzero().squeeze(1).tolist()
label_names = [self.class_set[idx] for idx in label_indices]
Expand Down

0 comments on commit 67b3653

Please sign in to comment.