Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restormer Implementation #8312

Open
wants to merge 52 commits into
base: dev
Choose a base branch
from
Open

Restormer Implementation #8312

wants to merge 52 commits into from

Conversation

phisanti
Copy link

Fixes # .

Description

This PR implements the Restormer architecture for high-resolution image restoration in MONAI following the discussion in issue #8261. The implementation supports both 2D and 3D images using MONAI's convolution as the base. Key additions include:

  • Downsample class for efficient downsampling operations
  • pixel_unshuffle operation complementing existing pixel_shuffle
  • Channel Attention Block (CABlock) with FeedForward layer
  • Multi-DConv Head Transposed Self-Attention (MDTA)
  • OverlapPatchEmbed class
  • Comprehensive unit tests for all new components

The implementation follows MONAI's coding patterns and includes performance validations against native PyTorch operations where applicable.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Copy link
Member

@ericspod ericspod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall but I had a few inline comments, and we should have full docstrings everywhere appropriate. For any classes meant for general purpose use (ie. not just by Restormer) please ensure they have docstring descriptions for the arguments (at the very least for constructor args). Thanks!

See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".

Args:
x: Input tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should specifically state that x has shape BCHW[D].


if any(d % factor != 0 for d in input_size[2:]):
raise ValueError(
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}"
f"All spatial dimensions must be divisible by {factor}, spatial shape is: {input_size[2:]}"

Maybe a little shorter?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)
padding = tuple((k - 1) // 2 for k in kernel_size_)

if down_mode == "conv":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if down_mode == "conv":
if down_mode == DownsampleMode.CONV:

bias=bias,
),
)
elif down_mode == "convgroup":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif down_mode == "convgroup":
elif down_mode == DownsampleMode.CONVGROUP:

if post_conv:
self.add_module("postconv", post_conv)

elif down_mode == "pixelunshuffle":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif down_mode == "pixelunshuffle":
elif down_mode == DownsampleMode.PIXELSHUFFLE:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, but I used DownsampleMode.PIXELUNSHUFFLE as in downsampling the restormer uses pixel_unshuffling while pixel_shuffling is reserved for upsampling.

Comment on lines 68 to 72
"""Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
convolutions for local mixing before attention, achieving linear complexity vs quadratic
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a full docstring here describing the arguments for the constructor, and in the previous class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 51 to 70
class OverlapPatchEmbed(nn.Module):
"""Initial feature extraction using overlapped convolutions.
Unlike standard patch embeddings that use non-overlapping patches,
this approach maintains spatial continuity through 3x3 convolutions."""

def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__()
self.proj = Convolution(
spatial_dims=spatial_dims,
in_channels=in_c,
out_channels=embed_dim,
kernel_size=3,
strides=1,
padding=1,
bias=bias,
conv_only=True,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class OverlapPatchEmbed(nn.Module):
"""Initial feature extraction using overlapped convolutions.
Unlike standard patch embeddings that use non-overlapping patches,
this approach maintains spatial continuity through 3x3 convolutions."""
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__()
self.proj = Convolution(
spatial_dims=spatial_dims,
in_channels=in_c,
out_channels=embed_dim,
kernel_size=3,
strides=1,
padding=1,
bias=bias,
conv_only=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
class OverlapPatchEmbed(Convolution):
"""
Initial feature extraction using overlapped convolutions. Unlike standard patch embeddings
that use non-overlapping patches, this approach maintains spatial continuity through 3x3 convolutions.
"""
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_c,
out_channels=embed_dim,
kernel_size=3,
strides=1,
padding=1,
bias=bias,
conv_only=True,
)

Would it work to inherit directly from Convolution?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works! very elegant suggestion btw!

@ericspod ericspod requested a review from aylward January 24, 2025 13:35
Copy link
Collaborator

@aylward aylward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this outstanding contribution!

return x


Downsample = DownSample
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericspod - do we normally provide alternative capitalizations to functions? "Downsample" is the generally accepted term (vs "Down Sample" which is less common).

I suggest using "Downsample" throughout, unless we offer alternative usage elsewhere that I haven't encountered. IDE auto-complete can help folks get the right capitalization.

Looking at your enums, you use "Downsample" which confirms (in my mind) that we should be using "Downsample" throughout.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have had this mechanism for other classes so other capitalisation could be used in scripts, such as "Transformd" and "TransformD". Here I'd say we don't need it though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I mirrored the naming style of the Upsample class. Happy either way to remove it or keep it:

monai/networks/blocks/upsample.py

Upsample = UpSample
Subpixelupsample = SubpixelUpSample = SubpixelUpsample

Unlike standard patch embeddings that use non-overlapping patches,
this approach maintains spatial continuity through 3x3 convolutions."""

def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spell-out in_c to in_channels

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def __init__(
self,
spatial_dims=2,
inp_channels=3,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_channels

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

num_refinement_blocks=4,
ffn_expansion_factor=2.66,
bias=False,
LayerNorm_type="WithBias",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make enum or convert to Bool (e.g., layer_norm_use_bias).



if __name__ == "__main__":
unittest.main()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good .... as long as the runtime is reasonable.

phisanti and others added 28 commits February 7, 2025 13:22
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 3db93ce
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 9693e04
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: a89f299
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 450691f
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: d0920d8
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 1a48d4d
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: fe47807
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 86155cd
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 137a7f2
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: fb17baf
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 5ff0baa
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 2566db1
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: ac4047b
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 2b74270
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 9b74533
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 1ab34f6
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 4f4c62c
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 068688f
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: e2e1070
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 35c7ee4
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: d8cb6c1
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 6d96816
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 8a688fb

Signed-off-by: tisalon <[email protected]>
Fixes Project-MONAI#8298 


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Fixes Project-MONAI#8267 .

### Description

Fix channel-wise intensity normalization for integer type inputs. 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [ ] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: advcu987 <[email protected]>
Signed-off-by: advcu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Fixes Project-MONAI#8306

This previous api has been deprecated, update based on:

https://docs.ngc.nvidia.com/api/?urls.primaryName=Private%20Artifacts%20(Models)%20API#/artifact-file-controller/downloadAllArtifactFiles

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Fixes Project-MONAI#8298


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Related to Project-MONAI#8241  .

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
)

Fixes Project-MONAI#8298.

### Description

This includes the tests for the `compressor` argument when testing with
Zarr before version 3.0 when this argument was deprecated. A fix to
upgrade the version of `pycln` used is also included. The version of
PyTorch is also fixed to below 2.6 to avoid issues with misuse of
`torch.load` which must be addressed later.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Eric Kerfoot <[email protected]>
…ns and simplify ValueError message in pixelunshuffle
…n Restormer model and update assert in forward layer to support 3D images
…ument descriptions and error handling details.
Signed-off-by: tisalon <[email protected]>
Signed-off-by: tisalon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants