Skip to content

Commit

Permalink
RGBA renderer for splatfacto (#3307)
Browse files Browse the repository at this point in the history
* Update base_model.py

* Update render.py

* Update render.py

* fixing error

* Update render.py

* Update render.py
  • Loading branch information
hardikdava authored Sep 9, 2024
1 parent f86dbe6 commit d67b281
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
6 changes: 1 addition & 5 deletions nerfstudio/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,7 @@ def get_rgba_image(self, outputs: Dict[str, torch.Tensor], output_name: str = "r
RGBA image.
"""
accumulation_name = output_name.replace("rgb", "accumulation")
if (
not hasattr(self, "renderer_rgb")
or not hasattr(self.renderer_rgb, "background_color")
or accumulation_name not in outputs
):
if accumulation_name not in outputs:
raise NotImplementedError(f"get_rgba_image is not implemented for model {self.__class__.__name__}")
rgb = outputs[output_name]
if self.renderer_rgb.background_color == "random": # type: ignore
Expand Down
11 changes: 10 additions & 1 deletion nerfstudio/scripts/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def _render_trajectory_video(
outputs = pipeline.model.get_outputs_for_camera(
cameras[camera_idx : camera_idx + 1], obb_box=obb_box
)
if rendered_output_names is not None and "rgba" in rendered_output_names:
rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb")
outputs["rgba"] = rgba

render_image = []
for rendered_output_name in rendered_output_names:
Expand All @@ -221,6 +224,8 @@ def _render_trajectory_video(
.cpu()
.numpy()
)
elif rendered_output_name == "rgba":
output_image = output_image.detach().cpu().numpy()
else:
output_image = (
colormaps.apply_colormap(
Expand Down Expand Up @@ -790,6 +795,9 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
for camera_idx, (camera, batch) in enumerate(progress.track(dataloader, total=len(dataset))):
with torch.no_grad():
outputs = pipeline.model.get_outputs_for_camera(camera)
if self.rendered_output_names is not None and "rgba" in self.rendered_output_names:
rgba = pipeline.model.get_rgba_image(outputs=outputs, output_name="rgb")
outputs["rgba"] = rgba

gt_batch = batch.copy()
gt_batch["rgb"] = gt_batch.pop("image")
Expand Down Expand Up @@ -841,11 +849,12 @@ def update_config(config: TrainerConfig) -> TrainerConfig:
output_image = gt_batch[output_name]
else:
output_image = outputs[output_name]
del output_name

# Map to color spaces / numpy
if is_raw:
output_image = output_image.cpu().numpy()
elif output_name == "rgba":
output_image = output_image.detach().cpu().numpy()
elif is_depth:
output_image = (
colormaps.apply_depth_colormap(
Expand Down

0 comments on commit d67b281

Please sign in to comment.