Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstracting out the GS heuristics into a Strategy class #278

Merged
merged 46 commits into from
Jul 17, 2024

Conversation

liruilong940607
Copy link
Collaborator

@liruilong940607 liruilong940607 commented Jul 11, 2024

The idea is to move the densification heuristics from the example scripts into gsplat library and package them into a class, so that the downstream users could simply use them via from gsplat.strategy import DefaultStrategy, MCMCStrategy.

The API is fairly simple:

params: torch.nn.ParameterDict = ...
optimizers: Dict[str, torch.optim.Optimizer] = ...
strategy = DefaultStrategy(<arguments>)
state = strategy.initialize_state()

for step in range(max_steps):
    render_color, render_alpha, info = rasterization(...)
    strategy.step_pre_backward(params, optimizers, state, step=step, info=info, <other arguments>)
    loss.backward()
    strategy.step_post_backward(params, optimizers, state, step=step, info=info, <other arguments>)

This would also allow users to switch between different strategies, and would make it more easily to embrace more GS heuristics. This is a joint efforts with @rtabrizi.

I also verified that by switch to DefaultStrategy and MCMCStrategy in the example script, nothing is affected. Now these two scripts have very little differences so we could potentially merge them into one script. But I'll leave it to another PR to minimize the code change of this PR

python simple_trainer.py

# Before this PR
> Step:  6999 {'mem': 7.582746505737305, 'ellipse_time': 424.46320390701294, 'num_GS': 4488772}
> {"psnr": 26.358938217163086, "ssim": 0.8353559374809265, "lpips": 0.12293213605880737, "ellipse_time": 0.026054859161376953, "num_GS": 4488772}

# After this PR
> Step:  6999 {'mem': 7.579144477844238, 'ellipse_time': 423.3516597747803, 'num_GS': 4481712}
> {"psnr": 26.348445892333984, "ssim": 0.8354260325431824, "lpips": 0.12293919920921326, "ellipse_time": 0.026077985763549805, "num_GS": 4481712}
python simple_trainer_mcmc.py 

# Before this PR
> Step:  6999 {'mem': 1.7906079292297363, 'ellipse_time': 247.97816705703735, 'num_GS': 1000000}
> {"psnr": 26.010623931884766, "ssim": 0.8161215782165527, "lpips": 0.15706905722618103, "ellipse_time": 0.009712855021158854, "num_GS": 1000000}

# After this PR
> Step:  6999 {'mem': 1.7903776168823242, 'ellipse_time': 239.95758247375488, 'num_GS': 1000000}
> {"psnr": 26.001811981201172, "ssim": 0.8158013224601746, "lpips": 0.1573895812034607, "ellipse_time": 0.009663820266723633, "num_GS": 1000000}

TODO:

tentative doc preview: https://669107b6a724ef1f2c186e67--inquisitive-speculoos-27b42b.netlify.app/apis/strategy

@liruilong940607 liruilong940607 changed the title [Refactor] Abstract out the densification logics (Original and MCMC) from the train script into gsplat library Abstracting out the GS heuristics into a Strategy class Jul 11, 2024
def __init__(
self,
params: torch.nn.ParameterDict,
optimizers: List[torch.optim.Optimizer],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't this need to be a dict matching the key in the paramdict to the optimizer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A single optimizer is enough as long as the param group has a "name" field defined. The reason I'm putting a List here is in case user wants to use different optimizer for different field. E.g., SGD for color and Adam for the rest

https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py#L217-L224

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so it will automatically loop through all optimizers and check if any of the param group names match the paramdict?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in nerfstudio there is a separate torch.nn.Optimizer for each param group, so we should check if passing in a list of those works

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List of optimizers also works as long as the param group has a "name" field.

The logic is that it will loop over the list of optimizers and every parameter group in it to find a match, based on the "name" field

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, maybe its more user friendly to just use a dict optimizer and force each parameter has a seperate optimizer?

# Refine GSs every this steps
refine_every: int = 100,
# Use absolute gradients for GS splitting
absgrad: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is true maybe we want to automatically set the grow_grad2d to .008 otherwise the number of gaussians will explode? or maybe warn the user?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if warning is a good idea. on different scenes this threshold could be very different. Auto set is also not ideal as we want to give the user the freedom to adjust that.

I think the best we can do is probably note it clearly in the docstring.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, another idea is to create another class AbsGradStrategy that inherit from this class and set the default value to 0.008.

def _refine_duplicate(self, mask: Tensor):
device = mask.device
sel = torch.where(mask)[0]
for optimizer in self._optimizers:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want some sort of helper function for this? looping through optimizers and adding/pruning/removing is a common pattern that all strategies will need

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe base level class functions that add/remove/duplicate an arbitrary set of indices for gaussians?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so these lines could then call self.duplicate_in_optimizer(index_mask) or something

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried once but didn't end up creating anything elegant. Open for suggestions!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions in splatfacto try to do something like this, so any subclass could also call dup_in_optim and have the same effect

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the body of the function would need to be different but the general interface of taking in a per-gaussian mask could work for removing/duplicating, or specifying how many to append

Copy link
Collaborator Author

@liruilong940607 liruilong940607 Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nerfstudio way of implementation is mixing the duplication, splitting, and removing operations together so that the remove on optimizer can be done once. But the downside is that it's hard to read and modify. I would like to keep each operation isolated. But yeah maybe there is a way to share some code between them. I just dont have a concrete implementation idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put reusable functions into gsplat/strategy/ops.py!

@liruilong940607
Copy link
Collaborator Author

@liruilong940607 liruilong940607 merged commit 68ff2c1 into main Jul 17, 2024
2 checks passed
@liruilong940607 liruilong940607 deleted the rtabrizi/refactor branch July 17, 2024 23:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants