diff --git a/dataset/static_dataset.py b/dataset/static_dataset.py index 3b5e6f8c..5800f5f3 100644 --- a/dataset/static_dataset.py +++ b/dataset/static_dataset.py @@ -130,7 +130,7 @@ def __getitem__(self, idx): indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)] merged_images = None - merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int32) + merged_masks = np.zeros((self.num_frames, 384, 384), dtype=np.int64) for i, list_id in enumerate(indices): images, masks = self._get_sample(list_id) @@ -148,8 +148,8 @@ def __getitem__(self, idx): target_objects = labels.tolist() # Generate one-hot ground-truth - cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int32) - first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int32) + cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64) + first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64) for i, l in enumerate(target_objects): this_mask = (masks==l) cls_gt[this_mask] = i+1 diff --git a/dataset/vos_dataset.py b/dataset/vos_dataset.py index a4043a08..be0f8a15 100644 --- a/dataset/vos_dataset.py +++ b/dataset/vos_dataset.py @@ -190,8 +190,8 @@ def __getitem__(self, idx): masks = np.stack(masks, 0) # Generate one-hot ground-truth - cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int32) - first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int32) + cls_gt = np.zeros((self.num_frames, 384, 384), dtype=np.int64) + first_frame_gt = np.zeros((1, self.max_num_obj, 384, 384), dtype=np.int64) for i, l in enumerate(target_objects): this_mask = (masks==l) cls_gt[this_mask] = i+1