-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restormer Implementation #8312
base: dev
Are you sure you want to change the base?
Restormer Implementation #8312
Conversation
…nsample class alias
…pass ./runtests.sh -f -u --net --coverage
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall but I had a few inline comments, and we should have full docstrings everywhere appropriate. For any classes meant for general purpose use (ie. not just by Restormer) please ensure they have docstring descriptions for the arguments (at the very least for constructor args). Thanks!
monai/networks/utils.py
Outdated
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". | ||
|
||
Args: | ||
x: Input tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we should specifically state that x
has shape BCHW[D]
.
monai/networks/utils.py
Outdated
|
||
if any(d % factor != 0 for d in input_size[2:]): | ||
raise ValueError( | ||
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}" | |
f"All spatial dimensions must be divisible by {factor}, spatial shape is: {input_size[2:]}" |
Maybe a little shorter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
monai/networks/blocks/downsample.py
Outdated
kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims) | ||
padding = tuple((k - 1) // 2 for k in kernel_size_) | ||
|
||
if down_mode == "conv": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if down_mode == "conv": | |
if down_mode == DownsampleMode.CONV: |
monai/networks/blocks/downsample.py
Outdated
bias=bias, | ||
), | ||
) | ||
elif down_mode == "convgroup": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif down_mode == "convgroup": | |
elif down_mode == DownsampleMode.CONVGROUP: |
monai/networks/blocks/downsample.py
Outdated
if post_conv: | ||
self.add_module("postconv", post_conv) | ||
|
||
elif down_mode == "pixelunshuffle": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif down_mode == "pixelunshuffle": | |
elif down_mode == DownsampleMode.PIXELSHUFFLE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, but I used DownsampleMode.PIXELUNSHUFFLE as in downsampling the restormer
uses pixel_unshuffling
while pixel_shuffling
is reserved for upsampling.
monai/networks/blocks/cablock.py
Outdated
"""Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention | ||
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise | ||
convolutions for local mixing before attention, achieving linear complexity vs quadratic | ||
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have a full docstring here describing the arguments for the constructor, and in the previous class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
monai/networks/nets/restormer.py
Outdated
class OverlapPatchEmbed(nn.Module): | ||
"""Initial feature extraction using overlapped convolutions. | ||
Unlike standard patch embeddings that use non-overlapping patches, | ||
this approach maintains spatial continuity through 3x3 convolutions.""" | ||
|
||
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): | ||
super().__init__() | ||
self.proj = Convolution( | ||
spatial_dims=spatial_dims, | ||
in_channels=in_c, | ||
out_channels=embed_dim, | ||
kernel_size=3, | ||
strides=1, | ||
padding=1, | ||
bias=bias, | ||
conv_only=True, | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.proj(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class OverlapPatchEmbed(nn.Module): | |
"""Initial feature extraction using overlapped convolutions. | |
Unlike standard patch embeddings that use non-overlapping patches, | |
this approach maintains spatial continuity through 3x3 convolutions.""" | |
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): | |
super().__init__() | |
self.proj = Convolution( | |
spatial_dims=spatial_dims, | |
in_channels=in_c, | |
out_channels=embed_dim, | |
kernel_size=3, | |
strides=1, | |
padding=1, | |
bias=bias, | |
conv_only=True, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.proj(x) | |
class OverlapPatchEmbed(Convolution): | |
""" | |
Initial feature extraction using overlapped convolutions. Unlike standard patch embeddings | |
that use non-overlapping patches, this approach maintains spatial continuity through 3x3 convolutions. | |
""" | |
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): | |
super().__init__( | |
spatial_dims=spatial_dims, | |
in_channels=in_c, | |
out_channels=embed_dim, | |
kernel_size=3, | |
strides=1, | |
padding=1, | |
bias=bias, | |
conv_only=True, | |
) |
Would it work to inherit directly from Convolution
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works! very elegant suggestion btw!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this outstanding contribution!
return x | ||
|
||
|
||
Downsample = DownSample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ericspod - do we normally provide alternative capitalizations to functions? "Downsample" is the generally accepted term (vs "Down Sample" which is less common).
I suggest using "Downsample" throughout, unless we offer alternative usage elsewhere that I haven't encountered. IDE auto-complete can help folks get the right capitalization.
Looking at your enums, you use "Downsample" which confirms (in my mind) that we should be using "Downsample" throughout.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have had this mechanism for other classes so other capitalisation could be used in scripts, such as "Transformd" and "TransformD". Here I'd say we don't need it though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, I mirrored the naming style of the Upsample class. Happy either way to remove it or keep it:
monai/networks/blocks/upsample.py
Upsample = UpSample
Subpixelupsample = SubpixelUpSample = SubpixelUpsample
monai/networks/nets/restormer.py
Outdated
Unlike standard patch embeddings that use non-overlapping patches, | ||
this approach maintains spatial continuity through 3x3 convolutions.""" | ||
|
||
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spell-out in_c to in_channels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
monai/networks/nets/restormer.py
Outdated
def __init__( | ||
self, | ||
spatial_dims=2, | ||
inp_channels=3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in_channels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
monai/networks/nets/restormer.py
Outdated
num_refinement_blocks=4, | ||
ffn_expansion_factor=2.66, | ||
bias=False, | ||
LayerNorm_type="WithBias", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make enum or convert to Bool (e.g., layer_norm_use_bias).
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good .... as long as the runtime is reasonable.
I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 3db93ce I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 9693e04 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: a89f299 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 450691f I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: d0920d8 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 1a48d4d I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: fe47807 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 86155cd I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 137a7f2 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: fb17baf I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 5ff0baa I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 2566db1 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: ac4047b I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 2b74270 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 9b74533 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 1ab34f6 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 4f4c62c I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 068688f I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: e2e1070 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 35c7ee4 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: d8cb6c1 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 6d96816 I, tisalon <[email protected]>, hereby add my Signed-off-by to this commit: 8a688fb Signed-off-by: tisalon <[email protected]>
Signed-off-by: tisalon <[email protected]>
Fixes Project-MONAI#8298 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
Fixes Project-MONAI#8267 . ### Description Fix channel-wise intensity normalization for integer type inputs. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: advcu987 <[email protected]> Signed-off-by: advcu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
Fixes Project-MONAI#8306 This previous api has been deprecated, update based on: https://docs.ngc.nvidia.com/api/?urls.primaryName=Private%20Artifacts%20(Models)%20API#/artifact-file-controller/downloadAllArtifactFiles ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]>
Fixes Project-MONAI#8298 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
Related to Project-MONAI#8241 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
) Fixes Project-MONAI#8298. ### Description This includes the tests for the `compressor` argument when testing with Zarr before version 3.0 when this argument was deprecated. A fix to upgrade the version of `pycln` used is also included. The version of PyTorch is also fixed to below 2.6 to avoid issues with misuse of `torch.load` which must be addressed later. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot <[email protected]>
…ns and simplify ValueError message in pixelunshuffle
…n Restormer model and update assert in forward layer to support 3D images
… Restormer class.
…ument descriptions and error handling details.
…ted changes Signed-off-by: tisalon <[email protected]>
Signed-off-by: tisalon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: tisalon <[email protected]>
Signed-off-by: tisalon <[email protected]>
Signed-off-by: tisalon <[email protected]>
Signed-off-by: tisalon <[email protected]>
Fixes # .
Description
This PR implements the Restormer architecture for high-resolution image restoration in MONAI following the discussion in issue #8261. The implementation supports both 2D and 3D images using MONAI's convolution as the base. Key additions include:
The implementation follows MONAI's coding patterns and includes performance validations against native PyTorch operations where applicable.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.