Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move camera optimization out of datamanager and parallelize dataloading #2092

Merged
merged 83 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
b95daad
Increase max_res of first proposal network
kerrj Jan 24, 2023
bf0f2e5
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Jan 27, 2023
d2698e0
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Jan 28, 2023
fdd3c72
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Feb 1, 2023
85491fa
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Feb 9, 2023
bae8e5c
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Feb 16, 2023
bb02ae0
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Feb 22, 2023
03efc6a
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Mar 14, 2023
6f89617
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Mar 18, 2023
03a4a45
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Mar 23, 2023
4a63d6b
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Mar 27, 2023
a886dc4
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 2, 2023
ea2378e
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 4, 2023
1b1bb22
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 5, 2023
06825f4
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 14, 2023
4b040a1
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 16, 2023
0f56cc6
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 17, 2023
134ff3e
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 21, 2023
4ac02b3
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 22, 2023
b421b2e
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 22, 2023
1ebfaed
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 24, 2023
9b86877
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Apr 24, 2023
a259f93
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj May 18, 2023
7e32c8c
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj May 19, 2023
9db295a
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj May 19, 2023
86135cf
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj May 22, 2023
04f7af5
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj May 25, 2023
09c24de
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj May 26, 2023
c06e2b5
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Jun 11, 2023
4697122
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Jun 15, 2023
d67e176
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Jun 16, 2023
9db9777
fix render statemachine logic for viewer beta
kerrj Jun 16, 2023
efd1826
clicking on cameras in the viewer
kerrj Jun 16, 2023
cb66e31
fix fov calculation, update camera positions while training
kerrj Jun 18, 2023
0a9eb8a
cleanup
kerrj Jun 18, 2023
1182b18
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Jun 18, 2023
ae5ef15
rip out camera optimization from datamanager and move to nerfacto.py
kerrj Jun 18, 2023
be47ffc
add l2 penalties on optimizers
kerrj Jun 18, 2023
2fb9cd7
Merge branch 'justin/camera_clicking_beta' into justin/camera_opt_ref…
kerrj Jun 18, 2023
75153f6
update viewer beta vis
kerrj Jun 18, 2023
b9feec2
param
kerrj Jun 18, 2023
767dc10
param
kerrj Jun 20, 2023
caf50a7
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Jun 20, 2023
807837a
merge main
kerrj Jun 20, 2023
f6b6ed8
ensure that non_trainable_indices are on the correct device during tr…
maturk Jun 21, 2023
e26bd4d
initial impl of parallel data manager
kerrj Jun 21, 2023
75d1013
merge
kerrj Jun 21, 2023
09e9a71
minor speedup
kerrj Jun 21, 2023
a73d892
params
kerrj Jun 21, 2023
1d8cef2
clean
kerrj Jun 21, 2023
3f73e78
oops
kerrj Jun 21, 2023
10fc5f3
explicitly catch full queue, notify of other exceptions
maturk Jun 26, 2023
7121bc8
docstring
maturk Jun 27, 2023
5f85ef1
data queue as config parameter
maturk Jun 27, 2023
d1350e1
max num threads as config parameter
maturk Jun 27, 2023
4c3db79
barf
maturk Jun 30, 2023
37e763f
optional pose noise in datamanager
maturk Jun 30, 2023
2e338a2
refactor + docstrings + eval dataset in main python process
maturk Jul 10, 2023
6ebdab9
add getters for param groups, metrics dict, and correction matrices
maturk Jul 20, 2023
3e059a4
refactor to use camera opt getters
maturk Jul 20, 2023
dff0d32
ruff format, license, + docstrings
maturk Jul 20, 2023
e0c2451
type checker fixes
maturk Jul 20, 2023
4ddbee9
type ignore
maturk Jul 20, 2023
aa7fd51
update method configs
maturk Jul 22, 2023
f2971b5
camera opt in tensorf
maturk Jul 22, 2023
483a6c2
bump viser version for viewer_beta
maturk Jul 22, 2023
dc1b809
remove barf
kerrj Aug 2, 2023
b171e68
merge main
kerrj Aug 2, 2023
017c274
typo
kerrj Aug 2, 2023
b6a3fa4
config inherit from base_dm, remove more camera opt stuff from datama…
kerrj Aug 2, 2023
f0574cc
Deprication Warning for Camera Optimization
maturk Aug 5, 2023
818c057
mipnerf in parallel
maturk Aug 7, 2023
83e93e0
rename num_cameras -> camera_indices
maturk Aug 10, 2023
fae69c2
formatting deprecation warnings more, throw FutureWarning to track wh…
kerrj Aug 28, 2023
f80d56e
lint
kerrj Aug 28, 2023
70d8662
similar warning for schedulers in cameraoptconfig
maturk Sep 4, 2023
084f725
lint, merge
kerrj Sep 25, 2023
7010f0d
Merge branch 'main' into justin/camera_opt_refactor
brentyi Sep 25, 2023
1cff893
merge
kerrj Oct 10, 2023
18d953a
spawn method in wrapper statement to prevent setting twice
kerrj Oct 10, 2023
60c6a5b
lint
kerrj Oct 10, 2023
ed45a53
fix pytest pin memory crash
kerrj Oct 10, 2023
33cc3c7
Merge branch 'main' into justin/camera_opt_refactor
tancik Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 68 additions & 31 deletions nerfstudio/cameras/camera_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,16 @@
from typing import Literal, Optional, Type, Union

import torch
import tyro
from jaxtyping import Float, Int
from torch import Tensor, nn
from typing_extensions import assert_never

