Skip to content

Commit b9404d1

Browse files
authored
Remove AugPipe (#1978)
* Initial commit * Add issue link * Remove explicit keepdim * Fix coverage * test transforms: Switch to kornia AugmentationSequential * Switch boxes key to bbox_xyxy * Switch class -> label * Switch prediction_boxes -> prediction_bbox_xyxy * Switch to relative imports * Switch to kornia AugmentationSequential - for real * Remove AugmentationSequential * Exclude AugmentationSequential
1 parent 2ad5d4f commit b9404d1

26 files changed

+167
-327
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ repos:
2929
- pillow>=10.4.0
3030
- pytest>=6.1.2
3131
- scikit-image>=0.22.0
32+
- timm>=1.0.14
3233
- torch>=2.6
3334
- torchmetrics>=0.10
3435
- torchvision>=0.18

docs/api/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ torchgeo.transforms
22
===================
33

44
.. automodule:: torchgeo.transforms
5+
:exclude-members: AugmentationSequential

tests/datamodules/test_fair1m.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_plot(self, datamodule: FAIR1MDataModule) -> None:
3535
batch = next(iter(datamodule.val_dataloader()))
3636
sample = {
3737
'image': batch['image'][0],
38-
'boxes': batch['boxes'][0],
38+
'bbox_xyxy': batch['bbox_xyxy'][0],
3939
'label': batch['label'][0],
4040
}
4141
datamodule.plot(sample)

tests/datasets/test_fair1m.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def test_getitem(self, dataset: FAIR1M) -> None:
7070
assert x['image'].shape[0] == 3
7171

7272
if dataset.split != 'test':
73-
assert isinstance(x['boxes'], torch.Tensor)
73+
assert isinstance(x['bbox_xyxy'], torch.Tensor)
7474
assert isinstance(x['label'], torch.Tensor)
75-
assert x['boxes'].shape[-2:] == (5, 2)
75+
assert x['bbox_xyxy'].shape[-2:] == (5, 2)
7676
assert x['label'].ndim == 1
7777

7878
def test_len(self, dataset: FAIR1M) -> None:
@@ -124,6 +124,6 @@ def test_plot(self, dataset: FAIR1M) -> None:
124124
plt.close()
125125

126126
if dataset.split != 'test':
127-
x['prediction_boxes'] = x['boxes'].clone()
127+
x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone()
128128
dataset.plot(x)
129129
plt.close()

tests/datasets/test_forestdamage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_getitem(self, dataset: ForestDamage) -> None:
3636
assert isinstance(x, dict)
3737
assert isinstance(x['image'], torch.Tensor)
3838
assert isinstance(x['label'], torch.Tensor)
39-
assert isinstance(x['boxes'], torch.Tensor)
39+
assert isinstance(x['bbox_xyxy'], torch.Tensor)
4040
assert x['image'].shape[0] == 3
4141
assert x['image'].ndim == 3
4242

@@ -67,6 +67,6 @@ def test_plot(self, dataset: ForestDamage) -> None:
6767

6868
def test_plot_prediction(self, dataset: ForestDamage) -> None:
6969
x = dataset[0].copy()
70-
x['prediction_boxes'] = x['boxes'].clone()
70+
x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone()
7171
dataset.plot(x, suptitle='Prediction')
7272
plt.close()

tests/datasets/test_idtrees.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ def test_getitem(self, dataset: IDTReeS) -> None:
5757

5858
if 'label' in x:
5959
assert isinstance(x['label'], torch.Tensor)
60-
if 'boxes' in x:
61-
assert isinstance(x['boxes'], torch.Tensor)
62-
if x['boxes'].ndim != 1:
63-
assert x['boxes'].ndim == 2
64-
assert x['boxes'].shape[-1] == 4
60+
if 'bbox_xyxy' in x:
61+
assert isinstance(x['bbox_xyxy'], torch.Tensor)
62+
if x['bbox_xyxy'].ndim != 1:
63+
assert x['bbox_xyxy'].ndim == 2
64+
assert x['bbox_xyxy'].shape[-1] == 4
6565

6666
def test_len(self, dataset: IDTReeS) -> None:
6767
assert len(dataset) == 3
@@ -87,8 +87,8 @@ def test_plot(self, dataset: IDTReeS) -> None:
8787
dataset.plot(x, show_titles=False)
8888
plt.close()
8989

90-
if 'boxes' in x:
91-
x['prediction_boxes'] = x['boxes']
90+
if 'bbox_xyxy' in x:
91+
x['prediction_bbox_xyxy'] = x['bbox_xyxy']
9292
dataset.plot(x, show_titles=True)
9393
plt.close()
9494
if 'label' in x:

tests/datasets/test_nasa_marine_debris.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def test_getitem(self, dataset: NASAMarineDebris) -> None:
2828
x = dataset[0]
2929
assert isinstance(x, dict)
3030
assert isinstance(x['image'], torch.Tensor)
31-
assert isinstance(x['boxes'], torch.Tensor)
31+
assert isinstance(x['bbox_xyxy'], torch.Tensor)
3232
assert x['image'].shape[0] == 3
33-
assert x['boxes'].shape[-1] == 4
33+
assert x['bbox_xyxy'].shape[-1] == 4
3434

3535
def test_len(self, dataset: NASAMarineDebris) -> None:
3636
assert len(dataset) == 5
@@ -50,6 +50,6 @@ def test_plot(self, dataset: NASAMarineDebris) -> None:
5050
plt.close()
5151
dataset.plot(x, show_titles=False)
5252
plt.close()
53-
x['prediction_boxes'] = x['boxes'].clone()
53+
x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone()
5454
dataset.plot(x)
5555
plt.close()

tests/datasets/test_pastis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_getitem_instance(self, dataset: PASTIS) -> None:
5252
assert isinstance(x, dict)
5353
assert isinstance(x['image'], torch.Tensor)
5454
assert isinstance(x['mask'], torch.Tensor)
55-
assert isinstance(x['boxes'], torch.Tensor)
55+
assert isinstance(x['bbox_xyxy'], torch.Tensor)
5656
assert isinstance(x['label'], torch.Tensor)
5757

5858
def test_len(self, dataset: PASTIS) -> None:

tests/datasets/test_reforestree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def test_getitem(self, dataset: ReforesTree) -> None:
3636
assert isinstance(x, dict)
3737
assert isinstance(x['image'], torch.Tensor)
3838
assert isinstance(x['label'], torch.Tensor)
39-
assert isinstance(x['boxes'], torch.Tensor)
39+
assert isinstance(x['bbox_xyxy'], torch.Tensor)
4040
assert isinstance(x['agb'], torch.Tensor)
4141
assert x['image'].shape[0] == 3
4242
assert x['image'].ndim == 3
43-
assert len(x['boxes']) == 2
43+
assert len(x['bbox_xyxy']) == 2
4444

4545
def test_len(self, dataset: ReforesTree) -> None:
4646
assert len(dataset) == 2
@@ -67,6 +67,6 @@ def test_plot(self, dataset: ReforesTree) -> None:
6767

6868
def test_plot_prediction(self, dataset: ReforesTree) -> None:
6969
x = dataset[0].copy()
70-
x['prediction_boxes'] = x['boxes'].clone()
70+
x['prediction_bbox_xyxy'] = x['bbox_xyxy'].clone()
7171
dataset.plot(x, suptitle='Prediction')
7272
plt.close()

tests/datasets/test_vhr10.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def test_getitem(self, dataset: VHR10) -> None:
4141
assert isinstance(x, dict)
4242
assert isinstance(x['image'], torch.Tensor)
4343
if dataset.split == 'positive':
44-
assert isinstance(x['labels'], torch.Tensor)
45-
assert isinstance(x['boxes'], torch.Tensor)
46-
if 'masks' in x:
47-
assert isinstance(x['masks'], torch.Tensor)
44+
assert isinstance(x['label'], torch.Tensor)
45+
assert isinstance(x['bbox_xyxy'], torch.Tensor)
46+
if 'mask' in x:
47+
assert isinstance(x['mask'], torch.Tensor)
4848

4949
def test_len(self, dataset: VHR10) -> None:
5050
if dataset.split == 'positive':
@@ -82,10 +82,10 @@ def test_plot(self, dataset: VHR10) -> None:
8282
scores = [0.7, 0.3, 0.7]
8383
for i in range(3):
8484
x = dataset[i]
85-
x['prediction_labels'] = x['labels']
86-
x['prediction_boxes'] = x['boxes']
85+
x['prediction_labels'] = x['label']
86+
x['prediction_bbox_xyxy'] = x['bbox_xyxy']
8787
x['prediction_scores'] = torch.Tensor([scores[i]])
88-
if 'masks' in x:
89-
x['prediction_masks'] = x['masks']
88+
if 'mask' in x:
89+
x['prediction_masks'] = x['mask']
9090
dataset.plot(x, show_feats='masks')
9191
plt.close()

0 commit comments

Comments
 (0)