Skip to content

Commit 545c2c9

Browse files
committed
fix style
1 parent 87f3123 commit 545c2c9

File tree

1 file changed

+15
-21
lines changed

1 file changed

+15
-21
lines changed

torchgeo/transforms/temporal.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
"""TorchGeo temporal transforms."""
55

66
from typing import Any, Literal
7-
from einops import rearrange
7+
88
import kornia.augmentation as K
9+
from einops import rearrange
910
from torch import Tensor
1011

1112

@@ -19,7 +20,7 @@ class TemporalRearrange(K.IntensityAugmentationBase2D):
1920

2021
def __init__(
2122
self,
22-
mode: Literal["merge", "split"],
23+
mode: Literal['merge', 'split'],
2324
num_temporal_channels: int,
2425
p: float = 1.0,
2526
p_batch: float = 1.0,
@@ -40,19 +41,12 @@ def __init__(
4041
super().__init__(
4142
p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
4243
)
43-
if mode not in ["merge", "split"]:
44+
if mode not in ['merge', 'split']:
4445
raise ValueError("mode must be either 'merge' or 'split'")
4546

46-
self.flags = {
47-
"mode": mode,
48-
"num_temporal_channels": num_temporal_channels,
49-
}
47+
self.flags = {'mode': mode, 'num_temporal_channels': num_temporal_channels}
5048

51-
def apply_transform(
52-
self,
53-
input: Tensor,
54-
flags: dict[str, Any],
55-
) -> Tensor:
49+
def apply_transform(self, input: Tensor, flags: dict[str, Any]) -> Tensor:
5650
"""Apply the transform.
5751
5852
Args:
@@ -65,25 +59,25 @@ def apply_transform(
6559
Raises:
6660
ValueError: If input tensor dimensions don't match expected shape
6761
"""
68-
mode = flags["mode"]
69-
t = flags["num_temporal_channels"]
62+
mode = flags['mode']
63+
t = flags['num_temporal_channels']
7064

71-
if mode == "merge":
65+
if mode == 'merge':
7266
if input.ndim != 5:
7367
raise ValueError(
74-
f"Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}"
68+
f'Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}'
7569
)
76-
return rearrange(input, "b t c h w -> b (t c) h w")
70+
return rearrange(input, 'b t c h w -> b (t c) h w')
7771
else:
7872
if input.ndim != 4:
7973
raise ValueError(
80-
f"Expected 4D input tensor (B,TC,H,W), got shape {input.shape}"
74+
f'Expected 4D input tensor (B,TC,H,W), got shape {input.shape}'
8175
)
8276
tc = input.shape[1]
8377
if tc % t != 0:
8478
raise ValueError(
85-
f"Input channels ({tc}) must be divisible by "
86-
f"num_temporal_channels ({t})"
79+
f'Input channels ({tc}) must be divisible by '
80+
f'num_temporal_channels ({t})'
8781
)
8882
c = tc // t
89-
return rearrange(input, "b (t c) h w -> b t c h w", t=t, c=c)
83+
return rearrange(input, 'b (t c) h w -> b t c h w', t=t, c=c)

0 commit comments

Comments
 (0)