Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continue adding double and float16 support #277

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ __global__ void quat_scale_to_covar_preci_bwd_kernel(
v_scales += idx * 3;
v_quats += idx * 4;

vec4<OpT> quat = glm::make_vec4(quats + idx * 4);
vec3<OpT> scale = glm::make_vec3(scales + idx * 3);
vec4<OpT> quat = vec4<OpT>(glm::make_vec4(quats + idx * 4));
vec3<OpT> scale = vec3<OpT>(glm::make_vec3(scales + idx * 3));
mat3<OpT> rotmat = quat_to_rotmat<OpT>(quat);

vec4<OpT> v_quat(0.f);
Expand All @@ -58,7 +58,7 @@ __global__ void quat_scale_to_covar_preci_bwd_kernel(
v_covars[2] * .5f, v_covars[4] * .5f, v_covars[5]);
} else {
v_covars += idx * 9;
mat3<OpT> v_covar_cast = glm::make_mat3(v_covars);
mat3<OpT> v_covar_cast = mat3<OpT>(glm::make_mat3(v_covars));
v_covar = glm::transpose(v_covar_cast);
}
quat_scale_to_covar_vjp<OpT>(quat, scale, rotmat, v_covar, v_quat, v_scale);
Expand All @@ -73,7 +73,7 @@ __global__ void quat_scale_to_covar_preci_bwd_kernel(
v_precis[2] * .5f, v_precis[4] * .5f, v_precis[5]);
} else {
v_precis += idx * 9;
mat3<OpT> v_precis_cast = glm::make_mat3(v_precis);
mat3<OpT> v_precis_cast = mat3<OpT>(glm::make_mat3(v_precis));
v_preci = glm::transpose(v_precis_cast);
}
quat_scale_to_preci_vjp<OpT>(quat, scale, rotmat, v_preci, v_quat, v_scale);
Expand Down
6 changes: 3 additions & 3 deletions gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ quat_scale_to_covar_preci_fwd_tensor(const torch::Tensor &quats, // [N, 4]
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, quats.scalar_type(), "quat_scale_to_covar_preci_fwd", [&]() {
quat_scale_to_covar_preci_fwd_kernel<<<(N + N_THREADS - 1) / N_THREADS,
N_THREADS, 0, stream>>>(
N, quats.data_ptr<float>(), scales.data_ptr<float>(), triu,
compute_covar ? covars.data_ptr<float>() : nullptr,
compute_preci ? precis.data_ptr<float>() : nullptr);
N, quats.data_ptr<scalar_t>(), scales.data_ptr<scalar_t>(), triu,
compute_covar ? covars.data_ptr<scalar_t>() : nullptr,
compute_preci ? precis.data_ptr<scalar_t>() : nullptr);
});
}
return std::make_tuple(covars, precis);
Expand Down
78 changes: 43 additions & 35 deletions gsplat/cuda/csrc/rasterize_to_indices_in_range.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <cub/cub.cuh>
#include <cuda_runtime.h>

#include <c10/cuda/CUDAMathCompat.h>

namespace cg = cooperative_groups;

/****************************************************************************
Expand All @@ -31,6 +33,9 @@ __global__ void rasterize_to_indices_in_range_kernel(
// each thread draws one pixel, but also timeshares caching gaussians in a
// shared tile

// For now we'll upcast float16 and bfloat16 to float32
using OpT = typename OpType<T>::type;

auto block = cg::this_thread_block();
uint32_t camera_id = block.group_index().x;
uint32_t tile_id = block.group_index().y * tile_width + block.group_index().z;
Expand All @@ -41,8 +46,8 @@ __global__ void rasterize_to_indices_in_range_kernel(
tile_offsets += camera_id * tile_height * tile_width;
transmittances += camera_id * image_height * image_width;

T px = (T)j + 0.5f;
T py = (T)i + 0.5f;
OpT px = (OpT)j + 0.5f;
OpT py = (OpT)i + 0.5f;
int32_t pix_id = i * image_width + j;

// return if out of bounds
Expand Down Expand Up @@ -77,16 +82,16 @@ __global__ void rasterize_to_indices_in_range_kernel(

extern __shared__ int s[];
int32_t *id_batch = (int32_t *)s; // [block_size]
vec3<T> *xy_opacity_batch =
vec3<float> *xy_opacity_batch =
reinterpret_cast<vec3<float> *>(&id_batch[block_size]); // [block_size]
vec3<T> *conic_batch =
vec3<float> *conic_batch =
reinterpret_cast<vec3<float> *>(&xy_opacity_batch[block_size]); // [block_size]

// current visibility left to render
// transmittance is gonna be used in the backward pass which requires a high
// numerical precision so we (should) use double for it. However double make
// bwd 1.5x slower so we stick with float for now.
T trans, next_trans;
OpT trans, next_trans;
if (inside) {
trans = transmittances[pix_id];
next_trans = trans;
Expand All @@ -112,10 +117,10 @@ __global__ void rasterize_to_indices_in_range_kernel(
if (idx < isect_range_end) {
int32_t g = flatten_ids[idx];
id_batch[tr] = g;
const vec2<T> xy = means2d[g];
const T opac = opacities[g];
xy_opacity_batch[tr] = {xy.x, xy.y, opac};
conic_batch[tr] = conics[g];
const vec2<OpT> xy = means2d[g];
const OpT opac = opacities[g];
xy_opacity_batch[tr] = vec3<float>(xy.x, xy.y, opac);
conic_batch[tr] = vec3<float>(conics[g]);
}

// wait for other threads to collect the gaussians in batch
Expand All @@ -124,14 +129,14 @@ __global__ void rasterize_to_indices_in_range_kernel(
// process gaussians in the current batch for this pixel
uint32_t batch_size = min(block_size, isect_range_end - batch_start);
for (uint32_t t = 0; (t < batch_size) && !done; ++t) {
const vec3<T> conic = conic_batch[t];
const vec3<T> xy_opac = xy_opacity_batch[t];
const T opac = xy_opac.z;
const vec2<T> delta = {xy_opac.x - px, xy_opac.y - py};
const T sigma =
const vec3<OpT> conic = vec3<OpT>(conic_batch[t]);
const vec3<OpT> xy_opac = vec3<OpT>(xy_opacity_batch[t]);
const OpT opac = xy_opac.z;
const vec2<OpT> delta = {xy_opac.x - px, xy_opac.y - py};
const OpT sigma =
0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) +
conic.y * delta.x * delta.y;
T alpha = min(0.999f, opac * __expf(-sigma));
OpT alpha = min(0.999f, opac * c10::cuda::compat::exp(-sigma));

if (sigma < 0.f || alpha < 1.f / 255.f) {
continue;
Expand Down Expand Up @@ -214,16 +219,17 @@ std::tuple<torch::Tensor, torch::Tensor> rasterize_to_indices_in_range_tensor(
if (n_isects) {
torch::Tensor chunk_cnts = torch::zeros({C * image_height * image_width},
means2d.options().dtype(torch::kInt32));
rasterize_to_indices_in_range_kernel<float>
<<<blocks, threads, shared_mem, stream>>>(
range_start, range_end, C, N, n_isects,
reinterpret_cast<vec2<float> *>(means2d.data_ptr<float>()),
reinterpret_cast<vec3<float> *>(conics.data_ptr<float>()),
opacities.data_ptr<float>(), image_width, image_height, tile_size,
tile_width, tile_height, tile_offsets.data_ptr<int32_t>(),
flatten_ids.data_ptr<int32_t>(), transmittances.data_ptr<float>(),
nullptr, chunk_cnts.data_ptr<int32_t>(), nullptr, nullptr);

AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, means2d.scalar_type(), "rasterize_to_indices_in_range", [&]() {
rasterize_to_indices_in_range_kernel<scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
range_start, range_end, C, N, n_isects,
reinterpret_cast<vec2<scalar_t> *>(means2d.data_ptr<scalar_t>()),
reinterpret_cast<vec3<scalar_t> *>(conics.data_ptr<scalar_t>()),
opacities.data_ptr<scalar_t>(), image_width, image_height, tile_size,
tile_width, tile_height, tile_offsets.data_ptr<int32_t>(),
flatten_ids.data_ptr<int32_t>(), transmittances.data_ptr<scalar_t>(),
nullptr, chunk_cnts.data_ptr<int32_t>(), nullptr, nullptr);
});
torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
n_elems = cumsum[-1].item<int64_t>();
chunk_starts = cumsum - chunk_cnts;
Expand All @@ -237,16 +243,18 @@ std::tuple<torch::Tensor, torch::Tensor> rasterize_to_indices_in_range_tensor(
torch::Tensor pixel_ids =
torch::empty({n_elems}, means2d.options().dtype(torch::kInt64));
if (n_elems) {
rasterize_to_indices_in_range_kernel<float>
<<<blocks, threads, shared_mem, stream>>>(
range_start, range_end, C, N, n_isects,
reinterpret_cast<vec2<float> *>(means2d.data_ptr<float>()),
reinterpret_cast<vec3<float> *>(conics.data_ptr<float>()),
opacities.data_ptr<float>(), image_width, image_height, tile_size,
tile_width, tile_height, tile_offsets.data_ptr<int32_t>(),
flatten_ids.data_ptr<int32_t>(), transmittances.data_ptr<float>(),
chunk_starts.data_ptr<int32_t>(), nullptr,
gaussian_ids.data_ptr<int64_t>(), pixel_ids.data_ptr<int64_t>());
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, means2d.scalar_type(), "rasterize_to_indices_in_range", [&]() {
rasterize_to_indices_in_range_kernel<scalar_t>
<<<blocks, threads, shared_mem, stream>>>(
range_start, range_end, C, N, n_isects,
reinterpret_cast<vec2<scalar_t> *>(means2d.data_ptr<scalar_t>()),
reinterpret_cast<vec3<scalar_t> *>(conics.data_ptr<scalar_t>()),
opacities.data_ptr<scalar_t>(), image_width, image_height, tile_size,
tile_width, tile_height, tile_offsets.data_ptr<int32_t>(),
flatten_ids.data_ptr<int32_t>(), transmittances.data_ptr<scalar_t>(),
chunk_starts.data_ptr<int32_t>(), nullptr,
gaussian_ids.data_ptr<int64_t>(), pixel_ids.data_ptr<int64_t>());
});
}
return std::make_tuple(gaussian_ids, pixel_ids);
}
Loading