Skip to content

Commit

Permalink
use dynamic shared mem to allow for adjusting tile_size (#229)
Browse files Browse the repository at this point in the history
Co-authored-by: Ruilong Li <[email protected]>
  • Loading branch information
liruilong940607 and Ruilong Li authored Jun 23, 2024
1 parent 02f0c6f commit 2aa70c2
Show file tree
Hide file tree
Showing 5 changed files with 501 additions and 90 deletions.
24 changes: 22 additions & 2 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,30 @@ def rasterize_to_pixels(

# Pad the channels to the nearest supported number if necessary
channels = colors.shape[-1]
if channels > 33 or channels == 0:
if channels > 513 or channels == 0:
# TODO: maybe worth to support zero channels?
raise ValueError(f"Unsupported number of color channels: {channels}")
if channels not in (1, 2, 3, 4, 5, 8, 9, 16, 17, 32, 33):
if channels not in (
1,
2,
3,
4,
5,
8,
9,
16,
17,
32,
33,
64,
65,
128,
129,
256,
257,
512,
513,
):
padded_channels = (1 << (channels - 1).bit_length()) - channels
colors = torch.cat(
[
Expand Down
65 changes: 33 additions & 32 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <torch/extension.h>
#include <tuple>

#define MAX_BLOCK_SIZE (16 * 16)
#define N_THREADS 256

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
Expand Down Expand Up @@ -69,15 +68,16 @@ world_to_cam_bwd_tensor(const torch::Tensor &means, // [N, 3]
const bool viewmats_requires_grad);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
fully_fused_projection_fwd_tensor(const torch::Tensor &means, // [N, 3]
const at::optional<torch::Tensor> &covars, // [N, 6] optional
const at::optional<torch::Tensor> &quats, // [N, 4] optional
const at::optional<torch::Tensor> &scales, // [N, 3] optional
const torch::Tensor &viewmats, // [C, 4, 4]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t image_width, const uint32_t image_height, const float eps2d,
const float near_plane, const float far_plane,
const float radius_clip, const bool calc_compensations);
fully_fused_projection_fwd_tensor(
const torch::Tensor &means, // [N, 3]
const at::optional<torch::Tensor> &covars, // [N, 6] optional
const at::optional<torch::Tensor> &quats, // [N, 4] optional
const at::optional<torch::Tensor> &scales, // [N, 3] optional
const torch::Tensor &viewmats, // [C, 4, 4]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t image_width, const uint32_t image_height, const float eps2d,
const float near_plane, const float far_plane, const float radius_clip,
const bool calc_compensations);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
fully_fused_projection_bwd_tensor(
Expand All @@ -101,13 +101,14 @@ fully_fused_projection_bwd_tensor(
const bool viewmats_requires_grad);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
isect_tiles_tensor(const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2]
const torch::Tensor &radii, // [C, N] or [nnz]
const torch::Tensor &depths, // [C, N] or [nnz]
const at::optional<torch::Tensor> &camera_ids, // [nnz]
isect_tiles_tensor(const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2]
const torch::Tensor &radii, // [C, N] or [nnz]
const torch::Tensor &depths, // [C, N] or [nnz]
const at::optional<torch::Tensor> &camera_ids, // [nnz]
const at::optional<torch::Tensor> &gaussian_ids, // [nnz]
const uint32_t C, const uint32_t tile_size, const uint32_t tile_width,
const uint32_t tile_height, const bool sort, const bool double_buffer);
const uint32_t C, const uint32_t tile_size,
const uint32_t tile_width, const uint32_t tile_height,
const bool sort, const bool double_buffer);

torch::Tensor isect_offset_encode_tensor(const torch::Tensor &isect_ids, // [n_isects]
const uint32_t C, const uint32_t tile_width,
Expand All @@ -124,7 +125,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
const torch::Tensor &flatten_ids // [n_isects]
);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand All @@ -139,7 +140,7 @@ rasterize_to_pixels_bwd_tensor(
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids, // [n_isects]
const torch::Tensor &flatten_ids, // [n_isects]
// forward outputs
const torch::Tensor &render_alphas, // [C, image_height, image_width, 1]
const torch::Tensor &last_ids, // [C, image_height, image_width]
Expand All @@ -150,7 +151,7 @@ rasterize_to_pixels_bwd_tensor(
bool absgrad);

std::tuple<torch::Tensor, torch::Tensor> rasterize_to_indices_in_range_tensor(
const uint32_t range_start, const uint32_t range_end, // iteration steps
const uint32_t range_start, const uint32_t range_end, // iteration steps
const torch::Tensor transmittances, // [C, image_height, image_width]
// Gaussian parameters
const torch::Tensor &means2d, // [C, N, 2]
Expand All @@ -160,7 +161,7 @@ std::tuple<torch::Tensor, torch::Tensor> rasterize_to_indices_in_range_tensor(
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
const torch::Tensor &flatten_ids // [n_isects]
const torch::Tensor &flatten_ids // [n_isects]
);

torch::Tensor compute_sh_fwd_tensor(const uint32_t degrees_to_use,
Expand All @@ -181,16 +182,16 @@ compute_sh_bwd_tensor(const uint32_t K, const uint32_t degrees_to_use,
****************************************************************************************/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor, torch::Tensor>
fully_fused_projection_packed_fwd_tensor(const torch::Tensor &means, // [N, 3]
const at::optional<torch::Tensor> &covars, // [N, 6]
const at::optional<torch::Tensor> &quats, // [N, 3]
const at::optional<torch::Tensor> &scales, // [N, 3]
const torch::Tensor &viewmats, // [C, 4, 4]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t image_width, const uint32_t image_height,
const float eps2d, const float near_plane,
const float far_plane, const float radius_clip,
const bool calc_compensations);
fully_fused_projection_packed_fwd_tensor(
const torch::Tensor &means, // [N, 3]
const at::optional<torch::Tensor> &covars, // [N, 6]
const at::optional<torch::Tensor> &quats, // [N, 3]
const at::optional<torch::Tensor> &scales, // [N, 3]
const torch::Tensor &viewmats, // [C, 4, 4]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t image_width, const uint32_t image_height, const float eps2d,
const float near_plane, const float far_plane, const float radius_clip,
const bool calc_compensations);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
fully_fused_projection_packed_bwd_tensor(
Expand All @@ -203,8 +204,8 @@ fully_fused_projection_packed_bwd_tensor(
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t image_width, const uint32_t image_height, const float eps2d,
// fwd outputs
const torch::Tensor &camera_ids, // [nnz]
const torch::Tensor &gaussian_ids, // [nnz]
const torch::Tensor &camera_ids, // [nnz]
const torch::Tensor &gaussian_ids, // [nnz]
const torch::Tensor &conics, // [nnz, 3]
const at::optional<torch::Tensor> &compensations, // [nnz] optional
// grad outputs
Expand Down
Loading

0 comments on commit 2aa70c2

Please sign in to comment.