Skip to content

Commit

Permalink
reverting full_images_datamanager to main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioMacaronio committed Oct 1, 2024
1 parent ba81e11 commit 1922566
Showing 1 changed file with 5 additions and 46 deletions.
51 changes: 5 additions & 46 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.dataparsers.nerfstudio_dataparser import NerfstudioDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.utils.data_utils import identity_collate
from nerfstudio.data.utils.dataloaders import ImageBatchStream
from nerfstudio.utils.misc import get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE

Expand Down Expand Up @@ -90,7 +88,7 @@ class FullImageDatamanagerConfig(DataManagerConfig):

def __post_init__(self):
if self.load_from_disk:
self.prefetch_factor = 2 if self.use_parallel_dataloader else None
self.prefetch_factor = 4 if self.use_parallel_dataloader else None

if self.use_parallel_dataloader:
try:
Expand Down Expand Up @@ -140,7 +138,6 @@ def __init__(
self.train_dataparser_outputs: DataparserOutputs = self.dataparser.get_dataparser_outputs(split="train")
self.train_dataset = self.create_train_dataset()
self.eval_dataset = self.create_eval_dataset()

if len(self.train_dataset) > 500 and self.config.cache_images == "gpu":
CONSOLE.print(
"Train dataset has over 500 images, overriding cache_images to cpu",
Expand Down Expand Up @@ -325,45 +322,15 @@ def get_datapath(self) -> Path:

def setup_train(self):
"""Sets up the data loaders for training"""
if self.config.use_parallel_dataloader:
self.train_imagebatch_stream = ImageBatchStream(
input_dataset=self.train_dataset,
datamanager_config=self.config,
device=self.device,
)
self.train_image_dataloader = torch.utils.data.DataLoader(
self.train_imagebatch_stream,
batch_size=1,
num_workers=self.config.dataloader_num_workers,
collate_fn=identity_collate,
# pin_memory_device=self.device, # for some reason if we pin memory, exporting to PLY file doesn't work?
)
self.iter_train_image_dataloader = iter(self.train_image_dataloader)

def setup_eval(self):
"""Sets up the data loader for evaluation"""
if self.config.use_parallel_dataloader:
self.eval_imagebatch_stream = ImageBatchStream(
input_dataset=self.eval_dataset,
datamanager_config=self.config,
device=self.device,
)
self.eval_image_dataloader = torch.utils.data.DataLoader(
self.eval_imagebatch_stream,
batch_size=1,
num_workers=self.config.dataloader_num_workers,
collate_fn=identity_collate,
# pin_memory_device=self.device,
)
self.iter_eval_image_dataloader = iter(self.eval_image_dataloader)

@property
def fixed_indices_eval_dataloader(self) -> List[Tuple[Cameras, Dict]]:
"""
Pretends to be the dataloader for evaluation, it returns a list of (camera, data) tuples
"""
if self.config.use_parallel_dataloader:
return self.iter_eval_image_dataloader
image_indices = [i for i in range(len(self.eval_dataset))]
data = [d.copy() for d in self.cached_eval]
_cameras = deepcopy(self.eval_dataset.cameras).to(self.device)
Expand Down Expand Up @@ -394,22 +361,18 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]:
"""Returns the next training batch
Returns a Camera instead of raybundle"""
self.train_count += 1
if self.config.use_parallel_dataloader:
camera, data = next(self.iter_train_image_dataloader)[0]
return camera, data

image_idx = self.train_unseen_cameras.pop(0)
# Make sure to re-populate the unseen cameras list if we have exhausted it
if len(self.train_unseen_cameras) == 0:
self.train_unseen_cameras = self.sample_train_cameras()

data = self.cached_train[image_idx]
# We're going to copy to make sure we don't mutate the cached dictionary.
# This can cause a memory leak: https://github.com/nerfstudio-project/nerfstudio/issues/3335
data = data.copy()
data["image"] = data["image"].to(self.device)

assert lDuden(self.train_cameras.shape) == 1, "Assumes single batch dimension"
assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension"
camera = self.train_cameras[image_idx : image_idx + 1].to(self.device)
if camera.metadata is None:
camera.metadata = {}
Expand All @@ -420,11 +383,6 @@ def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
"""Returns the next evaluation batch
Returns a Camera instead of raybundle"""
self.eval_count += 1
if self.config.use_parallel_dataloader:
camera, data = next(self.iter_train_image_dataloader)[0]
return camera, data

return self.next_eval_image(step=step)

def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
Expand All @@ -451,7 +409,7 @@ def _undistort_image(
mask = None
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
assert distortion_params[3] == 0, (
"We don't support the 4th Brown parameter for image undistortion, "
"We doesn't support the 4th Brown parameter for image undistortion, "
"Only k1, k2, k3, p1, p2 can be non-zero."
)
# because OpenCV expects the order of distortion parameters to be (k1, k2, p1, p2, k3), we need to reorder them
Expand Down Expand Up @@ -620,4 +578,5 @@ def _undistort_image(
K = undist_K.numpy()
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

return K, image, mask

0 comments on commit 1922566

Please sign in to comment.