from nerfstudio.cameras.lie_groups import exp_map_SE3, exp_map_SO3xR3
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.engine.optimizers import AdamOptimizerConfig, OptimizerConfig
from nerfstudio.engine.schedulers import (
ExponentialDecaySchedulerConfig,
SchedulerConfig,
)
from nerfstudio.utils import poses as pose_utils
from nerfstudio.engine.optimizers import OptimizerConfig
from nerfstudio.engine.schedulers import SchedulerConfig


@dataclass
Expand All @@ -47,21 +44,38 @@ class CameraOptimizerConfig(InstantiateConfig):
mode: Literal["off", "SO3xR3", "SE3"] = "off"
"""Pose optimization strategy to use. If enabled, we recommend SO3xR3."""

position_noise_std: float = 0.0
"""Noise to add to initial positions. Useful for debugging."""
trans_l2_penalty: float = 1e-2
"""L2 penalty on translation parameters."""

orientation_noise_std: float = 0.0
"""Noise to add to initial orientations. Useful for debugging."""
rot_l2_penalty: float = 1e-3
"""L2 penalty on rotation parameters."""

optimizer: OptimizerConfig = field(default_factory=lambda: AdamOptimizerConfig(lr=6e-4, eps=1e-15))
"""ADAM parameters for camera optimization."""
optimizer: Optional[OptimizerConfig] = field(default=None)
"""Deprecated, now specified inside the optimizers dict"""

scheduler: SchedulerConfig = field(default_factory=lambda: ExponentialDecaySchedulerConfig(max_steps=10000))
"""Learning rate scheduler for camera optimizer.."""
scheduler: Optional[SchedulerConfig] = field(default=None)
"""Deprecated, now specified inside the optimizers dict"""

param_group: tyro.conf.Suppress[str] = "camera_opt"
"""Name of the parameter group used for pose optimization. Can be any string that doesn't conflict with other
groups."""
def __post_init__(self):
if self.optimizer is not None:
import warnings
from nerfstudio.utils.rich_utils import CONSOLE

CONSOLE.print(
"\noptimizer is no longer specified in the CameraOptimizerConfig, it is now defined with the rest of the param groups inside the config file under the name 'camera_opt'\n",
style="bold yellow",
)
warnings.warn("above message coming from", FutureWarning, stacklevel=3)

if self.scheduler is not None:
import warnings
from nerfstudio.utils.rich_utils import CONSOLE

CONSOLE.print(
"\nscheduler is no longer specified in the CameraOptimizerConfig, it is now defined with the rest of the param groups inside the config file under the name 'camera_opt'\n",
style="bold yellow",
)
warnings.warn("above message coming from", FutureWarning, stacklevel=3)


class CameraOptimizer(nn.Module):
Expand Down Expand Up @@ -91,16 +105,6 @@ def __init__(
else:
assert_never(self.config.mode)

# Initialize pose noise; useful for debugging.
if config.position_noise_std != 0.0 or config.orientation_noise_std != 0.0:
assert config.position_noise_std >= 0.0 and config.orientation_noise_std >= 0.0
std_vector = torch.tensor(
[config.position_noise_std] * 3 + [config.orientation_noise_std] * 3, device=device
)
self.pose_noise = exp_map_SE3(torch.normal(torch.zeros((num_cameras, 6), device=device), std_vector))
else:
self.pose_noise = None

def forward(
self,
indices: Int[Tensor, "camera_indices"],
Expand All @@ -125,13 +129,46 @@ def forward(
assert_never(self.config.mode)
# Detach non-trainable indices by setting to identity transform
if self.non_trainable_camera_indices is not None:
outputs[0][self.non_trainable_camera_indices] = torch.eye(4, device=self.device)[:3, :4]
if self.non_trainable_camera_indices.device != self.pose_adjustment.device:
self.non_trainable_camera_indices = self.non_trainable_camera_indices.to(self.pose_adjustment.device)
outputs[0][self.non_trainable_camera_indices] = torch.eye(4, device=self.pose_adjustment.device)[:3, :4]

# Apply initial pose noise.
if self.pose_noise is not None:
outputs.append(self.pose_noise[indices, :, :])
# Return: identity if no transforms are needed, otherwise multiply transforms together.
if len(outputs) == 0:
# Note that using repeat() instead of tile() here would result in unnecessary copies.
return torch.eye(4, device=self.device)[None, :3, :4].tile(indices.shape[0], 1, 1)
return functools.reduce(pose_utils.multiply, outputs)

def apply_to_raybundle(self, raybundle: RayBundle) -> None:
"""Apply the pose correction to the raybundle"""
if self.config.mode != "off":
correction_matrices = self(raybundle.camera_indices.squeeze()) # type: ignore
raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3]
raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze()

def get_loss_dict(self, loss_dict: dict) -> None:
"""Add regularization"""
if self.config.mode != "off":
loss_dict["camera_opt_regularizer"] = (
self.pose_adjustment[:, :3].norm(dim=-1).mean() * self.config.trans_l2_penalty
+ self.pose_adjustment[:, 3:].norm(dim=-1).mean() * self.config.rot_l2_penalty
)

def get_correction_matrices(self):
"""Get optimized pose correction matrices"""
return self(torch.arange(0, self.num_cameras).long())

def get_metrics_dict(self, metrics_dict: dict) -> None:
"""Get camera optimizer metrics"""
if self.config.mode != "off":
metrics_dict["camera_opt_translation"] = self.pose_adjustment[:, :3].norm()
metrics_dict["camera_opt_rotation"] = self.pose_adjustment[:, 3:].norm()

def get_param_groups(self, param_groups: dict) -> None:
"""Get camera optimizer parameters"""
camera_opt_params = list(self.parameters())
if self.config.mode != "off":
assert len(camera_opt_params) > 0
param_groups["camera_opt"] = camera_opt_params
else:
assert len(camera_opt_params) == 0
Loading
Loading