diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 0e311fea..30ffae6c 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -46,7 +46,10 @@ class DefaultStrategy(Strategy): refine_start_iter (int): Start refining GSs after this iteration. Default is 500. refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. reset_every (int): Reset opacities every this steps. Default is 3000. - refine_every (int): Reine GSs every this steps. Default is 100. + refine_every (int): Refine GSs every this steps. Default is 100. + pause_refine_after_reset (int): Pause refining GSs until this number of steps after + reset, Default is 0 (no pause at all) and one might want to set this number to the + number of images in training set. absgrad (bool): Use absolute gradients for GS splitting. Default is False. revised_opacity (bool): Whether to use revised opacity heuristic from arXiv:2404.06109 (experimental). Default is False. @@ -80,6 +83,7 @@ class DefaultStrategy(Strategy): refine_stop_iter: int = 15_000 reset_every: int = 3000 refine_every: int = 100 + pause_refine_after_reset: int = 0 absgrad: bool = False revised_opacity: bool = False verbose: bool = False @@ -155,7 +159,11 @@ def step_post_backward( self._update_state(params, state, info, packed=packed) - if step > self.refine_start_iter and step % self.refine_every == 0: + if ( + step > self.refine_start_iter + and step % self.refine_every == 0 + and step % self.reset_every >= self.pause_refine_after_reset + ): # grow GSs n_dupli, n_split = self._grow_gs(params, optimizers, state, step) if self.verbose: @@ -175,6 +183,8 @@ def step_post_backward( # reset running stats state["grad2d"].zero_() state["count"].zero_() + if self.refine_scale2d_stop_iter > 0: + state["radii"].zero_() torch.cuda.empty_cache() if step % self.reset_every == 0: @@ -258,9 +268,9 @@ def _grow_gs( n_dupli = is_dupli.sum().item() is_large = ~is_small - if step < self.refine_scale2d_stop_iter: - is_large |= state["radii"] > self.grow_scale2d is_split = is_grad_high & is_large + if step < self.refine_scale2d_stop_iter: + is_split |= state["radii"] > self.grow_scale2d n_split = is_split.sum().item() # first duplicate