Skip to content

Start making 3d sam model more flexible #1033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,20 +819,31 @@ def get_unetr(

def get_decoder(
image_encoder: torch.nn.Module,
decoder_state: OrderedDict[str, torch.Tensor],
decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
device: Optional[Union[str, torch.device]] = None,
out_channels: int = 3,
flexible_load_checkpoint: bool = False
) -> DecoderAdapter:
"""Get decoder to predict outputs for automatic instance segmentation

Args:
image_encoder: The image encoder of the SAM model.
decoder_state: State to initialize the weights of the UNETR decoder.
decoder_state: Optional decoder state to initialize the weights of the UNETR decoder.
device: The device. By default, automatically chooses the best available device.
out_channels: The number of output channels. By default, set to '3'.
flexible_load_checkpoint: Whether to allow reinitialization of parameters
which could not be found in the provided decoder state. By default, set to 'False'.

Returns:
The decoder for instance segmentation.
"""
unetr = get_unetr(image_encoder, decoder_state, device)
unetr = get_unetr(
image_encoder=image_encoder,
decoder_state=decoder_state,
device=device,
out_channels=out_channels,
flexible_load_checkpoint=flexible_load_checkpoint,
)
return DecoderAdapter(unetr)


Expand Down
263 changes: 197 additions & 66 deletions micro_sam/models/sam_3d_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,154 @@
import os
from typing import Any, List, Dict, Type, Union, Optional
from collections import OrderedDict
from typing import Any, List, Dict, Type, Union, Optional, Literal

import torch
import torch.nn as nn

from segment_anything.modeling import Sam
from segment_anything.modeling.image_encoder import window_partition, window_unpartition

from ..util import get_sam_model
from .peft_sam import LoRASurgery
from ..instance_segmentation import get_decoder
from ..util import get_sam_model, _DEFAULT_MODEL


def get_sam_3d_model(
device: Union[str, torch.device],
n_classes: int,
image_size: int,
n_classes: int,
model_type: str = _DEFAULT_MODEL,
lora_rank: Optional[int] = None,
freeze_encoder: bool = False,
model_type: str = "vit_b",
decoder_choice: Literal["default", "unetr"] = "default",
device: Optional[Union[str, torch.device]] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
):
if lora_rank is None:
peft_kwargs = {}
else:
peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery}
) -> nn.Module:
"""Get the SAM 3D model for semantic segmentation.

Args:
image_size: The size of height / width of the input image.
n_classes: The number of output classes.
model_type: The choice of SAM model.
decoder_choice: Whether to use the SAM mask decoder, i.e. chosen by 'default' value,
or the UNETR decoder, i.e. chosen by 'unetr' value.
device: The torch device.
checkpoint_path: Optional, whether to load a finetuned model.

Returns:
The SAM 3D model.
"""
if decoder_choice not in ["default", "unetr"]:
raise ValueError(
f"'{decoder_choice}' as the decoder choice is not supported. Please choose either 'default' or 'unetr'."
)

kwargs = {}
if decoder_choice == "default":
kwargs["num_multimask_outputs"] = n_classes

peft_kwargs = {}
if lora_rank is not None:
peft_kwargs["rank"] = lora_rank
peft_kwargs["peft_module"] = LoRASurgery

_, sam = get_sam_model(
_, sam, state = get_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
return_sam=True,
return_state=True,
flexible_load_checkpoint=True,
num_multimask_outputs=n_classes,
image_size=image_size,
peft_kwargs=peft_kwargs,
**kwargs
)

# Make sure not to freeze the encoder when using LoRA.
_freeze_encoder = freeze_encoder if lora_rank is None else False
sam_3d = Sam3DWrapper(sam, freeze_encoder=_freeze_encoder, model_type=model_type)
sam_3d.to(device)

return sam_3d
if decoder_choice == "default":
model = Sam3DClassicWrapper(sam_model=sam, model_type=model_type)
else:
model = Sam3DUNETRWrapper(
sam_model=sam,
model_type=model_type,
decoder_state=state.get("decoder_state", None), # Loads the decoder state automatically, if weights found.
output_channels=n_classes,
)

return model.to(device)

