From 5277d2e53b228da8d24bf0618c43cbe479f233bb Mon Sep 17 00:00:00 2001 From: Tom Silver Date: Thu, 29 Jun 2023 14:20:15 -0400 Subject: [PATCH] hot fix: do not terminate when goal reached during online interaction --- predicators/main.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/predicators/main.py b/predicators/main.py index 3aae37fc65..0109f49355 100644 --- a/predicators/main.py +++ b/predicators/main.py @@ -255,6 +255,7 @@ def _generate_interaction_results( "train", request.train_task_idx, max_num_steps=CFG.max_num_steps_interaction_request, + terminate_on_goal_reached=False, exceptions_to_break_on={ utils.EnvironmentFailure, utils.OptionExecutionFailure, @@ -433,6 +434,7 @@ def _run_episode( task_idx: int, max_num_steps: int, do_env_reset: bool = True, + terminate_on_goal_reached: bool = True, exceptions_to_break_on: Optional[Set[TypingType[Exception]]] = None, monitor: Optional[utils.LoggingMonitor] = None ) -> Tuple[Tuple[List[Observation], List[Action]], bool, Metrics]: @@ -445,6 +447,7 @@ def _run_episode( (1) cogman.step returns None, indicating termination (2) max_num_steps is reached (3) cogman or env raise an exception of type in exceptions_to_break_on + (4) terminate_on_goal_reached is True and the env goal is reached. Note that in the case where the exception is raised in step, we exclude the last action from the returned trajectory to maintain the invariant that @@ -463,7 +466,7 @@ def _run_episode( metrics: Metrics = defaultdict(float) metrics["policy_call_time"] = 0.0 exception_raised_in_step = False - if not env.goal_reached(): + if not (terminate_on_goal_reached and env.goal_reached()): for _ in range(max_num_steps): monitor_observed = False exception_raised_in_step = False @@ -494,7 +497,7 @@ def _run_episode( if monitor is not None and not monitor_observed: monitor.observe(obs, None) raise e - if env.goal_reached(): + if terminate_on_goal_reached and env.goal_reached(): break if monitor is not None and not exception_raised_in_step: monitor.observe(obs, None)