Skip to content

Commit

Permalink
hot fix: do not terminate when goal reached during online interaction
Browse files Browse the repository at this point in the history
  • Loading branch information
tsilver-bdai committed Jun 29, 2023
1 parent c7e1d78 commit 5277d2e
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions predicators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def _generate_interaction_results(
"train",
request.train_task_idx,
max_num_steps=CFG.max_num_steps_interaction_request,
terminate_on_goal_reached=False,
exceptions_to_break_on={
utils.EnvironmentFailure,
utils.OptionExecutionFailure,
Expand Down Expand Up @@ -433,6 +434,7 @@ def _run_episode(
task_idx: int,
max_num_steps: int,
do_env_reset: bool = True,
terminate_on_goal_reached: bool = True,
exceptions_to_break_on: Optional[Set[TypingType[Exception]]] = None,
monitor: Optional[utils.LoggingMonitor] = None
) -> Tuple[Tuple[List[Observation], List[Action]], bool, Metrics]:
Expand All @@ -445,6 +447,7 @@ def _run_episode(
(1) cogman.step returns None, indicating termination
(2) max_num_steps is reached
(3) cogman or env raise an exception of type in exceptions_to_break_on
(4) terminate_on_goal_reached is True and the env goal is reached.
Note that in the case where the exception is raised in step, we exclude the
last action from the returned trajectory to maintain the invariant that
Expand All @@ -463,7 +466,7 @@ def _run_episode(
metrics: Metrics = defaultdict(float)
metrics["policy_call_time"] = 0.0
exception_raised_in_step = False
if not env.goal_reached():
if not (terminate_on_goal_reached and env.goal_reached()):
for _ in range(max_num_steps):
monitor_observed = False
exception_raised_in_step = False
Expand Down Expand Up @@ -494,7 +497,7 @@ def _run_episode(
if monitor is not None and not monitor_observed:
monitor.observe(obs, None)
raise e
if env.goal_reached():
if terminate_on_goal_reached and env.goal_reached():
break
if monitor is not None and not exception_raised_in_step:
monitor.observe(obs, None)
Expand Down

0 comments on commit 5277d2e

Please sign in to comment.