class Sam3DWrapper(nn.Module):
def __init__(self, sam_model: Sam, freeze_encoder: bool, model_type: str = "vit_b"):
"""Initializes the Sam3DWrapper object.

Args:
sam_model: The Sam model to be wrapped.
freeze_encoder: Whether to freeze the image encoder.
model_type: The choice of segment anything model to wrap adapters for respective model configuration.
"""
class Sam3DWrapperBase(nn.Module):
"""Sam3DWrapperBase is a base class to implement specific SAM-based 3d semantic segmentation models.
"""
def __init__(self, model_type: str = "vit_b"):
super().__init__()
self.embed_dim, self.num_heads = self._get_model_config(model_type)

# Model configurations
def _get_model_config(self, model_type: str):
# Returns the model configuration.
if model_type == "vit_b":
embed_dim, num_heads = 768, 12
return 768, 12
elif model_type == "vit_l":
embed_dim, num_heads = 1024, 16
return 1024, 16
elif model_type == "vit_h":
embed_dim, num_heads = 1280, 16
return 1280, 16
else:
raise ValueError(f"'{model_type}' is not a supported choice of model.")

def _prepare_inputs(self, batched_input: List[Dict[str, Any]]):
batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0)
original_size = batched_input[0]["original_size"]
assert all(inp["original_size"] == original_size for inp in batched_input)

shape = batched_images.shape
assert shape[1] == 3
batch_size, d_size, hw_size = shape[0], shape[2], shape[-2]

batched_images = batched_images.transpose(1, 2).contiguous().view(-1, 3, hw_size, hw_size)
return batched_images, original_size, batch_size, d_size

def forward(
self, batched_input: List[Dict[str, Any]], multimask_output: bool = False,
) -> List[Dict[str, torch.Tensor]]:
"""Predicts 3D masks for the provided inputs.

Unlike original SAM, this model only supports automatic segmentation and does not support prompts.

Args:
batched_input: A list over input images, each a dictionary with the following keys.
'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
'original_size': The original size of the image (HxW) before transformation.
multimask_output: Whether to predict with the multi- or single-mask head of the maks decoder.

Returns:
A list over input images, where each element is as dictionary with the following keys:
'masks': Mask prediction for this object (IMPORTANT, in accordance to SAM output style).
'iou_predictions': IOU score prediction for this object for the default mask decoder (OPTIONAL).
'low_res_masks': Low resolution mask prediction for this object for the default mask decoder (OPTIONAL).
"""
raise NotImplementedError(
"Sam3DWrapperBase is just a class template. Use a child class that implements the forward pass."
)


class Sam3DClassicWrapper(Sam3DWrapperBase):
def __init__(self, sam_model: Sam, model_type: str = "vit_b"):
"""Initializes the Sam3DClassicWrapper object.

Args:
sam_model: The SAM model to be wrapped.
model_type: The choice of segment anything model to wrap adapters for respective model configuration.
"""
super().__init__(model_type)

sam_model.image_encoder = ImageEncoderViT3DWrapper(
image_encoder=sam_model.image_encoder, num_heads=num_heads, embed_dim=embed_dim,
image_encoder=sam_model.image_encoder,
num_heads=self.num_heads,
embed_dim=self.embed_dim,
)
self.sam_model = sam_model

self.freeze_encoder = freeze_encoder
if self.freeze_encoder:
for param in self.sam_model.image_encoder.parameters():
param.requires_grad = False

def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -> List[Dict[str, torch.Tensor]]:
def forward(
self, batched_input: List[Dict[str, Any]], multimask_output: bool = False
) -> List[Dict[str, torch.Tensor]]:
"""Predict 3D masks for the current inputs.

Unlike original SAM this model only supports automatic segmentation and does not support prompts.
Expand All @@ -84,33 +157,19 @@ def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -
batched_input: A list over input images, each a dictionary with the following keys.
'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
'original_size': The original size of the image (HxW) before transformation.
multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.
multimask_output: Whether to predict with the multi- or single-mask head of the maks decoder.

Returns:
A list over input images, where each element is as dictionary with the following keys:
'masks': Mask prediction for this object.
'iou_predictions': IOU score prediction for this object.
'low_res_masks': Low resolution mask prediction for this object.
'iou_predictions': IOU score prediction for this object for the default mask decoder.
'low_res_masks': Low resolution mask prediction for this object for the default mask decoder.
"""
batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0)
original_size = batched_input[0]["original_size"]
assert all(inp["original_size"] == original_size for inp in batched_input)

