Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruilong Li committed Sep 29, 2023
1 parent a326334 commit 46fe277
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 17 deletions.
24 changes: 18 additions & 6 deletions diff_rast/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from torch import Tensor


def compute_sh_color(viewdirs: Float[Tensor, "*batch 3"], sh_coeffs: Float[Tensor, "*batch D C"]):
def compute_sh_color(
viewdirs: Float[Tensor, "*batch 3"], sh_coeffs: Float[Tensor, "*batch D C"]
):
"""
:param viewdirs (*, C)
:param sh_coeffs (*, D, C) sh coefficients for each color channel
Expand Down Expand Up @@ -66,7 +68,9 @@ def eval_sh_bases(basis_dim: int, dirs: torch.Tensor):
:return: torch.Tensor (..., basis_dim)
"""
result = torch.empty((*dirs.shape[:-1], basis_dim), dtype=dirs.dtype, device=dirs.device)
result = torch.empty(
(*dirs.shape[:-1], basis_dim), dtype=dirs.dtype, device=dirs.device
)
result[..., 0] = SH_C0
if basis_dim > 1:
x, y, z = dirs.unbind(-1)
Expand Down Expand Up @@ -100,7 +104,9 @@ def eval_sh_bases(basis_dim: int, dirs: torch.Tensor):
result[..., 21] = SH_C4[5] * xz * (7 * zz - 3)
result[..., 22] = SH_C4[6] * (xx - yy) * (7 * zz - 1)
result[..., 23] = SH_C4[7] * xz * (xx - 3 * yy)
result[..., 24] = SH_C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
result[..., 24] = SH_C4[8] * (
xx * (xx - 3 * yy) - yy * (3 * xx - yy)
)
return result


Expand Down Expand Up @@ -148,7 +154,9 @@ def scale_rot_to_cov3d(scale: Tensor, glob_scale: float, quat: Tensor) -> Tensor
return M @ M.transpose(-1, -2) # (..., 3, 3)


def project_cov3d_ewa(mean3d: Tensor, cov3d: Tensor, viewmat: Tensor, fx: float, fy: float) -> Tensor:
def project_cov3d_ewa(
mean3d: Tensor, cov3d: Tensor, viewmat: Tensor, fx: float, fy: float
) -> Tensor:
assert mean3d.shape[-1] == 3, mean3d.shape
assert cov3d.shape[-2:] == (3, 3), cov3d.shape
assert viewmat.shape[-2:] == (4, 4), viewmat.shape
Expand Down Expand Up @@ -212,7 +220,9 @@ def clip_near_plane(p, viewmat, thresh=0.1):


def get_tile_bbox(pix_center, pix_radius, tile_bounds, BLOCK_X=16, BLOCK_Y=16):
tile_size = torch.tensor([BLOCK_X, BLOCK_Y], dtype=torch.float32, device=pix_center.device)
tile_size = torch.tensor(
[BLOCK_X, BLOCK_Y], dtype=torch.float32, device=pix_center.device
)
tile_center = pix_center / tile_size
tile_radius = pix_radius[..., None] / tile_size

Expand Down Expand Up @@ -253,7 +263,9 @@ def project_gaussians_forward(
conic, radius, det_valid = compute_cov2d_bounds(cov2d)
center = project_pix(projmat, means3d, img_size)
tile_min, tile_max = get_tile_bbox(center, radius, tile_bounds)
tile_area = (tile_max[..., 0] - tile_min[..., 0]) * (tile_max[..., 1] - tile_min[..., 1])
tile_area = (tile_max[..., 0] - tile_min[..., 0]) * (
tile_max[..., 1] - tile_min[..., 1]
)
mask = (tile_area > 0) & (~is_close) & det_valid

num_tiles_hit = tile_area
Expand Down
4 changes: 3 additions & 1 deletion diff_rast/cov2d_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def forward(
), f"Expected input cov2d to be of shape (*batch, 3) (upper triangular values), but got {tuple(cov2d.shape)}"
num_pts = cov2d.shape[0]
assert num_pts > 0
conic, radius = _C.compute_cov2d_bounds_forward(num_pts, cov2d.contiguous().cuda())
conic, radius = _C.compute_cov2d_bounds_forward(
num_pts, cov2d.contiguous().cuda()
)
return (conic, radius)

@staticmethod
Expand Down
10 changes: 2 additions & 8 deletions diff_rast/project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def forward(
img_height: int,
img_width: int,
tile_bounds: Tuple[int, int, int],
clip_thresh:float=0.01
clip_thresh: float = 0.01,
):
num_points = means3d.shape[-2]

Expand Down Expand Up @@ -102,13 +102,7 @@ def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d):
conics,
) = ctx.saved_tensors

(
v_cov2d,
v_cov3d,
v_mean3d,
v_scale,
v_quat,
) = _C.project_gaussians_backward(
(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat,) = _C.project_gaussians_backward(
ctx.num_points,
means3d,
scales,
Expand Down
11 changes: 9 additions & 2 deletions tests/test_cov2d_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@ def compare_binding_to_pytorch():

num_cov2ds = 2

_covs2d = torch.rand((num_cov2ds, 2, 2), dtype=torch.float32, device=device, requires_grad=True)
_covs2d = torch.rand(
(num_cov2ds, 2, 2), dtype=torch.float32, device=device, requires_grad=True
)
covs2d = torch.stack(
[torch.triu(_covs2d)[:, 0, 0], torch.triu(_covs2d)[:, 0, 1], torch.triu(_covs2d)[:, 1, 1]], dim=-1
[
torch.triu(_covs2d)[:, 0, 0],
torch.triu(_covs2d)[:, 0, 1],
torch.triu(_covs2d)[:, 1, 1],
],
dim=-1,
)

conic, radii = compute_cov2d_bounds.apply(covs2d)
Expand Down

0 comments on commit 46fe277

Please sign in to comment.