Skip to content

Commit

Permalink
Added support for MPS Apple Silicon and quality of life improvements …
Browse files Browse the repository at this point in the history
…to the GUI
  • Loading branch information
jjhickmon committed Jul 5, 2023
1 parent 083698b commit b8946cb
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 79 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ output/
.vscode/
workspace/
run*.sh
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@
prob = torch.flip(prob, dims=[-1])

# Probability mask -> index mask
out_mask = torch.argmax(prob, dim=0)
out_mask = torch.max(prob, dim=0).indices
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)

if args.save_scores:
Expand Down
6 changes: 5 additions & 1 deletion inference/interact/fbrs/controller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import mps

from ..fbrs.inference import clicker
from ..fbrs.inference.predictors import get_predictor
Expand Down Expand Up @@ -35,7 +36,10 @@ def add_click(self, x, y, is_positive):
click = clicker.Click(is_positive=is_positive, coords=(y, x))
self.clicker.add_click(click)
pred = self.predictor.get_prediction(self.clicker)
torch.cuda.empty_cache()
if self.device.type == 'cuda':
torch.cuda.empty_cache()
elif self.device.type == 'mps':
mps.empty_cache()

if self.probs_history:
self.probs_history.append((self.probs_history[-1][0], pred))
Expand Down
160 changes: 106 additions & 54 deletions inference/interact/gui.py

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions inference/interact/gui_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar)
from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import (QHBoxLayout, QLabel, QSpinBox, QVBoxLayout, QProgressBar)


def create_parameter_box(min_val, max_val, text, step=1, callback=None):
Expand All @@ -10,12 +10,12 @@ def create_parameter_box(min_val, max_val, text, step=1, callback=None):
dial.setMaximumWidth(150)
dial.setMinimum(min_val)
dial.setMaximum(max_val)
dial.setAlignment(Qt.AlignRight)
dial.setAlignment(Qt.AlignmentFlag.AlignRight)
dial.setSingleStep(step)
dial.valueChanged.connect(callback)

label = QLabel(text)
label.setAlignment(Qt.AlignRight)
label.setAlignment(Qt.AlignmentFlag.AlignRight)

layout.addWidget(label)
layout.addWidget(dial)
Expand All @@ -29,10 +29,10 @@ def create_gauge(text):
gauge = QProgressBar()
gauge.setMaximumHeight(28)
gauge.setMaximumWidth(200)
gauge.setAlignment(Qt.AlignCenter)
gauge.setAlignment(Qt.AlignmentFlag.AlignCenter)

label = QLabel(text)
label.setAlignment(Qt.AlignRight)
label.setAlignment(Qt.AlignmentFlag.AlignRight)

layout.addWidget(label)
layout.addWidget(gauge)
Expand Down
2 changes: 1 addition & 1 deletion inference/interact/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def end_path(self):
self.curr_path = [[] for _ in range(self.K + 1)]

def predict(self):
self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1).cuda()
self.out_prob = index_numpy_to_one_hot_torch(self.drawn_map, self.K+1)
# self.out_prob = torch.from_numpy(self.drawn_map).float().cuda()
# self.out_prob, _ = pad_divide_by(self.out_prob, 16, self.out_prob.shape[-2:])
# self.out_prob = aggregate_sbg(self.out_prob, keep_bg=True)
Expand Down
17 changes: 11 additions & 6 deletions inference/interact/interactive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def image_to_torch(frame: np.ndarray, device='cuda'):
return frame_norm, frame

def torch_prob_to_numpy_mask(prob):
mask = torch.argmax(prob, dim=0)
mask = torch.max(prob, dim=0).indices
mask = mask.cpu().numpy().astype(np.uint8)
return mask

Expand All @@ -26,16 +26,21 @@ def index_numpy_to_one_hot_torch(mask, num_classes):
"""
Some constants fro visualization
"""
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")

color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
# scales for better visualization
color_map_np = (color_map_np.astype(np.float32)*1.5).clip(0, 255).astype(np.uint8)
color_map = color_map_np.tolist()
if torch.cuda.is_available():
color_map_torch = torch.from_numpy(color_map_np).cuda() / 255
color_map_torch = torch.from_numpy(color_map_np).to(device) / 255

