Skip to content

Commit

Permalink
Merge pull request #73 from yinchimaoliang/support-fusion
Browse files Browse the repository at this point in the history
[Feature]: Support fusion
  • Loading branch information
yinchimaoliang authored Oct 2, 2022
2 parents 5f36567 + 019eb93 commit 89527fe
Show file tree
Hide file tree
Showing 16 changed files with 909 additions and 49 deletions.
24 changes: 16 additions & 8 deletions dataset/nusc_mv_det_dataset.py → datasets/nusc_det_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pyquaternion import Quaternion
from torch.utils.data import Dataset

__all__ = ['NuscMVDetDataset']
__all__ = ['NuscDetDataset']

map_name_from_general_to_detection = {
'human.pedestrian.adult': 'pedestrian',
Expand Down Expand Up @@ -151,7 +151,7 @@ def depth_transform(cam_depth, resize, resize_dims, crop, flip, rotate):
return torch.Tensor(depth_map)


class NuscMVDetDataset(Dataset):
class NuscDetDataset(Dataset):
def __init__(self,
ida_aug_conf,
bda_aug_conf,
Expand All @@ -166,7 +166,8 @@ def __init__(self,
to_rgb=True),
return_depth=False,
sweep_idxes=list(),
key_idxes=list()):
key_idxes=list(),
use_fusion=False):
"""Dataset used for bevdetection task.
Args:
ida_aug_conf (dict): Config for ida augmentation.
Expand All @@ -183,10 +184,13 @@ def __init__(self,
default: list().
key_idxes (list): List of key idxes to be used.
default: list().
use_fusion (bool): Whether to use lidar data.
default: False.
"""
super().__init__()
self.infos = mmcv.load(info_path)
self.is_train = is_train
self.split = 'train' if is_train else 'val'
self.ida_aug_conf = ida_aug_conf
self.bda_aug_conf = bda_aug_conf
self.data_root = data_root
Expand All @@ -208,6 +212,7 @@ def __init__(self,
assert sum([key_idx < 0 for key_idx in key_idxes]) == len(key_idxes),\
'All `key_idxes` must less than 0.'
self.key_idxes = [0] + key_idxes
self.use_fusion = use_fusion

def _get_sample_indices(self):
"""Load annotations from ann_file.
Expand Down Expand Up @@ -311,14 +316,15 @@ def get_image(self, cam_infos, cams):
sweep_ida_mats = list()
sweep_sensor2sensor_mats = list()
sweep_timestamps = list()
gt_depth = list()
sweep_lidar_depth = list()
for cam in cams:
imgs = list()
sensor2ego_mats = list()
intrin_mats = list()
ida_mats = list()
sensor2sensor_mats = list()
timestamps = list()
lidar_depth = list()
key_info = cam_infos[0]
resize, resize_dims, crop, flip, \
rotate_ida = self.sample_ida_augmentation(
Expand Down Expand Up @@ -384,16 +390,17 @@ def get_image(self, cam_infos, cams):
intrin_mat[3, 3] = 1
intrin_mat[:3, :3] = torch.Tensor(
cam_info[cam]['calibrated_sensor']['camera_intrinsic'])
if self.return_depth and sweep_idx == 0:
if self.return_depth and (self.use_fusion or sweep_idx == 0):
file_name = os.path.split(cam_info[cam]['filename'])[-1]
point_depth = np.fromfile(os.path.join(
self.data_root, 'depth_gt', f'{file_name}.bin'),
self.data_root, 'depth_gt', self.split,
f'{file_name}.bin'),
dtype=np.float32,
count=-1).reshape(-1, 3)
point_depth_augmented = depth_transform(
point_depth, resize, self.ida_aug_conf['final_dim'],
crop, flip, rotate_ida)
gt_depth.append(point_depth_augmented)
lidar_depth.append(point_depth_augmented)
img, ida_mat = img_transform(
img,
resize=resize,
Expand All @@ -415,6 +422,7 @@ def get_image(self, cam_infos, cams):
sweep_ida_mats.append(torch.stack(ida_mats))
sweep_sensor2sensor_mats.append(torch.stack(sensor2sensor_mats))
sweep_timestamps.append(torch.tensor(timestamps))
sweep_lidar_depth.append(torch.stack(lidar_depth))
# Get mean pose of all cams.
ego2global_rotation = np.mean(
[key_info[cam]['ego_pose']['rotation'] for cam in cams], 0)
Expand All @@ -436,7 +444,7 @@ def get_image(self, cam_infos, cams):
img_metas,
]
if self.return_depth:
ret_list.append(torch.stack(gt_depth))
ret_list.append(torch.stack(sweep_lidar_depth).permute(1, 0, 2, 3))
return ret_list

def get_gt(self, info, cams):
Expand Down
Loading

0 comments on commit 89527fe

Please sign in to comment.