diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 5caef0ee4ee..8328dc1c844 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -15,7 +15,14 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoClassificationDataset -from .utils import Path, check_integrity, download_url, extract_archive, rasterio_loader +from .utils import ( + Path, + check_integrity, + download_url, + extract_archive, + percentile_normalization, + rasterio_loader, +) class EuroSAT(NonGeoClassificationDataset): @@ -268,7 +275,7 @@ def plot( image = np.take(sample['image'].numpy(), indices=rgb_indices, axis=0) image = np.rollaxis(image, 0, 3) - image = np.clip(image / 3000, 0, 1) + image = percentile_normalization(image) label = cast(int, sample['label'].item()) label_class = self.classes[label]