Skip to content

rasterization_2dgs outputs wrong shape for surf_normals for single camera #611

Open
@shiukaheng

Description

@shiukaheng

I'm using the main branch of gsplat. When rendering 2d gaussians with camera dimension = 1, surf_normals tensor is missing the batch dimension, being (H, W, 3) while I expected (C, H, W, 3). Everything other output tensor has the batch dimension. Increasing the number of cameras to 2 fixes it.

Here is some minimal code to reproduce it:

import torch

args = {
    "means": torch.rand(1000, 3) * 2 - 1,
    "quats": torch.rand(1000, 4),
    "scales": torch.rand(1000, 3),
    "opacities": torch.rand(1000),
    "colors": torch.rand(1000, 3),
    "Ks": torch.eye(3).unsqueeze(0) + torch.rand(1, 3, 3) * 0.1,  # Slightly perturbed identity matrix
    "viewmats": torch.eye(4).unsqueeze(0) + torch.rand(1, 4, 4) * 0.1,  # Slightly perturbed identity matrix
    "width": 256,
    "height": 256,
    "near_plane": 0.0001,
    "far_plane": 1000,
    "backgrounds": torch.rand(1, 4),
}

# Normalize quaternions
args["quats"] = args["quats"] / torch.norm(args["quats"], dim=1, keepdim=True)

# Moving everything to the GPU
args = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in args.items()}

# Print the shape of all tensors in args
for key in args:
    # Make sure its a tensor, otherwise just print it   
    if isinstance(args[key], torch.Tensor):
        print(key, args[key].shape)
    else:
        print(key, args[key])

# Uncomment the following lines to double the number of cameras, and the results will now be correct.
# args["Ks"] = torch.cat((args["Ks"], args["Ks"]), dim=0)
# args["viewmats"] = torch.cat((args["viewmats"], args["viewmats"]), dim=0)
# args["backgrounds"] = torch.cat((args["backgrounds"], args["backgrounds"]), dim=0)

from gsplat import rasterization_2dgs

render_colors_and_depth, render_alphas, render_normals, render_normals_from_depth, render_distort, render_median, meta = rasterization_2dgs(**args, render_mode="RGB+ED")
# Print the shape of all tensors in the output
print("render_colors_and_depth", render_colors_and_depth.shape)
print("render_alphas", render_alphas.shape)
print("render_normals", render_normals.shape)
print("render_normals_from_depth", render_normals_from_depth.shape)
print("render_distort", render_distort.shape)
print("render_median", render_median.shape)

assert len(render_normals_from_depth.shape) == 4, f"render_normals_from_depth should have 4 dimensions: (C, H, W, 3), got {render_normals_from_depth.shape}"

Output

(gs) heng-work@WINDOWS-HT7B50O:~/Documents/repos/gs$ uv run ./test.py 
means torch.Size([1000, 3])
quats torch.Size([1000, 4])
scales torch.Size([1000, 3])
opacities torch.Size([1000])
colors torch.Size([1000, 3])
Ks torch.Size([1, 3, 3])
viewmats torch.Size([1, 4, 4])
width 256
height 256
near_plane 0.0001
far_plane 1000
backgrounds torch.Size([1, 4])
/home/heng-work/Documents/repos/gs/.venv/lib/python3.10/site-packages/torch/utils/cpp_extension.py:2059: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
render_colors_and_depth torch.Size([1, 256, 256, 4])
render_alphas torch.Size([1, 256, 256, 1])
render_normals torch.Size([1, 256, 256, 3])
render_normals_from_depth torch.Size([256, 256, 3])
render_distort torch.Size([1, 256, 256, 1])
render_median torch.Size([1, 256, 256, 1])
Traceback (most recent call last):
  File "/home/heng-work/Documents/repos/gs/./test.py", line 51, in <module>
    assert len(render_normals_from_depth.shape) == 4, f"render_normals_from_depth should have 4 dimensions: (C, H, W, 3), got {render_normals_from_depth.shape}"
AssertionError: render_normals_from_depth should have 4 dimensions: (C, H, W, 3), got torch.Size([256, 256, 3])

Thanks in advance

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions