diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 7375ecd8..c636b7e4 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -579,9 +579,9 @@ def backward(ctx, v_covars: Tensor, v_precis: Tensor): compute_covar = ctx.compute_covar compute_preci = ctx.compute_preci triu = ctx.triu - if v_covars.is_sparse: + if compute_covar and v_covars.is_sparse: v_covars = v_covars.to_dense() - if v_precis.is_sparse: + if compute_preci and v_precis.is_sparse: v_precis = v_precis.to_dense() v_quats, v_scales = _make_lazy_cuda_func("quat_scale_to_covar_preci_bwd")( quats,