diff --git a/segmentation-refinement/segmentation_refinement/eval_helper.py b/segmentation-refinement/segmentation_refinement/eval_helper.py index 8d9bf37..a7f153d 100644 --- a/segmentation-refinement/segmentation_refinement/eval_helper.py +++ b/segmentation-refinement/segmentation_refinement/eval_helper.py @@ -30,11 +30,14 @@ def safe_forward(model, im, seg, inter_s8=None, inter_s4=None): p_inter_s8 = torch.zeros(b, 1, newH, newW, device=im.device) - 1 p_inter_s8[:,:,0:ph,0:pw] = inter_s8 inter_s8 = p_inter_s8 + inter_s8 = inter_s8.half() if inter_s4 is not None: p_inter_s4 = torch.zeros(b, 1, newH, newW, device=im.device) - 1 p_inter_s4[:,:,0:ph,0:pw] = inter_s4 inter_s4 = p_inter_s4 - + inter_s4 = inter_s4.half() + im = im.half() + seg = seg.half() images = model(im, seg, inter_s8, inter_s4) return_im = {} diff --git a/segmentation-refinement/segmentation_refinement/main.py b/segmentation-refinement/segmentation_refinement/main.py index 9891d19..43b35b4 100644 --- a/segmentation-refinement/segmentation_refinement/main.py +++ b/segmentation-refinement/segmentation_refinement/main.py @@ -35,7 +35,8 @@ def __init__(self, device='cpu', model_folder=None, download_and_check_model=Tru new_dict[name] = v self.model.load_state_dict(new_dict) self.model.eval().to(device) - + self.model = self.model.half() + self.im_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(