From 2ce0f33b482918f860042dad2c09b88b74d66e5e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 3 Feb 2025 11:59:11 +0100 Subject: [PATCH 1/3] RasterDataset: assert valid bands --- torchgeo/datasets/geo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 26a035d427d..e9527b87231 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -435,6 +435,7 @@ def __init__( cache: if True, cache file handle to speed up repeated sampling Raises: + AssertionError: If *bands* are invalid. DatasetNotFoundError: If dataset is not found. .. versionchanged:: 0.5 @@ -446,6 +447,8 @@ def __init__( self.bands = bands or self.all_bands self.cache = cache + assert set(bands) <= set(self.all_bands) + # Populate the dataset index i = 0 filename_regex = re.compile(self.filename_regex, re.VERBOSE) From 640bd34e314aebfacbd2e0638326c3843d6b57c4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 3 Feb 2025 12:03:49 +0100 Subject: [PATCH 2/3] bands cannot be None --- torchgeo/datasets/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index e9527b87231..14a69231754 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -447,7 +447,7 @@ def __init__( self.bands = bands or self.all_bands self.cache = cache - assert set(bands) <= set(self.all_bands) + assert set(self.bands) <= set(self.all_bands) # Populate the dataset index i = 0 From 3c1f8cfb1fa79932593463c36fe3a9038bf6afc8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 3 Feb 2025 19:23:02 +0100 Subject: [PATCH 3/3] Only check if all_bands is defined --- torchgeo/datasets/geo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 14a69231754..cdce123ce8f 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -447,7 +447,8 @@ def __init__( self.bands = bands or self.all_bands self.cache = cache - assert set(self.bands) <= set(self.all_bands) + if self.all_bands: + assert set(self.bands) <= set(self.all_bands) # Populate the dataset index i = 0