Skip to content

Commit

Permalink
Merge pull request #94 from jysdoran/main
Browse files Browse the repository at this point in the history
Multi-GPU support and small fixes
  • Loading branch information
luca-medeiros authored Feb 16, 2025
2 parents 7a055c7 + fd0d210 commit 04eaddf
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 41 deletions.
11 changes: 6 additions & 5 deletions lang_sam/lang_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@

from lang_sam.models.gdino import GDINO
from lang_sam.models.sam import SAM
from lang_sam.models.utils import DEVICE


class LangSAM:
def __init__(self, sam_type="sam2.1_hiera_small", ckpt_path: str | None = None):
def __init__(self, sam_type="sam2.1_hiera_small", ckpt_path: str | None = None, device=DEVICE):
self.sam_type = sam_type

self.sam = SAM()
self.sam.build_model(sam_type, ckpt_path)
self.sam.build_model(sam_type, ckpt_path, device=device)
self.gdino = GDINO()
self.gdino.build_model()
self.gdino.build_model(device=device)

def predict(
self,
Expand Down Expand Up @@ -45,15 +47,14 @@ def predict(
sam_boxes = []
sam_indices = []
for idx, result in enumerate(gdino_results):
result = {k: (v.cpu().numpy() if hasattr(v, "numpy") else v) for k, v in result.items()}
processed_result = {
**result,
"masks": [],
"mask_scores": [],
}

if result["labels"]:
processed_result["boxes"] = result["boxes"].cpu().numpy()
processed_result["scores"] = result["scores"].cpu().numpy()
sam_images.append(np.asarray(images_pil[idx]))
sam_boxes.append(processed_result["boxes"])
sam_indices.append(idx)
Expand Down
33 changes: 10 additions & 23 deletions lang_sam/models/gdino.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,27 @@
from PIL import Image
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor

from lang_sam.models.utils import get_device_type

device_type = get_device_type()
DEVICE = torch.device(device_type)

if torch.cuda.is_available():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from lang_sam.models.utils import DEVICE

class GDINO:
def __init__(self):
self.build_model()

def build_model(self, ckpt_path: str | None = None):
model_id = "IDEA-Research/grounding-dino-base"
def build_model(self, ckpt_path: str | None = None, device=DEVICE):
model_id = "IDEA-Research/grounding-dino-base" if ckpt_path is None else ckpt_path
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
DEVICE
device
)

def predict(
self,
pil_images: list[Image.Image],
text_prompt: list[str],
images_pil: list[Image.Image],
texts_prompt: list[str],
box_threshold: float,
text_threshold: float,
) -> list[dict]:
for i, prompt in enumerate(text_prompt):
for i, prompt in enumerate(texts_prompt):
if prompt[-1] != ".":
text_prompt[i] += "."
inputs = self.processor(images=pil_images, text=text_prompt, return_tensors="pt").to(DEVICE)
texts_prompt[i] += "."
inputs = self.processor(images=images_pil, text=texts_prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs)

Expand All @@ -44,7 +31,7 @@ def predict(
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[k.size[::-1] for k in pil_images],
target_sizes=[k.size[::-1] for k in images_pil],
)
return results

Expand Down
15 changes: 3 additions & 12 deletions lang_sam/models/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor

from lang_sam.models.utils import get_device_type

DEVICE = torch.device(get_device_type())

if torch.cuda.is_available():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from lang_sam.models.utils import DEVICE

SAM_MODELS = {
"sam2.1_hiera_tiny": {
Expand All @@ -38,14 +29,14 @@


class SAM:
def build_model(self, sam_type: str, ckpt_path: str | None = None):
def build_model(self, sam_type: str, ckpt_path: str | None = None, device=DEVICE):
self.sam_type = sam_type
self.ckpt_path = ckpt_path
cfg = compose(config_name=SAM_MODELS[self.sam_type]["config"], overrides=[])
OmegaConf.resolve(cfg)
self.model = instantiate(cfg.model, _recursive_=True)
self._load_checkpoint(self.model)
self.model = self.model.to(DEVICE)
self.model = self.model.to(device)
self.model.eval()
self.mask_generator = SAM2AutomaticMaskGenerator(self.model)
self.predictor = SAM2ImagePredictor(self.model)
Expand Down
10 changes: 10 additions & 0 deletions lang_sam/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ def get_device_type() -> str:
else:
logging.warning("No GPU found, using CPU instead")
return "cpu"


device_type = get_device_type()
DEVICE = torch.device(device_type)

if torch.cuda.is_available():
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
2 changes: 1 addition & 1 deletion lang_sam/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class LangSAMAPI(ls.LitAPI):
def setup(self, device: str) -> None:
"""Initialize or load the LangSAM model."""
self.model = LangSAM(sam_type="sam2.1_hiera_small")
self.model = LangSAM(sam_type="sam2.1_hiera_small", device=device)
print("LangSAM model initialized.")

def decode_request(self, request) -> dict:
Expand Down

0 comments on commit 04eaddf

Please sign in to comment.