4
4
"""TorchGeo temporal transforms."""
5
5
6
6
from typing import Any , Literal
7
- from einops import rearrange
7
+
8
8
import kornia .augmentation as K
9
+ from einops import rearrange
9
10
from torch import Tensor
10
11
11
12
@@ -19,7 +20,7 @@ class TemporalRearrange(K.IntensityAugmentationBase2D):
19
20
20
21
def __init__ (
21
22
self ,
22
- mode : Literal [" merge" , " split" ],
23
+ mode : Literal [' merge' , ' split' ],
23
24
num_temporal_channels : int ,
24
25
p : float = 1.0 ,
25
26
p_batch : float = 1.0 ,
@@ -40,19 +41,12 @@ def __init__(
40
41
super ().__init__ (
41
42
p = p , p_batch = p_batch , same_on_batch = same_on_batch , keepdim = keepdim
42
43
)
43
- if mode not in [" merge" , " split" ]:
44
+ if mode not in [' merge' , ' split' ]:
44
45
raise ValueError ("mode must be either 'merge' or 'split'" )
45
46
46
- self .flags = {
47
- "mode" : mode ,
48
- "num_temporal_channels" : num_temporal_channels ,
49
- }
47
+ self .flags = {'mode' : mode , 'num_temporal_channels' : num_temporal_channels }
50
48
51
- def apply_transform (
52
- self ,
53
- input : Tensor ,
54
- flags : dict [str , Any ],
55
- ) -> Tensor :
49
+ def apply_transform (self , input : Tensor , flags : dict [str , Any ]) -> Tensor :
56
50
"""Apply the transform.
57
51
58
52
Args:
@@ -65,25 +59,25 @@ def apply_transform(
65
59
Raises:
66
60
ValueError: If input tensor dimensions don't match expected shape
67
61
"""
68
- mode = flags [" mode" ]
69
- t = flags [" num_temporal_channels" ]
62
+ mode = flags [' mode' ]
63
+ t = flags [' num_temporal_channels' ]
70
64
71
- if mode == " merge" :
65
+ if mode == ' merge' :
72
66
if input .ndim != 5 :
73
67
raise ValueError (
74
- f" Expected 5D input tensor (B,T,C,H,W), got shape { input .shape } "
68
+ f' Expected 5D input tensor (B,T,C,H,W), got shape { input .shape } '
75
69
)
76
- return rearrange (input , " b t c h w -> b (t c) h w" )
70
+ return rearrange (input , ' b t c h w -> b (t c) h w' )
77
71
else :
78
72
if input .ndim != 4 :
79
73
raise ValueError (
80
- f" Expected 4D input tensor (B,TC,H,W), got shape { input .shape } "
74
+ f' Expected 4D input tensor (B,TC,H,W), got shape { input .shape } '
81
75
)
82
76
tc = input .shape [1 ]
83
77
if tc % t != 0 :
84
78
raise ValueError (
85
- f" Input channels ({ tc } ) must be divisible by "
86
- f" num_temporal_channels ({ t } )"
79
+ f' Input channels ({ tc } ) must be divisible by '
80
+ f' num_temporal_channels ({ t } )'
87
81
)
88
82
c = tc // t
89
- return rearrange (input , " b (t c) h w -> b t c h w" , t = t , c = c )
83
+ return rearrange (input , ' b (t c) h w -> b t c h w' , t = t , c = c )
0 commit comments