Skip to content

Commit

Permalink
resolved formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuld committed Feb 6, 2025
1 parent 13337c5 commit 60ee1e0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
6 changes: 3 additions & 3 deletions tests/datasets/test_substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestSubstation:
@pytest.fixture
def dataset(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> Generator[Substation, None, None]:
) -> Substation:
"""Fixture for the Substation."""
root = os.path.join(os.getcwd(), 'tests', 'data', 'substation')

Expand Down Expand Up @@ -157,15 +157,15 @@ def test_not_downloaded_with_download(
target_image_path = tmp_path / filename
target_mask_path = tmp_path / maskname

def mock_download(_self): # Accept 'self' as an argument
def mock_download(self) -> None:

Check failure on line 160 in tests/datasets/test_substation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

tests/datasets/test_substation.py:160:27: ANN001 Missing type annotation for function argument `self`
shutil.copytree(source_image_path, target_image_path)
shutil.copytree(source_mask_path, target_mask_path)

monkeypatch.setattr(
'torchgeo.datasets.substation.Substation._download', mock_download
)
monkeypatch.setattr(
'torchgeo.datasets.substation.Substation._extract', lambda _self: None
'torchgeo.datasets.substation.Substation._extract', lambda self: None
)

Substation(
Expand Down
14 changes: 5 additions & 9 deletions torchgeo/datamodules/substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

import kornia.augmentation as K
import torch
from torch.utils.data import Subset, random_split
from tqdm import tqdm
from torch.utils.data import random_split

from ..datasets import Substation
from .geo import NonGeoDataModule
Expand Down Expand Up @@ -43,10 +42,11 @@ def __init__(
val_split_pct: Percentage of data to use for validation.
test_split_pct: Percentage of data to use for testing.
bands: Number of input channels to use.
model_type: Type of model being used (e.g., 'swin' for specific channel selection).
num_of_timepoints: Number of timepoints to use in the dataset.
aug: Augmentation to apply to the dataset.
train_aug: Augmentation to apply to the training dataset.
timepoint_aggregation: Aggregation method for multiple timepoints.
model_type: Type of model being used (e.g., 'swin' for specific channel selection).
size: Size of the input images.
**kwargs: Additional arguments passed to Substation.
"""
super().__init__(Substation, batch_size, num_workers, **kwargs)
Expand All @@ -68,10 +68,6 @@ def __init__(
K.Resize(size), data_keys=None, keepdim=True
)

def _identity(self, x: torch.Tensor) -> torch.Tensor:
"""Identity function for default transformations."""
return x

def setup(self, stage: str) -> None:
"""Set up datasets.
Expand Down

0 comments on commit 60ee1e0

Please sign in to comment.