From 77ac95e6f4a503e6ac8502e8c6979744ed131168 Mon Sep 17 00:00:00 2001 From: Kyle Vedder Date: Tue, 23 Apr 2024 18:31:58 -0400 Subject: [PATCH] Added control over debug steps to model constructor --- configs/fast_nsf/argo/val_debug.py | 3 ++- configs/gigachad_nsf/argo/noncausal/val_debug.py | 4 +++- models/fast_nsf_model.py | 9 ++++++++- models/nsfp_model.py | 7 ++++++- models/optimization/test_time_optimizer_loop.py | 5 +++-- visualization/visualize_tto_itermediary_results.py | 4 +++- 6 files changed, 25 insertions(+), 7 deletions(-) diff --git a/configs/fast_nsf/argo/val_debug.py b/configs/fast_nsf/argo/val_debug.py index f1eb3b9..f555729 100644 --- a/configs/fast_nsf/argo/val_debug.py +++ b/configs/fast_nsf/argo/val_debug.py @@ -4,5 +4,6 @@ save_output_folder = "/efs/argoverse2_small/val_fast_nsf_flow_replicate/" test_dataset = dict( - args=dict(root_dir=test_dataset_root, split=dict(split_idx=56 // 2, num_splits=314 // 2)) + args=dict(root_dir=test_dataset_root, split=dict(split_idx=56 // 4, num_splits=314 // 4)) ) +model = dict(args=dict(save_flow_every=10)) diff --git a/configs/gigachad_nsf/argo/noncausal/val_debug.py b/configs/gigachad_nsf/argo/noncausal/val_debug.py index b916330..3a3388e 100644 --- a/configs/gigachad_nsf/argo/noncausal/val_debug.py +++ b/configs/gigachad_nsf/argo/noncausal/val_debug.py @@ -3,4 +3,6 @@ test_dataset_root = "/efs/argoverse2_small/val/" save_output_folder = "/efs/argoverse2_small/val_gigachad_nsf_flow_feather/" -test_dataset = dict(args=dict(root_dir=test_dataset_root)) +test_dataset = dict(args=dict(root_dir=test_dataset_root, split=dict(split_idx=4, num_splits=20))) + +model = dict(args=dict(save_flow_every=10)) diff --git a/models/fast_nsf_model.py b/models/fast_nsf_model.py index 56d9464..7b7d45a 100644 --- a/models/fast_nsf_model.py +++ b/models/fast_nsf_model.py @@ -4,6 +4,7 @@ from models.neural_reps import FastNSF, FastNSFPlusPlus from .nsfp_model import NSFPModel from pytorch_lightning.loggers import Logger +from typing import Optional class FastNSFModel(NSFPModel): @@ -29,8 +30,14 @@ def __init__( patience: int = 100, min_delta: float = 0.00005, speed_threshold: float = 30.0 / 10.0, # 30 m/s cap + save_flow_every: Optional[int] = None, ) -> None: - super().__init__(iterations=iterations, patience=patience, min_delta=min_delta) + super().__init__( + iterations=iterations, + patience=patience, + min_delta=min_delta, + save_flow_every=save_flow_every, + ) self.speed_threshold = speed_threshold def forward_single( diff --git a/models/nsfp_model.py b/models/nsfp_model.py index cb31e73..c1f08df 100644 --- a/models/nsfp_model.py +++ b/models/nsfp_model.py @@ -8,6 +8,7 @@ from models.optimization import OptimizationLoop from models.neural_reps import NSFPCycleConsistency from pytorch_lightning.loggers import Logger +from typing import Optional class NSFPModel(BaseModel): @@ -16,10 +17,14 @@ def __init__( iterations: int = 5000, patience: int = 100, min_delta: float = 0.00005, + save_flow_every: Optional[int] = None, ) -> None: super().__init__() self.optimization_loop = OptimizationLoop( - iterations=iterations, min_delta=min_delta, patience=patience + iterations=iterations, + min_delta=min_delta, + patience=patience, + save_flow_every=save_flow_every, ) def _validate_input(self, batched_sequence: list[BucketedSceneFlowInputSequence]) -> None: diff --git a/models/optimization/test_time_optimizer_loop.py b/models/optimization/test_time_optimizer_loop.py index 67f3a07..bd222a4 100644 --- a/models/optimization/test_time_optimizer_loop.py +++ b/models/optimization/test_time_optimizer_loop.py @@ -21,6 +21,7 @@ def __init__( min_delta: float = 0.00005, weight_decay: float = 0, compile: bool = True, + save_flow_every: Optional[int] = None, ): self.iterations = iterations self.lr = lr @@ -28,6 +29,7 @@ def __init__( self.patience = patience self.min_delta = min_delta self.compile = compile + self.save_flow_every = save_flow_every def _save_intermediary_results( self, @@ -62,7 +64,6 @@ def optimize( min_delta: Optional[float] = None, title: Optional[str] = "Optimizing Neur Rep", leave: bool = False, - save_flow_every: Optional[int] = None, ) -> BucketedSceneFlowOutputSequence: model = model.train() if self.compile: @@ -92,7 +93,7 @@ def optimize( {f"log/{problem.sequence_log_id}/{problem.dataset_idx:06d}": cost.item()}, step=step ) - if save_flow_every is not None and step % save_flow_every == 0: + if self.save_flow_every is not None and step % self.save_flow_every == 0: self._save_intermediary_results(model, problem, logger, step) if cost.item() < lowest_cost: diff --git a/visualization/visualize_tto_itermediary_results.py b/visualization/visualize_tto_itermediary_results.py index defc9f6..6dcb0f1 100644 --- a/visualization/visualize_tto_itermediary_results.py +++ b/visualization/visualize_tto_itermediary_results.py @@ -146,7 +146,9 @@ def main(): dataset = construct_dataset( name=args.dataset_name, - args=dict(root_dir=args.root_dir, subsequence_length=args.subsequence_length), + args=dict( + root_dir=args.root_dir, subsequence_length=args.subsequence_length, with_ground=False + ), ) visualizer = ResultsVisualizer(dataset, args.intermediary_results_folder)