Skip to content

Commit

Permalink
[RLlib] Terminated/Truncated in Quickstart Script. (#31386)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArturNiederfahrenhorst authored Jan 5, 2023
1 parent fba15f6 commit e618c8c
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions rllib/examples/documentation/rllib_on_ray_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,33 @@ def __init__(self, config):
self.observation_space = gym.spaces.Box(0.0, self.end_pos, shape=(1,))

def reset(self, *, seed=None, options=None):
"""Resets the episode and returns the initial observation of the new one."""
"""Resets the episode.
Returns:
Initial observation of the new episode and an info dict.
"""
self.cur_pos = 0
# Return initial observation.
return [self.cur_pos], {}

def step(self, action):
"""Takes a single step in the episode given `action`
"""Takes a single step in the episode given `action`.
Returns:
New observation, reward, done-flag, info-dict (empty).
New observation, reward, terminated-flag, truncated-flag, info-dict (empty).
"""
# Walk left.
if action == 0 and self.cur_pos > 0:
self.cur_pos -= 1
# Walk right.
elif action == 1:
self.cur_pos += 1
# Set `done` and `truncated` flags when end of corridor (goal) reached.
done = truncated = self.cur_pos >= self.end_pos
# Set `terminated` flag when end of corridor (goal) reached.
terminated = self.cur_pos >= self.end_pos
truncated = False
# +1 when goal reached, otherwise -1.
reward = 1.0 if done else -0.1
return [self.cur_pos], reward, done, truncated, {}
reward = 1.0 if terminated else -0.1
return [self.cur_pos], reward, terminated, truncated, {}


# Create an RLlib Algorithm instance from a PPOConfig object.
Expand Down Expand Up @@ -78,15 +83,15 @@ def step(self, action):
env = SimpleCorridor({"corridor_length": 10})
# Get the initial observation (should be: [0.0] for the starting position).
obs, info = env.reset()
done = False
terminated = truncated = False
total_reward = 0.0
# Play one episode.
while not done:
while not terminated and not truncated:
# Compute a single action, given the current observation
# from the environment.
action = algo.compute_single_action(obs)
# Apply the computed action in the environment.
obs, reward, done, truncated, info = env.step(action)
obs, reward, terminated, truncated, info = env.step(action)
# Sum up rewards for reporting purposes.
total_reward += reward
# Report results.
Expand Down

0 comments on commit e618c8c

Please sign in to comment.