Skip to content

Commit

Permalink
Added control over debug steps to model constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
kylevedder committed Apr 23, 2024
1 parent faf310e commit 77ac95e
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 7 deletions.
3 changes: 2 additions & 1 deletion configs/fast_nsf/argo/val_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
save_output_folder = "/efs/argoverse2_small/val_fast_nsf_flow_replicate/"

test_dataset = dict(
args=dict(root_dir=test_dataset_root, split=dict(split_idx=56 // 2, num_splits=314 // 2))
args=dict(root_dir=test_dataset_root, split=dict(split_idx=56 // 4, num_splits=314 // 4))
)
model = dict(args=dict(save_flow_every=10))
4 changes: 3 additions & 1 deletion configs/gigachad_nsf/argo/noncausal/val_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
test_dataset_root = "/efs/argoverse2_small/val/"
save_output_folder = "/efs/argoverse2_small/val_gigachad_nsf_flow_feather/"

test_dataset = dict(args=dict(root_dir=test_dataset_root))
test_dataset = dict(args=dict(root_dir=test_dataset_root, split=dict(split_idx=4, num_splits=20)))

model = dict(args=dict(save_flow_every=10))
9 changes: 8 additions & 1 deletion models/fast_nsf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from models.neural_reps import FastNSF, FastNSFPlusPlus
from .nsfp_model import NSFPModel
from pytorch_lightning.loggers import Logger
from typing import Optional


class FastNSFModel(NSFPModel):
Expand All @@ -29,8 +30,14 @@ def __init__(
patience: int = 100,
min_delta: float = 0.00005,
speed_threshold: float = 30.0 / 10.0, # 30 m/s cap
save_flow_every: Optional[int] = None,
) -> None:
super().__init__(iterations=iterations, patience=patience, min_delta=min_delta)
super().__init__(
iterations=iterations,
patience=patience,
min_delta=min_delta,
save_flow_every=save_flow_every,
)
self.speed_threshold = speed_threshold

def forward_single(
Expand Down
7 changes: 6 additions & 1 deletion models/nsfp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from models.optimization import OptimizationLoop
from models.neural_reps import NSFPCycleConsistency
from pytorch_lightning.loggers import Logger
from typing import Optional


class NSFPModel(BaseModel):
Expand All @@ -16,10 +17,14 @@ def __init__(
iterations: int = 5000,
patience: int = 100,
min_delta: float = 0.00005,
save_flow_every: Optional[int] = None,
) -> None:
super().__init__()
self.optimization_loop = OptimizationLoop(
iterations=iterations, min_delta=min_delta, patience=patience
iterations=iterations,
min_delta=min_delta,
patience=patience,
save_flow_every=save_flow_every,
)

def _validate_input(self, batched_sequence: list[BucketedSceneFlowInputSequence]) -> None:
Expand Down
5 changes: 3 additions & 2 deletions models/optimization/test_time_optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def __init__(
min_delta: float = 0.00005,
weight_decay: float = 0,
compile: bool = True,
save_flow_every: Optional[int] = None,
):
self.iterations = iterations
self.lr = lr
self.weight_decay = weight_decay
self.patience = patience
self.min_delta = min_delta
self.compile = compile
self.save_flow_every = save_flow_every

def _save_intermediary_results(
self,
Expand Down Expand Up @@ -62,7 +64,6 @@ def optimize(
min_delta: Optional[float] = None,
title: Optional[str] = "Optimizing Neur Rep",
leave: bool = False,
save_flow_every: Optional[int] = None,
) -> BucketedSceneFlowOutputSequence:
model = model.train()
if self.compile:
Expand Down Expand Up @@ -92,7 +93,7 @@ def optimize(
{f"log/{problem.sequence_log_id}/{problem.dataset_idx:06d}": cost.item()}, step=step
)

if save_flow_every is not None and step % save_flow_every == 0:
if self.save_flow_every is not None and step % self.save_flow_every == 0:
self._save_intermediary_results(model, problem, logger, step)

if cost.item() < lowest_cost:
Expand Down
4 changes: 3 additions & 1 deletion visualization/visualize_tto_itermediary_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def main():

dataset = construct_dataset(
name=args.dataset_name,
args=dict(root_dir=args.root_dir, subsequence_length=args.subsequence_length),
args=dict(
root_dir=args.root_dir, subsequence_length=args.subsequence_length, with_ground=False
),
)

visualizer = ResultsVisualizer(dataset, args.intermediary_results_folder)
Expand Down

0 comments on commit 77ac95e

Please sign in to comment.