Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Temporal Transforms #2477

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions torchgeo/transforms/temporal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""TorchGeo temporal transforms."""

from typing import Any, Literal

import kornia.augmentation as K
from einops import rearrange
from torch import Tensor


class TemporalRearrange(K.IntensityAugmentationBase2D):
"""Rearrange temporal and channel dimensions.

This transform allows conversion between:
- B x T x C x H x W (temporal-explicit)
- B x (T*C) x H x W (temporal-channel)
"""

def __init__(
self,
mode: Literal['merge', 'split'],
num_temporal_channels: int,
p: float = 1.0,
p_batch: float = 1.0,
same_on_batch: bool = False,
keepdim: bool = False,
) -> None:
"""Initialize a new TemporalRearrange instance.

Args:
mode: Whether to 'merge' (B x T x C x H x W -> B x TC x H x W) or
'split' (B x TC x H x W -> B x T x C x H x W) temporal dimensions
num_temporal_channels: Number of temporal channels (T) in the sequence
p: Probability for applying the transform element-wise
p_batch: Probability for applying the transform batch-wise
same_on_batch: Apply the same transformation across the batch
keepdim: Whether to keep the output shape the same as input
"""
super().__init__(
p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
)
if mode not in ['merge', 'split']:
raise ValueError("mode must be either 'merge' or 'split'")

self.flags = {'mode': mode, 'num_temporal_channels': num_temporal_channels}

def apply_transform(self, input: Tensor, flags: dict[str, Any]) -> Tensor:
"""Apply the transform.

Args:
input: Input tensor
flags: Static parameters including mode and number of temporal channels

Returns:
Transformed tensor with rearranged dimensions

Raises:
ValueError: If input tensor dimensions don't match expected shape
"""
mode = flags['mode']
t = flags['num_temporal_channels']

if mode == 'merge':
if input.ndim != 5:
raise ValueError(
f'Expected 5D input tensor (B,T,C,H,W), got shape {input.shape}'
)
return rearrange(input, 'b t c h w -> b (t c) h w')
else:
if input.ndim != 4:
raise ValueError(
f'Expected 4D input tensor (B,TC,H,W), got shape {input.shape}'
)
tc = input.shape[1]
if tc % t != 0:
raise ValueError(
f'Input channels ({tc}) must be divisible by '
f'num_temporal_channels ({t})'
)
c = tc // t
return rearrange(input, 'b (t c) h w -> b t c h w', t=t, c=c)
Loading