Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug related to refine_scale2d and add pause_refine_after_reset to default strategy #354

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions gsplat/strategy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class DefaultStrategy(Strategy):
refine_start_iter (int): Start refining GSs after this iteration. Default is 500.
refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000.
reset_every (int): Reset opacities every this steps. Default is 3000.
refine_every (int): Reine GSs every this steps. Default is 100.
refine_every (int): Refine GSs every this steps. Default is 100.
pause_refine_after_reset (int): Pause refining GSs until this number of steps after
reset, Default is 0 (no pause at all) and one might want to set this number to the
number of images in training set.
absgrad (bool): Use absolute gradients for GS splitting. Default is False.
revised_opacity (bool): Whether to use revised opacity heuristic from
arXiv:2404.06109 (experimental). Default is False.
Expand Down Expand Up @@ -80,6 +83,7 @@ class DefaultStrategy(Strategy):
refine_stop_iter: int = 15_000
reset_every: int = 3000
refine_every: int = 100
pause_refine_after_reset: int = 0
absgrad: bool = False
revised_opacity: bool = False
verbose: bool = False
Expand Down Expand Up @@ -155,7 +159,11 @@ def step_post_backward(

self._update_state(params, state, info, packed=packed)

if step > self.refine_start_iter and step % self.refine_every == 0:
if (
step > self.refine_start_iter
and step % self.refine_every == 0
and step % self.reset_every >= self.pause_refine_after_reset
):
# grow GSs
n_dupli, n_split = self._grow_gs(params, optimizers, state, step)
if self.verbose:
Expand All @@ -175,6 +183,8 @@ def step_post_backward(
# reset running stats
state["grad2d"].zero_()
state["count"].zero_()
if self.refine_scale2d_stop_iter > 0:
state["radii"].zero_()
torch.cuda.empty_cache()

if step % self.reset_every == 0:
Expand Down Expand Up @@ -258,9 +268,9 @@ def _grow_gs(
n_dupli = is_dupli.sum().item()

is_large = ~is_small
if step < self.refine_scale2d_stop_iter:
is_large |= state["radii"] > self.grow_scale2d
is_split = is_grad_high & is_large
if step < self.refine_scale2d_stop_iter:
is_split |= state["radii"] > self.grow_scale2d
jb-ye marked this conversation as resolved.
Show resolved Hide resolved
n_split = is_split.sum().item()

# first duplicate
Expand Down
Loading