# dimensions: [b, 3, d, h, w]
shape = batched_images.shape
assert shape[1] == 3
batch_size, d_size, hw_size = shape[0], shape[2], shape[-2]
# Transpose the axes, so that the depth axis is the first axis and the channel
# axis is the second axis. This is expected by the transformer!
batched_images = batched_images.transpose(1, 2)
assert batched_images.shape[1] == d_size
batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size)
batched_images, original_size, batch_size, d_size = self._prepare_inputs(batched_input)

input_images = self.sam_model.preprocess(batched_images)
image_embeddings = self.sam_model.image_encoder(input_images, d_size)
sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
points=None, boxes=None, masks=None
)
sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(points=None, boxes=None, masks=None)
low_res_masks, iou_predictions = self.sam_model.mask_decoder(
image_embeddings=image_embeddings,
image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
Expand All @@ -119,9 +178,7 @@ def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -
multimask_output=multimask_output
)
masks = self.sam_model.postprocess_masks(
low_res_masks,
input_size=batched_images.shape[-2:],
original_size=original_size,
masks=low_res_masks, input_size=batched_images.shape[-2:], original_size=original_size,
)

# Bring the masks and low-res masks into the correct shape:
Expand All @@ -130,23 +187,97 @@ def forward(self, batched_input: List[Dict[str, Any]], multimask_output: bool) -

n_channels = masks.shape[1]
masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1]))
masks = masks.transpose(1, 2)

low_res_masks = low_res_masks.view(
*(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1])
)

masks = masks.transpose(1, 2)
low_res_masks = low_res_masks.transpose(1, 2)

# Make the output compatable with the SAM output.
# Make the output compatible with the SAM output.
outputs = [{
"masks": mask.unsqueeze(0),
"iou_predictions": iou_pred,
"low_res_logits": low_res_mask.unsqueeze(0)
"masks": mask.unsqueeze(0), "iou_predictions": iou_pred, "low_res_logits": low_res_mask.unsqueeze(0)
} for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)]

return outputs


class Sam3DUNETRWrapper(Sam3DWrapperBase):
def __init__(
self,
sam_model: Sam,
model_type: str = "vit_b",
decoder_state: Optional[OrderedDict[str, torch.Tensor]] = None,
output_channels: int = 3,
):
"""Initializes the Sam3DUNETRWrapper object.

Args:
sam_model: The SAM model to be wrapped.
model_type: The choice of segment anything model to wrap adapters for respective model configuration.
decoder_state: Optional, whether to load the UNETR decoder with provided pretrained weights.
output_channels: The choice of output classes.
"""
super().__init__(model_type)

self.image_encoder = ImageEncoderViT3DWrapper(
image_encoder=sam_model.image_encoder,
num_heads=self.num_heads,
embed_dim=self.embed_dim,
)
self._preprocess = sam_model.preprocess

# NOTE: Remove the output layer weights as we have new target class for the new task.
if decoder_state is not None:
decoder_state = OrderedDict(
[(k, v) for k, v in decoder_state.items() if not k.startswith("out_conv.")]
)

# Get a custom decoder, which overtakes the SAM mask decoder.
self.decoder = get_decoder(
image_encoder=sam_model.image_encoder,
decoder_state=decoder_state,
out_channels=output_channels,
flexible_load_checkpoint=True,
)

def forward(
self, batched_input: List[Dict[str, Any]], multimask_output: bool = False
) -> List[Dict[str, torch.Tensor]]:
"""Predict 3D masks for the current inputs.

Unlike original SAM this model only supports automatic segmentation and does not support prompts.

Args:
batched_input: A list over input images, each a dictionary with the following keys.
'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
'original_size': The original size of the image (HxW) before transformation.
multimask_output: Whether to predict with the multi- or single-mask head of the maks decoder.

Returns:
A list over input images, where each element is as dictionary with the following key:
'masks': Mask prediction for this object.
"""
batched_images, original_size, batch_size, d_size = self._prepare_inputs(batched_input)

input_images = self._preprocess(batched_images)
image_embeddings = self.image_encoder(input_images, d_size)
masks = self.decoder(image_embeddings, batched_images.shape[-2:], original_size)

# Bring the masks and low-res masks into the correct shape:
# - disentangle batches and z-slices
# - rearrange output channels and z-slices

n_channels = masks.shape[1]
masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1]))
masks = masks.transpose(1, 2)

# Make the output compatable with the SAM output.
outputs = [{"masks": mask.unsqueeze(0)} for mask in masks]

return outputs


class ImageEncoderViT3DWrapper(nn.Module):
def __init__(self, image_encoder: nn.Module, num_heads: int = 12, embed_dim: int = 768):

Expand Down