Skip to content

Commit

Permalink
Fix 2DGS for Correct Normal and Median Depth Calculation (#430)
Browse files Browse the repository at this point in the history
* Fix: Add missing camera offsets for render outputs in 2DGS kernel

* Fix: Correct render_normals transformation

* black formatting
  • Loading branch information
khokao authored Sep 30, 2024
1 parent 22540f7 commit 0a200cd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
4 changes: 4 additions & 0 deletions gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel(
render_colors += camera_id * image_height * image_width * COLOR_DIM; // get the global offset of the pixel w.r.t the camera
render_alphas += camera_id * image_height * image_width; // get the global offset of the pixel w.r.t the camera
last_ids += camera_id * image_height * image_width; // get the global offset of the pixel w.r.t the camera
render_normals += camera_id * image_height * image_width * 3;
render_distort += camera_id * image_height * image_width;
render_median += camera_id * image_height * image_width;
median_ids += camera_id * image_height * image_width;

// get the global offset of the background and mask
if (backgrounds != nullptr) {
Expand Down
4 changes: 3 additions & 1 deletion gsplat/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,9 @@ def rasterization_2dgs(
"gradient_2dgs": densify, # This holds the gradient used for densification for 2dgs
}

render_normals = render_normals @ torch.linalg.inv(viewmats)[0, :3, :3].T
render_normals = torch.einsum(
"...ij,...hwj->...hwi", torch.linalg.inv(viewmats)[..., :3, :3], render_normals
)

return (
render_colors,
Expand Down

0 comments on commit 0a200cd

Please sign in to comment.