Skip to content

Commit

Permalink
Switch to kornia AugmentationSequential - for real
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Feb 5, 2025
1 parent 8a716eb commit d588af7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ repos:
- pillow>=10.4.0
- pytest>=6.1.2
- scikit-image>=0.22.0
- timm>=1.0.14
- torch>=2.6
- torchmetrics>=0.10
- torchvision>=0.18
Expand Down
28 changes: 13 additions & 15 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch import Tensor

from torchgeo.transforms import indices, transforms
from torchgeo.transforms import indices
from torchgeo.transforms.transforms import _ExtractPatches

# Kornia is very particular about its boxes:
Expand Down Expand Up @@ -203,12 +203,12 @@ def test_extract_patches() -> None:
'image': torch.randn(size=(b, c, h, w)),
'mask': torch.randint(low=0, high=2, size=(b, h, w)),
}
train_transforms = transforms.AugmentationSequential(
_ExtractPatches(window_size=p), same_on_batch=True, data_keys=['image', 'mask']
train_transforms = K.AugmentationSequential(
_ExtractPatches(window_size=p), same_on_batch=True, data_keys=None
)
output = train_transforms(batch)
assert batch['image'].shape == (b * num_patches, c, p, p)
assert batch['mask'].shape == (b * num_patches, p, p)
assert output['image'].shape == (b * num_patches, c, p, p)
assert output['mask'].shape == (b * num_patches, 1, p, p)

# Test different stride
s = 16
Expand All @@ -217,14 +217,12 @@ def test_extract_patches() -> None:
'image': torch.randn(size=(b, c, h, w)),
'mask': torch.randint(low=0, high=2, size=(b, h, w)),
}
train_transforms = transforms.AugmentationSequential(
_ExtractPatches(window_size=p, stride=s),
same_on_batch=True,
data_keys=['image', 'mask'],
train_transforms = K.AugmentationSequential(
_ExtractPatches(window_size=p, stride=s), same_on_batch=True, data_keys=None
)
output = train_transforms(batch)
assert batch['image'].shape == (b * num_patches, c, p, p)
assert batch['mask'].shape == (b * num_patches, p, p)
assert output['image'].shape == (b * num_patches, c, p, p)
assert output['mask'].shape == (b * num_patches, 1, p, p)

# Test keepdim=False
s = p
Expand All @@ -233,13 +231,13 @@ def test_extract_patches() -> None:
'image': torch.randn(size=(b, c, h, w)),
'mask': torch.randint(low=0, high=2, size=(b, h, w)),
}
train_transforms = transforms.AugmentationSequential(
train_transforms = K.AugmentationSequential(
_ExtractPatches(window_size=p, stride=s, keepdim=False),
same_on_batch=True,
data_keys=['image', 'mask'],
data_keys=None,
)
output = train_transforms(batch)
for k, v in output.items():
print(k, v.shape, v.dtype)
assert batch['image'].shape == (b, num_patches, c, p, p)
assert batch['mask'].shape == (b, num_patches, 1, p, p)
assert output['image'].shape == (b, num_patches, c, p, p)
assert output['mask'].shape == (b, num_patches, 1, p, p)

0 comments on commit d588af7

Please sign in to comment.