grayscale_weights = np.array([[0.3,0.59,0.11]]).astype(np.float32)
if torch.cuda.is_available():
grayscale_weights_torch = torch.from_numpy(grayscale_weights).cuda().unsqueeze(0)
grayscale_weights_torch = torch.from_numpy(grayscale_weights).to(device).unsqueeze(0)

def get_visualization(mode, image, mask, layer, target_object):
if mode == 'fade':
Expand Down Expand Up @@ -112,7 +117,7 @@ def overlay_davis_torch(image, mask, alpha=0.5, fade=False):
# Changes the image in-place to avoid copying
image = image.permute(1, 2, 0)
im_overlay = image
mask = torch.argmax(mask, dim=0)
mask = torch.max(mask, dim=0).indices

colored_mask = color_map_torch[mask]
foreground = image*alpha + (1-alpha)*colored_mask
Expand Down
1 change: 1 addition & 0 deletions inference/interact/s2m_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, s2m_net:S2M, num_objects, ignore_class, device='cuda:0'):
self.device = device

def interact(self, image, prev_mask, scr_mask):
print(self.device)
image = image.to(self.device, non_blocking=True)
prev_mask = prev_mask.unsqueeze(0)

Expand Down
38 changes: 28 additions & 10 deletions interactive_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import os
from os import path
# fix for Windows
if 'QT_QPA_PLATFORM_PLUGIN_PATH' not in os.environ:
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = ''
Expand All @@ -17,15 +18,21 @@
from inference.interact.fbrs_controller import FBRSController
from inference.interact.s2m.s2m_network import deeplabv3plus_resnet50 as S2M

from PyQt5.QtWidgets import QApplication
from PyQt6.QtWidgets import QApplication
from inference.interact.gui import App
from inference.interact.resource_manager import ResourceManager
from contextlib import nullcontext

torch.set_grad_enabled(False)

if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")

if __name__ == '__main__':

# Arguments parsing
parser = ArgumentParser()
parser.add_argument('--model', default='./saves/XMem.pth')
Expand Down Expand Up @@ -64,32 +71,43 @@
help='Resize the shorter side to this size. -1 to use original resolution. ')
args = parser.parse_args()

# create temporary workspace if not specified
config = vars(args)
config['enable_long_term'] = True
config['enable_long_term_count_usage'] = True

with torch.cuda.amp.autocast(enabled=not args.no_amp):
if config["workspace"] is None:
if config["images"] is not None:
basename = path.basename(config["images"])
elif config["video"] is not None:
basename = path.basename(config["video"])[:-4]
else:
raise NotImplementedError(
'Either images, video, or workspace has to be specified')

config["workspace"] = path.join('./workspace', basename)

with torch.cuda.amp.autocast(enabled=not args.no_amp) if device.type == 'cuda' else nullcontext():
# Load our checkpoint
network = XMem(config, args.model).cuda().eval()
network = XMem(config, args.model, map_location=device).to(device).eval()

# Loads the S2M model
if args.s2m_model is not None:
s2m_saved = torch.load(args.s2m_model)
s2m_model = S2M().cuda().eval()
s2m_saved = torch.load(args.s2m_model, map_location=device)
s2m_model = S2M().to(device).eval()
s2m_model.load_state_dict(s2m_saved)
else:
s2m_model = None

s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255)
s2m_controller = S2MController(s2m_model, args.num_objects, ignore_class=255, device=device)
if args.fbrs_model is not None:
fbrs_controller = FBRSController(args.fbrs_model)
fbrs_controller = FBRSController(args.fbrs_model, device=device)
else:
fbrs_controller = None

# Manages most IO
resource_manager = ResourceManager(config)

app = QApplication(sys.argv)
ex = App(network, resource_manager, s2m_controller, fbrs_controller, config)
sys.exit(app.exec_())
ex = App(network, resource_manager, s2m_controller, fbrs_controller, config, device)
sys.exit(app.exec())

0 comments on commit b8946cb

Please sign in to comment.