Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix nerfstudio collate function #2965

Merged
merged 3 commits into from
May 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 32 additions & 18 deletions nerfstudio/data/utils/nerfstudio_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N
), "All cameras must have distortion parameters or none of them should have distortion parameters.\
Generalized batching will be supported in the future."

if batch[0].metadata is not None:
metadata_keys = batch[0].metadata.keys()
assert all(
(cam.metadata.keys() == metadata_keys for cam in batch)
), "All cameras must have the same metadata keys."
else:
assert all((cam.metadata is None for cam in batch)), "All cameras must have the same metadata keys."

if batch[0].times is not None:
assert all((cam.times is not None for cam in batch)), "All cameras must have times present or absent."
else:
assert all((cam.times is None for cam in batch)), "All cameras must have times present or absent."

# If no batch dimension exists, then we need to stack everything and create a batch dimension on 0th dim
if elem.shape == ():
op = torch.stack
Expand All @@ -163,11 +176,23 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N
op = torch.cat

# Create metadata dictionary
metadata_keys = batch[0].metadata.keys()
assert all(
(cam.metadata.keys() == metadata_keys for cam in batch)
), "All cameras must have the same metadata keys."
metadata = {key: op([cam.metadata[key] for cam in batch], dim=0) for key in metadata_keys}
if batch[0].metadata is not None:
metadata = {key: op([cam.metadata[key] for cam in batch], dim=0) for key in batch[0].metadata.keys()}
else:
metadata = None

if batch[0].distortion_params is not None:
distortion_params = op(
[cameras.distortion_params for cameras in batch],
dim=0,
)
else:
distortion_params = None

if batch[0].times is not None:
times = torch.stack([cameras.times for cameras in batch], dim=0)
else:
times = None

return Cameras(
op([cameras.camera_to_worlds for cameras in batch], dim=0),
Expand All @@ -177,20 +202,9 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N
op([cameras.cy for cameras in batch], dim=0),
height=op([cameras.height for cameras in batch], dim=0),
width=op([cameras.width for cameras in batch], dim=0),
distortion_params=op(
[
cameras.distortion_params
if cameras.distortion_params is not None
else torch.zeros_like(cameras.distortion_params)
for cameras in batch
],
dim=0,
),
distortion_params=distortion_params,
camera_type=op([cameras.camera_type for cameras in batch], dim=0),
times=torch.stack(
[cameras.times if cameras.times is not None else -torch.ones_like(cameras.times) for cameras in batch],
dim=0,
),
times=times,
metadata=metadata,
)

Expand Down