diff --git a/nerfstudio/data/utils/nerfstudio_collate.py b/nerfstudio/data/utils/nerfstudio_collate.py index c259c46589..8c8a633fb8 100644 --- a/nerfstudio/data/utils/nerfstudio_collate.py +++ b/nerfstudio/data/utils/nerfstudio_collate.py @@ -155,6 +155,19 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N ), "All cameras must have distortion parameters or none of them should have distortion parameters.\ Generalized batching will be supported in the future." + if batch[0].metadata is not None: + metadata_keys = batch[0].metadata.keys() + assert all( + (cam.metadata.keys() == metadata_keys for cam in batch) + ), "All cameras must have the same metadata keys." + else: + assert all((cam.metadata is None for cam in batch)), "All cameras must have the same metadata keys." + + if batch[0].times is not None: + assert all((cam.times is not None for cam in batch)), "All cameras must have times present or absent." + else: + assert all((cam.times is None for cam in batch)), "All cameras must have times present or absent." + # If no batch dimension exists, then we need to stack everything and create a batch dimension on 0th dim if elem.shape == (): op = torch.stack @@ -163,11 +176,23 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N op = torch.cat # Create metadata dictionary - metadata_keys = batch[0].metadata.keys() - assert all( - (cam.metadata.keys() == metadata_keys for cam in batch) - ), "All cameras must have the same metadata keys." - metadata = {key: op([cam.metadata[key] for cam in batch], dim=0) for key in metadata_keys} + if batch[0].metadata is not None: + metadata = {key: op([cam.metadata[key] for cam in batch], dim=0) for key in batch[0].metadata.keys()} + else: + metadata = None + + if batch[0].distortion_params is not None: + distortion_params = op( + [cameras.distortion_params for cameras in batch], + dim=0, + ) + else: + distortion_params = None + + if batch[0].times is not None: + times = torch.stack([cameras.times for cameras in batch], dim=0) + else: + times = None return Cameras( op([cameras.camera_to_worlds for cameras in batch], dim=0), @@ -177,20 +202,9 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N op([cameras.cy for cameras in batch], dim=0), height=op([cameras.height for cameras in batch], dim=0), width=op([cameras.width for cameras in batch], dim=0), - distortion_params=op( - [ - cameras.distortion_params - if cameras.distortion_params is not None - else torch.zeros_like(cameras.distortion_params) - for cameras in batch - ], - dim=0, - ), + distortion_params=distortion_params, camera_type=op([cameras.camera_type for cameras in batch], dim=0), - times=torch.stack( - [cameras.times if cameras.times is not None else -torch.ones_like(cameras.times) for cameras in batch], - dim=0, - ), + times=times, metadata=metadata, )