diff --git a/README.md b/README.md index 9ce85fd..1ac5bc9 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ Ask and answer questions on [MONAI VISTA's GitHub discussions tab](https://githu ## License -The codebase is under Apache 2.0 Licence. The model weight is under special NVIDIA license. +The codebase is under Apache 2.0 Licence. The model weight is under special NVIDIA license. ## Reference diff --git a/configs/infer.yaml b/configs/infer.yaml index 89d018b..8db5bbb 100644 --- a/configs/infer.yaml +++ b/configs/infer.yaml @@ -1,3 +1,4 @@ +trt: true amp: true input_channels: 1 patch_size: [128, 128, 128] diff --git a/data/README.md b/data/README.md index 3fbdde0..03354c3 100644 --- a/data/README.md +++ b/data/README.md @@ -81,7 +81,7 @@ The output of this step is multiple JSON files, each file corresponds to one dataset. ##### 2. Add label_dict.json and label_mapping.json -Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`. +Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`. ## SupverVoxel Generation 1. Download the segment anything repo and download the ViT-H weights diff --git a/dices.json b/dices.json new file mode 100644 index 0000000..6febab3 --- /dev/null +++ b/dices.json @@ -0,0 +1,135 @@ +{ + "liver": 0.9999467134475708, + "kidney": 1.0, + "spleen": 0.9998987317085266, + "pancreas": 0.9998106360435486, + "right kidney": 0.9997254610061646, + "aorta": 0.9999536275863647, + "inferior vena cava": 0.9997954964637756, + "right adrenal gland": 1.0, + "left adrenal gland": 0.9971064925193787, + "gallbladder": 1.0, + "esophagus": 0.9997258186340332, + "stomach": 0.9999147653579712, + "duodenum": 0.9995471835136414, + "left kidney": 0.9997535347938538, + "bladder": 0.9998233318328857, + "prostate or uterus (deprecated)": 1.0, + "portal vein and splenic vein": 0.9996570348739624, + "rectum (deprecated)": 1.0, + "small bowel": 0.9995405673980713, + "lung": 1.0, + "bone": 1.0, + "brain": 1.0, + "lung tumor": 1.0, + "pancreatic tumor": 1.0, + "hepatic vessel": 1.0, + "hepatic tumor": 1.0, + "colon cancer primaries": 1.0, + "left lung upper lobe": 0.9999317526817322, + "left lung lower lobe": 0.9999247789382935, + "right lung upper lobe": 1.0, + "right lung middle lobe": 0.9999620318412781, + "right lung lower lobe": 0.9999691843986511, + "vertebrae L5": 0.9999167323112488, + "vertebrae L4": 0.9999210834503174, + "vertebrae L3": 1.0, + "vertebrae L2": 0.9997909665107727, + "vertebrae L1": 0.9998704195022583, + "vertebrae T12": 0.999764084815979, + "vertebrae T11": 0.9997434616088867, + "vertebrae T10": 0.9998674392700195, + "vertebrae T9": 0.9997072815895081, + "vertebrae T8": 0.9992929697036743, + "vertebrae T7": 1.0, + "vertebrae T6": 1.0, + "vertebrae T5": 1.0, + "vertebrae T4": 1.0, + "vertebrae T3": 1.0, + "vertebrae T2": 1.0, + "vertebrae T1": 1.0, + "vertebrae C7": 1.0, + "vertebrae C6": 1.0, + "vertebrae C5": 1.0, + "vertebrae C4": 1.0, + "vertebrae C3": 1.0, + "vertebrae C2": 1.0, + "vertebrae C1": 1.0, + "trachea": 1.0, + "left iliac artery": 0.998672604560852, + "right iliac artery": 0.9997827410697937, + "left iliac vena": 0.9996752142906189, + "right iliac vena": 0.9997751712799072, + "colon": 0.9997839331626892, + "left rib 1": 1.0, + "left rib 2": 1.0, + "left rib 3": 1.0, + "left rib 4": 1.0, + "left rib 5": 1.0, + "left rib 6": 0.9985436797142029, + "left rib 7": 0.9997116327285767, + "left rib 8": 1.0, + "left rib 9": 0.9997071027755737, + "left rib 10": 0.9987931251525879, + "left rib 11": 1.0, + "left rib 12": 1.0, + "right rib 1": 1.0, + "right rib 2": 1.0, + "right rib 3": 1.0, + "right rib 4": 1.0, + "right rib 5": 1.0, + "right rib 6": 0.9992054104804993, + "right rib 7": 0.999552845954895, + "right rib 8": 0.9996969103813171, + "right rib 9": 1.0, + "right rib 10": 0.9995119571685791, + "right rib 11": 1.0, + "right rib 12": 1.0, + "left humerus": 0.9719626307487488, + "right humerus": 0.9873417615890503, + "left scapula": 1.0, + "right scapula": 1.0, + "left clavicula": 1.0, + "right clavicula": 1.0, + "left femur": 0.999920129776001, + "right femur": 0.9998330473899841, + "left hip": 0.9999256730079651, + "right hip": 0.9999226927757263, + "sacrum": 0.9997796416282654, + "left gluteus maximus": 0.9998824000358582, + "right gluteus maximus": 0.9998437166213989, + "left gluteus medius": 0.9997230172157288, + "right gluteus medius": 0.9997458457946777, + "left gluteus minimus": 0.9993826150894165, + "right gluteus minimus": 0.9997991919517517, + "left autochthon": 0.999840259552002, + "right autochthon": 0.9998072981834412, + "left iliopsoas": 0.9998109340667725, + "right iliopsoas": 0.9998148679733276, + "left atrial appendage": 1.0, + "brachiocephalic trunk": 1.0, + "left brachiocephalic vein": 1.0, + "right brachiocephalic vein": 1.0, + "left common carotid artery": 1.0, + "right common carotid artery": 1.0, + "costal cartilages": 0.9993331432342529, + "heart": 0.9998570084571838, + "left kidney cyst": 1.0, + "right kidney cyst": 0.9997888803482056, + "prostate": 1.0, + "pulmonary vein": 1.0, + "skull": 1.0, + "spinal cord": 0.9996580481529236, + "sternum": 1.0, + "left subclavian artery": 1.0, + "right subclavian artery": 1.0, + "superior vena cava": 1.0, + "thyroid gland": 1.0, + "vertebrae S1": 0.9998401999473572, + "bone lesion": 1.0, + "kidney mass (deprecated)": 1.0, + "liver tumor (deprecated)": 1.0, + "vertebrae L6 (deprecated)": 1.0, + "airway": 1.0, + "average": 0.9995385372277462 +} diff --git a/scripts/debugger.py b/scripts/debugger.py index c924862..b2568f4 100644 --- a/scripts/debugger.py +++ b/scripts/debugger.py @@ -123,8 +123,12 @@ def on_button_click(event, ax=ax): print("-- segmenting ---") self.generate_mask() print("-- done ---") - print("-- Note: Point only prompts will only do 128 cubic segmentation, a cropping artefact will be observed. ---") - print("-- Note: Point without class will be treated as supported class, which has worse zero-shot ability. Try class > 132 to perform better zeroshot. ---") + print( + "-- Note: Point only prompts will only do 128 cubic segmentation, a cropping artefact will be observed. ---" + ) + print( + "-- Note: Point without class will be treated as supported class, which has worse zero-shot ability. Try class > 132 to perform better zeroshot. ---" + ) print("-- Note: CTRL + Right Click will be adding negative points. ---") print( "-- Note: Click points on different foreground class will cause segmentation conflicts. Clear first. ---" @@ -132,7 +136,7 @@ def on_button_click(event, ax=ax): print( "-- Note: Click points not matching class prompts will also cause confusion. ---" ) - + self.update_slice(ax) # self.point_start = len(self.clicked_points) diff --git a/scripts/infer.py b/scripts/infer.py index 924a5ab..313a7a2 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -12,6 +12,7 @@ import logging import os import sys +import time from functools import partial import monai @@ -32,6 +33,8 @@ from .train import CONFIG from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point +trt_wrap, TRT_AVAILABLE = optional_import("monai.networks", name="trt_wrap") + rearrange, _ = optional_import("einops", name="rearrange") sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) IGNORE_PROMPT = set( @@ -73,6 +76,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override): parser.update(pairs=_args) self.amp = parser.get_parsed_content("amp") + self.trt = parser.get_parsed_content("trt") input_channels = parser.get_parsed_content("input_channels") patch_size = parser.get_parsed_content("patch_size") self.patch_size = patch_size @@ -128,6 +132,28 @@ def __init__(self, config_file="./configs/infer.yaml", **override): self.save_transforms = transforms.Compose(save_transforms) self.prev_mask = None self.batch_data = None + if self.trt and TRT_AVAILABLE: + bundle_root = parser.get_parsed_content("bundle_root") + ts = os.path.getmtime(config_file) + trt_args = { + "precision": "fp16", + "build_args": { + "builder_optimization_level": 5, + "precision_constraints": "obey", + }, + "timestamp": ts, + } + + trt_wrap( + self.model.image_encoder.encoder, + f"{bundle_root}/image_encoder", + args=trt_args, + ) + trt_wrap( + self.model.class_head, + f"{bundle_root}/class_head", + args=trt_args, + ) return def clear_cache(self): @@ -161,6 +187,7 @@ def infer( used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save time and avoid repeated inference. This is by default disabled. """ + time00 = time.time() self.model.eval() if not isinstance(image_file, dict): image_file = {"image": image_file} @@ -255,12 +282,15 @@ def infer( finished = False if finished: break + print(f"Infer Time: {time.time() - time00}") + if not finished: raise RuntimeError("Infer not finished due to OOM.") return batch_data[0]["pred"] @torch.no_grad() def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): + time00 = time.time() self.model.eval() device = f"cuda:{rank}" if not isinstance(image_file, dict): @@ -302,6 +332,8 @@ def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0): finished = False if finished: break + print(f"InferEverything Time: {time.time() - time00}") + if not finished: raise RuntimeError("Infer not finished due to OOM.") @@ -324,5 +356,7 @@ def batch_infer_everything(self, datalist=str, basedir=str): if __name__ == "__main__": + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True fire, _ = optional_import("fire") fire.Fire(InferClass) diff --git a/vista3d/modeling/segresnetds.py b/vista3d/modeling/segresnetds.py index e8c96d5..ebd1d9c 100644 --- a/vista3d/modeling/segresnetds.py +++ b/vista3d/modeling/segresnetds.py @@ -12,7 +12,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import Union +from typing import List, Tuple, Union import numpy as np import torch @@ -473,7 +473,7 @@ def _forward( f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}" ) - x_down = self.encoder(x) + x_down = self.encoder(x=x) x_down.reverse() x = x_down.pop(0) @@ -483,8 +483,9 @@ def _forward( outputs: list[torch.Tensor] = [] outputs_auto: list[torch.Tensor] = [] - x_ = x.clone() + if with_point: + x_ = x.clone() i = 0 for level in self.up_layers: x = level["upsample"](x) @@ -496,7 +497,8 @@ def _forward( i = i + 1 outputs.reverse() - x = x_ + x = x_ + if with_label: i = 0 for level in self.up_layers_auto: @@ -522,7 +524,7 @@ def _forward( def forward( self, x: torch.Tensor, with_point=True, with_label=True, **kwargs - ) -> Union[None, torch.Tensor, list[torch.Tensor]]: + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: return self._forward(x, with_point, with_label) def set_auto_grad(self, auto_freeze=False, point_freeze=False): diff --git a/vista3d/modeling/vista3d.py b/vista3d/modeling/vista3d.py index 8352ba9..39e58f5 100644 --- a/vista3d/modeling/vista3d.py +++ b/vista3d/modeling/vista3d.py @@ -306,7 +306,8 @@ def forward( # force releasing memories that set to None torch.cuda.empty_cache() if class_vector is not None: - logits, _ = self.class_head(out_auto, class_vector) + logits, _ = self.class_head(out_auto, class_vector=class_vector) + if point_coords is not None: point_logits = self.point_head( out, point_coords, point_labels, class_vector=prompt_class