Skip to content

Commit

Permalink
Adjusting Instant NGP to Inner AABB
Browse files Browse the repository at this point in the history
  • Loading branch information
Anthony-Tafoya committed Sep 20, 2024
1 parent f86dbe6 commit f985925
Show file tree
Hide file tree
Showing 15 changed files with 244 additions and 309 deletions.
18 changes: 9 additions & 9 deletions nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,15 +778,15 @@ def _compute_rays_for_vr180(

return vr180_origins, directions_stack

for cam in cam_types:
if CameraType.PERSPECTIVE.value in cam_types:
for cam_type in cam_types:
if CameraType.PERSPECTIVE.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.PERSPECTIVE.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)
directions_stack[..., 0][mask] = torch.masked_select(coord_stack[..., 0], mask).float()
directions_stack[..., 1][mask] = torch.masked_select(coord_stack[..., 1], mask).float()
directions_stack[..., 2][mask] = -1.0

elif CameraType.FISHEYE.value in cam_types:
elif CameraType.FISHEYE.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.FISHEYE.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)

Expand All @@ -803,7 +803,7 @@ def _compute_rays_for_vr180(
).float()
directions_stack[..., 2][mask] = -torch.masked_select(torch.cos(theta), mask).float()

elif CameraType.EQUIRECTANGULAR.value in cam_types:
elif CameraType.EQUIRECTANGULAR.value == cam_type:
mask = (self.camera_type[true_indices] == CameraType.EQUIRECTANGULAR.value).squeeze(-1) # (num_rays)
mask = torch.stack([mask, mask, mask], dim=0)

Expand All @@ -816,22 +816,22 @@ def _compute_rays_for_vr180(
directions_stack[..., 1][mask] = torch.masked_select(torch.cos(phi), mask).float()
directions_stack[..., 2][mask] = torch.masked_select(-torch.cos(theta) * torch.sin(phi), mask).float()

elif CameraType.OMNIDIRECTIONALSTEREO_L.value in cam_types:
elif CameraType.OMNIDIRECTIONALSTEREO_L.value == cam_type:
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("left")
# assign final camera origins
c2w[..., :3, 3] = ods_origins_circle

elif CameraType.OMNIDIRECTIONALSTEREO_R.value in cam_types:
elif CameraType.OMNIDIRECTIONALSTEREO_R.value == cam_type:
ods_origins_circle, directions_stack = _compute_rays_for_omnidirectional_stereo("right")
# assign final camera origins
c2w[..., :3, 3] = ods_origins_circle

elif CameraType.VR180_L.value in cam_types:
elif CameraType.VR180_L.value == cam_type:
vr180_origins, directions_stack = _compute_rays_for_vr180("left")
# assign final camera origins
c2w[..., :3, 3] = vr180_origins

elif CameraType.VR180_R.value in cam_types:
elif CameraType.VR180_R.value == cam_type:
vr180_origins, directions_stack = _compute_rays_for_vr180("right")
# assign final camera origins
c2w[..., :3, 3] = vr180_origins
Expand Down Expand Up @@ -880,7 +880,7 @@ def _compute_rays_for_vr180(
directions_stack[coord_mask] = camera_utils.fisheye624_unproject(masked_coords, camera_params)

else:
raise ValueError(f"Camera type {cam} not supported.")
raise ValueError(f"Camera type {cam_type} not supported.")

assert directions_stack.shape == (3,) + num_rays_shape + (3,)

Expand Down
3 changes: 1 addition & 2 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,7 @@
),
model=SplatfactoModelConfig(
cull_alpha_thresh=0.005,
continue_cull_post_densification=False,
densify_grad_thresh=0.0006,
densify_grad_thresh=0.0005,
),
),
optimizers={
Expand Down
11 changes: 8 additions & 3 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class TrainerConfig(ExperimentConfig):
"""Optionally log gradients during training"""
gradient_accumulation_steps: Dict[str, int] = field(default_factory=lambda: {})
"""Number of steps to accumulate gradients over. Contains a mapping of {param_group:num}"""
start_paused: bool = False
"""Whether to start the training in a paused state."""


class Trainer:
Expand Down Expand Up @@ -121,7 +123,9 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =
self.device += f":{local_rank}"
self.mixed_precision: bool = self.config.mixed_precision
self.use_grad_scaler: bool = self.mixed_precision or self.config.use_grad_scaler
self.training_state: Literal["training", "paused", "completed"] = "training"
self.training_state: Literal["training", "paused", "completed"] = (
"paused" if self.config.start_paused else "training"
)
self.gradient_accumulation_steps: DefaultDict = defaultdict(lambda: 1)
self.gradient_accumulation_steps.update(self.config.gradient_accumulation_steps)

Expand Down Expand Up @@ -296,7 +300,8 @@ def train(self) -> None:

# Do not perform evaluation if there are no validation images
if self.pipeline.datamanager.eval_dataset:
self.eval_iteration(step)
with self.train_lock:
self.eval_iteration(step)

if step_check(step, self.config.steps_per_save):
self.save_checkpoint(step)
Expand Down Expand Up @@ -361,7 +366,7 @@ def _init_viewer_state(self) -> None:
assert self.viewer_state and self.pipeline.datamanager.train_dataset
self.viewer_state.init_scene(
train_dataset=self.pipeline.datamanager.train_dataset,
train_state="training",
train_state=self.training_state,
eval_dataset=self.pipeline.datamanager.eval_dataset,
)

Expand Down
10 changes: 5 additions & 5 deletions nerfstudio/exporter/exporter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,11 @@ def generate_point_cloud(

if crop_obb is not None:
mask = crop_obb.within(point)
point = point[mask]
rgb = rgb[mask]
view_direction = view_direction[mask]
if normal is not None:
normal = normal[mask]
point = point[mask]
rgb = rgb[mask]
view_direction = view_direction[mask]
if normal is not None:
normal = normal[mask]

points.append(point)
rgbs.append(rgb)
Expand Down
6 changes: 1 addition & 5 deletions nerfstudio/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,7 @@ def get_rgba_image(self, outputs: Dict[str, torch.Tensor], output_name: str = "r
RGBA image.
"""
accumulation_name = output_name.replace("rgb", "accumulation")
if (
not hasattr(self, "renderer_rgb")
or not hasattr(self.renderer_rgb, "background_color")
or accumulation_name not in outputs
):
if accumulation_name not in outputs:
raise NotImplementedError(f"get_rgba_image is not implemented for model {self.__class__.__name__}")
rgb = outputs[output_name]
if self.renderer_rgb.background_color == "random": # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/models/instant_ngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def populate_modules(self):
self.config.render_step_size = ((self.scene_aabb[3:] - self.scene_aabb[:3]) ** 2).sum().sqrt().item() / 1000
# Occupancy Grid.
self.occupancy_grid = nerfacc.OccGridEstimator(
roi_aabb=self.scene_aabb,
roi_aabb=self.scene_aabb * 2 ** -(levels - 1),
resolution=self.config.grid_resolution,
levels=self.config.grid_levels,
)
Expand Down
Loading

0 comments on commit f985925

Please sign in to comment.