From 726ba0660164f4ea9c3c1aae3583e2d7f64e5b96 Mon Sep 17 00:00:00 2001 From: Edilmo Palencia Date: Thu, 1 Oct 2020 09:17:30 -0700 Subject: [PATCH 1/6] Making states available in SampleBatch and ReplayBuffer --- rllib/evaluation/sampler.py | 1 + rllib/execution/replay_buffer.py | 25 +++++++++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 7cb23d8168ec..e74641ed76bd 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -674,6 +674,7 @@ def _process_observations(self, agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, + states=episode.rnn_state_for(agent_id), **episode.last_pi_info_for(agent_id)) # Invoke the step callback after the step is logged to the episode diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index 25d6e57d6b44..18d450c62749 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -41,8 +41,8 @@ def __len__(self): return len(self._storage) @DeveloperAPI - def add(self, obs_t, action, reward, obs_tp1, done, weight): - data = (obs_t, action, reward, obs_tp1, done) + def add(self, obs_t, action, reward, obs_tp1, done, state, weight): + data = (obs_t, action, reward, obs_tp1, done, state) self._num_added += 1 if self._next_idx >= len(self._storage): @@ -58,18 +58,19 @@ def add(self, obs_t, action, reward, obs_tp1, done, weight): self._hit_count[self._next_idx] = 0 def _encode_sample(self, idxes): - obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] + obses_t, actions, rewards, obses_tp1, dones, states = [], [], [], [], [], [] for i in idxes: data = self._storage[i] - obs_t, action, reward, obs_tp1, done = data + obs_t, action, reward, obs_tp1, done, state = data obses_t.append(np.array(unpack_if_needed(obs_t), copy=False)) actions.append(np.array(action, copy=False)) rewards.append(reward) obses_tp1.append(np.array(unpack_if_needed(obs_tp1), copy=False)) dones.append(done) + states.append(state) self._hit_count[i] += 1 return (np.array(obses_t), np.array(actions), np.array(rewards), - np.array(obses_tp1), np.array(dones)) + np.array(obses_tp1), np.array(dones), np.array(states)) @DeveloperAPI def sample_idxes(self, batch_size): @@ -156,12 +157,13 @@ def __init__(self, size, alpha): self._prio_change_stats = WindowStat("reprio", 1000) @DeveloperAPI - def add(self, obs_t, action, reward, obs_tp1, done, weight): + def add(self, obs_t, action, reward, obs_tp1, done, state, weight): """See ReplayBuffer.store_effect""" idx = self._next_idx super(PrioritizedReplayBuffer, self).add(obs_t, action, reward, - obs_tp1, done, weight) + obs_tp1, done, state, + weight) if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha @@ -356,8 +358,8 @@ def add_batch(self, batch): for row in s.rows(): self.replay_buffers[policy_id].add( row["obs"], row["actions"], row["rewards"], - row["new_obs"], row["dones"], row["weights"] - if "weights" in row else None) + row["new_obs"], row["dones"], row["states"], + row["weights"] if "weights" in row else None) self.num_added += batch.count def replay(self): @@ -380,7 +382,7 @@ def replay(self): self.replay_batch_size) else: idxes = replay_buffer.sample_idxes(self.replay_batch_size) - (obses_t, actions, rewards, obses_tp1, dones, weights, + (obses_t, actions, rewards, obses_tp1, dones, states, weights, batch_indexes) = replay_buffer.sample_with_idxes( idxes, beta=self.prioritized_replay_beta) samples[policy_id] = SampleBatch({ @@ -390,6 +392,7 @@ def replay(self): "new_obs": obses_tp1, "dones": dones, "weights": weights, + "states": states, "batch_indexes": batch_indexes }) return MultiAgentBatch(samples, self.replay_batch_size) @@ -495,6 +498,7 @@ def replay(self): rewards, obses_tp1, dones, + states, ) = replay_buffer.sample_with_idxes(idxes) samples[policy_id] = SampleBatch( { @@ -503,6 +507,7 @@ def replay(self): "rewards": rewards, "new_obs": obses_tp1, "dones": dones, + "states": states, } ) return MultiAgentBatch(samples, self.replay_batch_size) From a08fc780e336dc262f6ffda4611af80745f11866 Mon Sep 17 00:00:00 2001 From: Edilmo Palencia Date: Sun, 4 Oct 2020 15:36:57 -0700 Subject: [PATCH 2/6] RNN support in DQN --- rllib/agents/dqn/dqn_tf_policy.py | 60 ++++++++++--- rllib/evaluation/sampler.py | 1 - rllib/execution/replay_buffer.py | 89 ++++++++++++++----- ...output-2019-02-03_20-27-20_worker-0_0.json | 6 +- 4 files changed, 119 insertions(+), 37 deletions(-) diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index 8587a4321cbc..0802ae284f8f 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -111,8 +111,15 @@ def __init__(self, class ComputeTDErrorMixin: def __init__(self): @make_tf_callable(self.get_session(), dynamic_shape=True) - def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + def compute_td_error(obs_t, states_t, act_t, rew_t, obs_tp1, done_mask, + states_tp1, seq_lens, importance_weights): + state_in_out_and_seq_lens = {} + for i, h in enumerate(states_t): + state_in_out_and_seq_lens["state_in_{}".format(i)] = h + for i, h in enumerate(states_tp1): + state_in_out_and_seq_lens["state_out_{}".format(i)] = h + if state_in_out_and_seq_lens: + state_in_out_and_seq_lens["seq_lens"] = seq_lens # Do forward pass on loss to update td error attribute build_q_losses( self, self.model, None, { @@ -122,6 +129,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1), SampleBatch.DONES: tf.convert_to_tensor(done_mask), PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights), + **state_in_out_and_seq_lens }) return self.q_loss.td_error @@ -190,10 +198,12 @@ def build_q_model(policy, obs_space, action_space, config): def get_distribution_inputs_and_class(policy, model, obs_batch, + state_batches, + seq_lens, *, explore=True, **kwargs): - q_vals = compute_q_values(policy, model, obs_batch, explore) + q_vals = compute_q_values(policy, model, obs_batch, state_batches, seq_lens, explore) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_values = q_vals @@ -202,12 +212,25 @@ def get_distribution_inputs_and_class(policy, def build_q_losses(policy, model, _, train_batch): + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else None config = policy.config # q network evaluation q_t, q_logits_t, q_dist_t = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.CUR_OBS], + states_in, + seq_lens, explore=False) # target q network evalution @@ -215,6 +238,8 @@ def build_q_losses(policy, model, _, train_batch): policy, policy.target_q_model, train_batch[SampleBatch.NEXT_OBS], + states_out, + seq_lens, explore=False) policy.target_q_func_vars = policy.target_q_model.variables() @@ -232,6 +257,8 @@ def build_q_losses(policy, model, _, train_batch): q_dist_tp1_using_online_net = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.NEXT_OBS], + states_out, + seq_lens, explore=False) q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net, @@ -293,13 +320,13 @@ def setup_late_mixins(policy, obs_space, action_space, config): TargetNetworkMixin.__init__(policy, obs_space, action_space, config) -def compute_q_values(policy, model, obs, explore): +def compute_q_values(policy, model, obs, states, seq_lens, explore): config = policy.config model_out, state = model({ SampleBatch.CUR_OBS: obs, "is_training": policy._get_is_training_placeholder(), - }, [], None) + }, states, seq_lens) if config["num_atoms"] > 1: (action_scores, z, support_logits_per_action, logits, @@ -335,7 +362,8 @@ def compute_q_values(policy, model, obs, explore): return value, logits, dist -def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): +def _adjust_nstep(n_step, gamma, obs, states_in, actions, rewards, new_obs, dones, + states_out, seq_lens): """Rewrites the given trajectory fragments to encode n-step rewards. reward[i] = ( @@ -358,15 +386,27 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): new_obs[i] = new_obs[i + j] dones[i] = dones[i + j] rewards[i] += gamma**j * rewards[i + j] + states_out[i] = states_out[i + j] def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None): + states_in = [] + i = 0 + while "state_in_{}".format(i) in batch: + states_in.append(batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in batch: + states_out.append(batch["state_out_{}".format(i)]) + i += 1 + seq_lens = batch["seq_lens"] if "seq_lens" in batch else None # N-step Q adjustments if policy.config["n_step"] > 1: _adjust_nstep(policy.config["n_step"], policy.config["gamma"], - batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS], + batch[SampleBatch.CUR_OBS], states_in, batch[SampleBatch.ACTIONS], batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS], - batch[SampleBatch.DONES]) + batch[SampleBatch.DONES], states_out, seq_lens) if PRIO_WEIGHTS not in batch: batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS]) @@ -379,9 +419,9 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None): elif isinstance(policy.action_space, Discrete) and actions.shape[-1] == 1: actions = np.reshape(actions, [-1]) td_errors = policy.compute_td_error( - batch[SampleBatch.CUR_OBS], actions, + batch[SampleBatch.CUR_OBS], states_in, actions, batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS], - batch[SampleBatch.DONES], batch[PRIO_WEIGHTS]) + batch[SampleBatch.DONES], states_out, seq_lens, batch[PRIO_WEIGHTS]) new_priorities = ( np.abs(td_errors) + policy.config["prioritized_replay_eps"]) batch.data[PRIO_WEIGHTS] = new_priorities diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index e74641ed76bd..7cb23d8168ec 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -674,7 +674,6 @@ def _process_observations(self, agent_done), infos=infos[env_id].get(agent_id, {}), new_obs=filtered_obs, - states=episode.rnn_state_for(agent_id), **episode.last_pi_info_for(agent_id)) # Invoke the step callback after the step is logged to the episode diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index 18d450c62749..1e48b8ca20f5 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -1,3 +1,5 @@ +from typing import Optional, Iterator + import numpy as np import random import collections @@ -41,8 +43,8 @@ def __len__(self): return len(self._storage) @DeveloperAPI - def add(self, obs_t, action, reward, obs_tp1, done, state, weight): - data = (obs_t, action, reward, obs_tp1, done, state) + def add(self, obs_t, state_t, action, reward, obs_tp1, done, state_tp1, seq_lens, weight): + data = (obs_t, state_t, action, reward, obs_tp1, done, state_tp1, seq_lens) self._num_added += 1 if self._next_idx >= len(self._storage): @@ -58,19 +60,24 @@ def add(self, obs_t, action, reward, obs_tp1, done, state, weight): self._hit_count[self._next_idx] = 0 def _encode_sample(self, idxes): - obses_t, actions, rewards, obses_tp1, dones, states = [], [], [], [], [], [] + obses_t, states_t, actions, rewards, obses_tp1, dones, states_tp1, seq_lens = [], [], [], [], [], [], [], [] for i in idxes: data = self._storage[i] - obs_t, action, reward, obs_tp1, done, state = data + obs_t, state_t, action, reward, obs_tp1, done, state_tp1, seq_len = data obses_t.append(np.array(unpack_if_needed(obs_t), copy=False)) + states_t.append(state_t) actions.append(np.array(action, copy=False)) rewards.append(reward) obses_tp1.append(np.array(unpack_if_needed(obs_tp1), copy=False)) dones.append(done) - states.append(state) + states_tp1.append(state_tp1) + seq_lens.append(seq_len) self._hit_count[i] += 1 - return (np.array(obses_t), np.array(actions), np.array(rewards), - np.array(obses_tp1), np.array(dones), np.array(states)) + states_t = [np.array(h) for h in states_t if h] + states_tp1 = [np.array(h) for h in states_tp1 if h] + seq_lens = np.array(seq_lens) if all([sl is not None for sl in seq_lens]) else np.array([]) + return (np.array(obses_t), states_t, np.array(actions), np.array(rewards), + np.array(obses_tp1), np.array(dones), states_tp1, seq_lens) @DeveloperAPI def sample_idxes(self, batch_size): @@ -157,13 +164,13 @@ def __init__(self, size, alpha): self._prio_change_stats = WindowStat("reprio", 1000) @DeveloperAPI - def add(self, obs_t, action, reward, obs_tp1, done, state, weight): + def add(self, obs_t, state_t, action, reward, obs_tp1, done, state_tp1, seq_lens, weight): """See ReplayBuffer.store_effect""" idx = self._next_idx - super(PrioritizedReplayBuffer, self).add(obs_t, action, reward, - obs_tp1, done, state, - weight) + super(PrioritizedReplayBuffer, self).add(obs_t, state_t, action, reward, + obs_tp1, done, state_tp1, + seq_lens, weight) if weight is None: weight = self._max_priority self._it_sum[idx] = weight**self._alpha @@ -292,6 +299,17 @@ def stats(self, debug=False): _local_replay_buffer = None +class _GenReplay(Iterator[Optional[MultiAgentBatch]]): + def __init__(self, parent_buffer: "LocalReplayBuffer"): + self.parent_buffer = parent_buffer + + def __iter__(self) -> Iterator[Optional[MultiAgentBatch]]: + return self + + def __next__(self) -> Optional[MultiAgentBatch]: + return self.parent_buffer.replay() + + # TODO(ekl) move this class to common class LocalReplayBuffer(ParallelIteratorWorker): """A replay buffer shard. @@ -316,8 +334,7 @@ def __init__(self, self.multiagent_sync_replay = multiagent_sync_replay def gen_replay(): - while True: - yield self.replay() + return _GenReplay(self) ParallelIteratorWorker.__init__(self, gen_replay, False) @@ -356,9 +373,20 @@ def add_batch(self, batch): with self.add_batch_timer: for policy_id, s in batch.policy_batches.items(): for row in s.rows(): + states_in = [] + i = 0 + while "state_in_{}".format(i) in row: + states_in.append(row["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in row: + states_out.append(row["state_out_{}".format(i)]) + i += 1 self.replay_buffers[policy_id].add( - row["obs"], row["actions"], row["rewards"], - row["new_obs"], row["dones"], row["states"], + row["obs"], states_in, row["actions"], row["rewards"], + row["new_obs"], row["dones"], states_out, + row["seq_lens"] if "seq_lens" in row else None, row["weights"] if "weights" in row else None) self.num_added += batch.count @@ -382,9 +410,16 @@ def replay(self): self.replay_batch_size) else: idxes = replay_buffer.sample_idxes(self.replay_batch_size) - (obses_t, actions, rewards, obses_tp1, dones, states, weights, - batch_indexes) = replay_buffer.sample_with_idxes( + (obses_t, states_t, actions, rewards, obses_tp1, dones, states_tp1, seq_lens, + weights, batch_indexes) = replay_buffer.sample_with_idxes( idxes, beta=self.prioritized_replay_beta) + state_in_out_and_seq_lens = {} + for i, h in enumerate(states_t): + state_in_out_and_seq_lens["state_in_{}".format(i)] = h + for i, h in enumerate(states_tp1): + state_in_out_and_seq_lens["state_out_{}".format(i)] = h + if state_in_out_and_seq_lens: + state_in_out_and_seq_lens["seq_lens"] = seq_lens samples[policy_id] = SampleBatch({ "obs": obses_t, "actions": actions, @@ -392,8 +427,8 @@ def replay(self): "new_obs": obses_tp1, "dones": dones, "weights": weights, - "states": states, - "batch_indexes": batch_indexes + "batch_indexes": batch_indexes, + **state_in_out_and_seq_lens }) return MultiAgentBatch(samples, self.replay_batch_size) @@ -448,8 +483,7 @@ def __init__( self.multiagent_sync_replay = multiagent_sync_replay def gen_replay(): - while True: - yield self.replay() + return _GenReplay(self) ParallelIteratorWorker.__init__(self, gen_replay, False) @@ -494,12 +528,21 @@ def replay(self): idxes = replay_buffer.sample_idxes(self.replay_batch_size) ( obses_t, + states_t, actions, rewards, obses_tp1, dones, - states, + states_tp1, + seq_lens, ) = replay_buffer.sample_with_idxes(idxes) + state_in_out_and_seq_lens = {} + for i, h in enumerate(states_t): + state_in_out_and_seq_lens["state_in_{}".format(i)] = h + for i, h in enumerate(states_tp1): + state_in_out_and_seq_lens["state_out_{}".format(i)] = h + if state_in_out_and_seq_lens: + state_in_out_and_seq_lens["seq_lens"] = seq_lens samples[policy_id] = SampleBatch( { "obs": obses_t, @@ -507,7 +550,7 @@ def replay(self): "rewards": rewards, "new_obs": obses_tp1, "dones": dones, - "states": states, + **state_in_out_and_seq_lens } ) return MultiAgentBatch(samples, self.replay_batch_size) diff --git a/rllib/tests/data/cartpole_small/output-2019-02-03_20-27-20_worker-0_0.json b/rllib/tests/data/cartpole_small/output-2019-02-03_20-27-20_worker-0_0.json index 803617e91d7e..258020398f10 100644 --- a/rllib/tests/data/cartpole_small/output-2019-02-03_20-27-20_worker-0_0.json +++ b/rllib/tests/data/cartpole_small/output-2019-02-03_20-27-20_worker-0_0.json @@ -1,3 +1,3 @@ -{"type": "SampleBatch", "weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "eps_id": [241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760], "dones": [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true], "infos": [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}], "prev_rewards": [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "t": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55], "agent_index": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "action_prob": [0.4979577958583832, 0.5745141506195068, 0.5042742490768433, 0.5248998403549194, 0.5048907995223999, 0.5254997611045837, 0.4930223524570465, 0.5723332166671753, 0.5071576237678528, 0.5262983441352844, 0.5075111389160156, 0.4721700847148895, 0.4541035294532776, 0.5691784024238586, 0.45002007484436035, 0.42802754044532776, 0.5951988697052002, 0.5743389129638672, 0.44297751784324646, 0.5751434564590454, 0.4427056908607483, 0.575354278087616, 0.5583169460296631, 0.5349109768867493, 0.49323225021362305, 0.42819857597351074, 0.6240300536155701, 0.42723774909973145, 0.6247843503952026, 0.4268564283847809, 0.6255699396133423, 0.5718400478363037, 0.49357253313064575, 0.5718478560447693, 0.506999135017395, 0.4627947509288788, 0.44369709491729736, 0.42281273007392883, 0.40176495909690857, 0.6177492141723633, 0.6000679731369019, 0.4211883246898651, 0.5995147228240967, 0.578464925289154, 0.5586039423942566, 0.5260810256004333, 0.4879906177520752, 0.42811155319213867, 0.6308852434158325, 0.5760338306427002, 0.5073276162147522, 0.46694710850715637, 0.43938523530960083, 0.5832104086875916, 0.5628215670585632, 0.5309032201766968], "actions": [0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], "q_values": [[-0.005643954500555992, 0.0025248583406209946], [-0.04723002016544342, 0.2530632019042969], [-0.004162287805229425, 0.012935103848576546], [0.05779631435871124, -0.041885510087013245], [-0.0001599406823515892, 0.019403917714953423], [0.05187809467315674, -0.05020952224731445], [-6.351247429847717e-05, 0.027848877012729645], [-0.03533334285020828, 0.2560437023639679], [0.005023432895541191, 0.03365574777126312], [0.04525064304471016, -0.06003996357321739], [0.002838471904397011, 0.032885171473026276], [0.03723599761724472, -0.07419878989458084], [0.09575563669204712, -0.0883483961224556], [0.16416001319885254, -0.11433979868888855], [0.09313704073429108, -0.10745253413915634], [0.16196757555007935, -0.12793570756912231], [0.23910409212112427, -0.1463954746723175], [0.15805242955684662, -0.14152376353740692], [0.09662380814552307, -0.1324627697467804], [0.1541520208120346, -0.14871598780155182], [0.0929112657904625, -0.1372770369052887], [0.1511463224887848, -0.15258446335792542], [0.0875367745757103, -0.14679750800132751], [0.08854943513870239, -0.05132210999727249], [0.018426118418574333, 0.045498818159103394], [-0.04996141046285629, 0.23924344778060913], [-0.09354546666145325, 0.4131438434123993], [-0.038044273853302, 0.255085825920105], [-0.09211604297161102, 0.4177895784378052], [-0.030748017132282257, 0.26394063234329224], [-0.09104493260383606, 0.4222134053707123], [-0.02319370210170746, 0.2661687135696411], [0.02133956551551819, 0.04705086350440979], [-0.021654099225997925, 0.2677402198314667], [0.01794305630028248, 0.04594135284423828], [0.05681019276380539, -0.0922863557934761], [0.11023147404193878, -0.1159394159913063], [0.16652457416057587, -0.14471273124217987], [0.23569053411483765, -0.16242587566375732], [0.31461724638938904, -0.165388286113739], [0.22523169219493866, -0.1805165857076645], [0.14499591290950775, -0.17290116846561432], [0.2126035839319229, -0.19084002077579498], [0.12525871396064758, -0.19121608138084412], [0.07890036702156067, -0.15659788250923157], [0.07070913910865784, -0.03370969370007515], [-0.0010413788259029388, 0.047005534172058105], [-0.05502410978078842, 0.2345360815525055], [-0.15737640857696533, 0.37863999605178833], [-0.09506852179765701, 0.21144413948059082], [-0.06340484321117401, -0.0340922586619854], [0.016717009246349335, -0.11568755656480789], [0.059842679649591446, -0.1838146150112152], [0.12809047102928162, -0.20787617564201355], [0.055311597883701324, -0.19730976223945618], [-0.022230863571166992, -0.14600159227848053]], "rewards": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "prev_actions": [0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], "obs": [[0.040251147001981735, -0.009447001852095127, 0.04735473543405533, -0.00123753328807652], [0.040062207728624344, -0.2052149772644043, 0.04732998460531235, 0.30600231885910034], [0.03595791012048721, -0.010798314586281776, 0.05345002934336662, 0.028613731265068054], [0.03574194014072418, 0.18351800739765167, 0.054022304713726044, -0.24673765897750854], [0.039412301033735275, -0.012332209385931492, 0.04908755049109459, 0.06248391792178154], [0.03916565701365471, 0.1820528209209442, 0.050337228924036026, -0.2143164724111557], [0.0428067147731781, -0.013751287944614887, 0.04605090245604515, 0.09381057322025299], [0.04253168776631355, -0.20950199663639069, 0.04792711138725281, 0.4006595313549042], [0.03834164887666702, -0.015091483481228352, 0.055940303951501846, 0.12346379458904266], [0.03803981840610504, 0.17918626964092255, 0.05840957909822464, -0.1510591059923172], [0.041623543947935104, -0.01672130823135376, 0.055388398468494415, 0.1594637781381607], [0.04128911718726158, 0.17756572365760803, 0.05857767164707184, -0.11524398624897003], [0.044840432703495026, 0.37180155515670776, 0.05627279356122017, -0.3888860046863556], [0.05227646231651306, 0.5660815238952637, 0.04849507287144661, -0.6633091568946838], [0.06359809637069702, 0.3703196048736572, 0.03522888943552971, -0.3557596206665039], [0.0710044875741005, 0.5649234652519226, 0.028113696724176407, -0.6371290683746338], [0.08230295777320862, 0.7596423029899597, 0.015371114946901798, -0.9208275675773621], [0.09749580174684525, 0.5643160343170166, -0.003045437391847372, -0.623353898525238], [0.10878212004899979, 0.36923670768737793, -0.015512514859437943, -0.3316316306591034], [0.116166852414608, 0.5645759701728821, -0.022145148366689682, -0.6291658282279968], [0.1274583786725998, 0.36976999044418335, -0.03472846373915672, -0.3435385525226593], [0.1348537802696228, 0.5653683543205261, -0.04159923642873764, -0.6469672918319702], [0.14616113901138306, 0.3708499073982239, -0.054538581520318985, -0.3676687479019165], [0.15357813239097595, 0.17654363811016083, -0.06189195439219475, -0.09266908466815948], [0.15710900723934174, -0.01763911545276642, -0.06374533474445343, 0.17986272275447845], [0.1567562371492386, -0.2117937058210373, -0.06014808267354965, 0.4517746567726135], [0.15252035856246948, -0.4060157239437103, -0.0511125884950161, 0.7249079942703247], [0.14440004527568817, -0.21022562682628632, -0.03661442920565605, 0.4165858030319214], [0.14019553363323212, -0.4048100411891937, -0.028282713145017624, 0.6975045800209045], [0.13209933042526245, -0.20930756628513336, -0.014332621358335018, 0.39605414867401123], [0.12791317701339722, -0.4042232632637024, -0.006411538925021887, 0.6841840147972107], [0.1198287084698677, -0.20901288092136383, 0.007272141519933939, 0.38948947191238403], [0.11564845591783524, -0.013994891196489334, 0.015061930753290653, 0.09910821169614792], [0.11536855250597, -0.20932942628860474, 0.01704409532248974, 0.39650481939315796], [0.11118196696043015, -0.014453399926424026, 0.024974191561341286, 0.1092439591884613], [0.1108928993344307, 0.18030193448066711, 0.0271590705960989, -0.17545630037784576], [0.11449893563985825, 0.3750248849391937, 0.023649943992495537, -0.45944923162460327], [0.12199943512678146, 0.5698046684265137, 0.014460960403084755, -0.7445847988128662], [0.13339552283287048, 0.7647241353988647, -0.000430735235568136, -1.032681941986084], [0.14869001507759094, 0.9598518013954163, -0.02108437567949295, -1.3255001306533813], [0.16788704693317413, 0.7650023102760315, -0.047594375908374786, -1.0394892692565918], [0.1831870973110199, 0.5705440044403076, -0.06838416308164597, -0.762119472026825], [0.1945979744195938, 0.7665379047393799, -0.08362655341625214, -1.0755125284194946], [0.2099287360906601, 0.5726144313812256, -0.10513680428266525, -0.8102014064788818], [0.2213810235261917, 0.3790779709815979, -0.12134082615375519, -0.552353024482727], [0.22896258533000946, 0.1858503371477127, -0.1323878914117813, -0.30022940039634705], [0.2326795905828476, -0.007160619366914034, -0.13839247822761536, -0.05205482989549637], [0.23253637552261353, -0.2000548243522644, -0.1394335776567459, 0.1939624696969986], [0.22853527963161469, -0.3929353952407837, -0.13555432856082916, 0.4396146833896637], [0.22067657113075256, -0.1961815357208252, -0.1267620325088501, 0.10746019333600998], [0.21675294637680054, 0.0005075104418210685, -0.12461283057928085, -0.22237446904182434], [0.21676309406757355, 0.19716985523700714, -0.1290603131055832, -0.5516219735145569], [0.2207064926624298, 0.39384564757347107, -0.14009276032447815, -0.8820206522941589], [0.22858339548110962, 0.5905638933181763, -0.15773317217826843, -1.2152597904205322], [0.2403946816921234, 0.39778846502304077, -0.18203836679458618, -0.9758678674697876], [0.24835044145584106, 0.20551282167434692, -0.20155572891235352, -0.745444118976593]], "new_obs": [[0.040062207728624344, -0.2052149772644043, 0.04732998460531235, 0.30600231885910034], [0.03595791012048721, -0.010798314586281776, 0.05345002934336662, 0.028613731265068054], [0.03574194014072418, 0.18351800739765167, 0.054022304713726044, -0.24673765897750854], [0.039412301033735275, -0.012332209385931492, 0.04908755049109459, 0.06248391792178154], [0.03916565701365471, 0.1820528209209442, 0.050337228924036026, -0.2143164724111557], [0.0428067147731781, -0.013751287944614887, 0.04605090245604515, 0.09381057322025299], [0.04253168776631355, -0.20950199663639069, 0.04792711138725281, 0.4006595313549042], [0.03834164887666702, -0.015091483481228352, 0.055940303951501846, 0.12346379458904266], [0.03803981840610504, 0.17918626964092255, 0.05840957909822464, -0.1510591059923172], [0.041623543947935104, -0.01672130823135376, 0.055388398468494415, 0.1594637781381607], [0.04128911718726158, 0.17756572365760803, 0.05857767164707184, -0.11524398624897003], [0.044840432703495026, 0.37180155515670776, 0.05627279356122017, -0.3888860046863556], [0.05227646231651306, 0.5660815238952637, 0.04849507287144661, -0.6633091568946838], [0.06359809637069702, 0.3703196048736572, 0.03522888943552971, -0.3557596206665039], [0.0710044875741005, 0.5649234652519226, 0.028113696724176407, -0.6371290683746338], [0.08230295777320862, 0.7596423029899597, 0.015371114946901798, -0.9208275675773621], [0.09749580174684525, 0.5643160343170166, -0.003045437391847372, -0.623353898525238], [0.10878212004899979, 0.36923670768737793, -0.015512514859437943, -0.3316316306591034], [0.116166852414608, 0.5645759701728821, -0.022145148366689682, -0.6291658282279968], [0.1274583786725998, 0.36976999044418335, -0.03472846373915672, -0.3435385525226593], [0.1348537802696228, 0.5653683543205261, -0.04159923642873764, -0.6469672918319702], [0.14616113901138306, 0.3708499073982239, -0.054538581520318985, -0.3676687479019165], [0.15357813239097595, 0.17654363811016083, -0.06189195439219475, -0.09266908466815948], [0.15710900723934174, -0.01763911545276642, -0.06374533474445343, 0.17986272275447845], [0.1567562371492386, -0.2117937058210373, -0.06014808267354965, 0.4517746567726135], [0.15252035856246948, -0.4060157239437103, -0.0511125884950161, 0.7249079942703247], [0.14440004527568817, -0.21022562682628632, -0.03661442920565605, 0.4165858030319214], [0.14019553363323212, -0.4048100411891937, -0.028282713145017624, 0.6975045800209045], [0.13209933042526245, -0.20930756628513336, -0.014332621358335018, 0.39605414867401123], [0.12791317701339722, -0.4042232632637024, -0.006411538925021887, 0.6841840147972107], [0.1198287084698677, -0.20901288092136383, 0.007272141519933939, 0.38948947191238403], [0.11564845591783524, -0.013994891196489334, 0.015061930753290653, 0.09910821169614792], [0.11536855250597, -0.20932942628860474, 0.01704409532248974, 0.39650481939315796], [0.11118196696043015, -0.014453399926424026, 0.024974191561341286, 0.1092439591884613], [0.1108928993344307, 0.18030193448066711, 0.0271590705960989, -0.17545630037784576], [0.11449893563985825, 0.3750248849391937, 0.023649943992495537, -0.45944923162460327], [0.12199943512678146, 0.5698046684265137, 0.014460960403084755, -0.7445847988128662], [0.13339552283287048, 0.7647241353988647, -0.000430735235568136, -1.032681941986084], [0.14869001507759094, 0.9598518013954163, -0.02108437567949295, -1.3255001306533813], [0.16788704693317413, 0.7650023102760315, -0.047594375908374786, -1.0394892692565918], [0.1831870973110199, 0.5705440044403076, -0.06838416308164597, -0.762119472026825], [0.1945979744195938, 0.7665379047393799, -0.08362655341625214, -1.0755125284194946], [0.2099287360906601, 0.5726144313812256, -0.10513680428266525, -0.8102014064788818], [0.2213810235261917, 0.3790779709815979, -0.12134082615375519, -0.552353024482727], [0.22896258533000946, 0.1858503371477127, -0.1323878914117813, -0.30022940039634705], [0.2326795905828476, -0.007160619366914034, -0.13839247822761536, -0.05205482989549637], [0.23253637552261353, -0.2000548243522644, -0.1394335776567459, 0.1939624696969986], [0.22853527963161469, -0.3929353952407837, -0.13555432856082916, 0.4396146833896637], [0.22067657113075256, -0.1961815357208252, -0.1267620325088501, 0.10746019333600998], [0.21675294637680054, 0.0005075104418210685, -0.12461283057928085, -0.22237446904182434], [0.21676309406757355, 0.19716985523700714, -0.1290603131055832, -0.5516219735145569], [0.2207064926624298, 0.39384564757347107, -0.14009276032447815, -0.8820206522941589], [0.22858339548110962, 0.5905638933181763, -0.15773317217826843, -1.2152597904205322], [0.2403946816921234, 0.39778846502304077, -0.18203836679458618, -0.9758678674697876], [0.24835044145584106, 0.20551282167434692, -0.20155572891235352, -0.745444118976593], [0.2524607181549072, 0.01365789957344532, -0.21646460890769958, -0.5223444700241089]]} -{"type": "SampleBatch", "weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "eps_id": [1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020], "dones": [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true], "infos": [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}], "prev_rewards": [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "t": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], "agent_index": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "action_prob": [0.5135254263877869, 0.4770704507827759, 0.5442214012145996, 0.47627949714660645, 0.5454674363136292, 0.5253314971923828, 0.48434364795684814, 0.5828204154968262, 0.48531463742256165, 0.5827109813690186, 0.5136748552322388, 0.4766709804534912, 0.45407694578170776, 0.4279625415802002, 0.5955550074577332, 0.5748928189277649, 0.5481062531471252, 0.4735119938850403, 0.5489782094955444, 0.47440415620803833, 0.5505622625350952, 0.5247683525085449, 0.5148704051971436, 0.4746163487434387, 0.4442490339279175, 0.4205590784549713], "actions": [1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], "q_values": [[-0.015597449615597725, 0.038517292588949203], [0.04316295310854912, -0.04861947521567345], [0.09876783937215805, -0.0785810723900795], [0.03863132745027542, -0.05632191151380539], [0.09450361132621765, -0.08787006139755249], [0.033118072897195816, -0.06829479336738586], [-0.011613234877586365, 0.0510326623916626], [-0.08389873802661896, 0.25046348571777344], [-0.021378351375460625, 0.0373799204826355], [-0.08555285632610321, 0.24835921823978424], [-0.028901388868689537, 0.025811681523919106], [0.02785981446504593, -0.0655241534113884], [0.0917566642165184, -0.09245472401380539], [0.1692613959312439, -0.12090739607810974], [0.25693047046661377, -0.1300475001335144], [0.1545487344264984, -0.14729353785514832], [0.055337414145469666, -0.13768470287322998], [0.00671960785984993, -0.09933169186115265], [0.05141502618789673, -0.14512820541858673], [-0.008995093405246735, -0.1114681214094162], [0.0450827032327652, -0.15785999596118927], [-0.02486952394247055, -0.12402410060167313], [-0.15750475227832794, -0.09800545871257782], [-0.04371977970004082, -0.14534175395965576], [0.03489668667316437, -0.1890382468700409], [0.1171964704990387, -0.20328232645988464]], "rewards": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "prev_actions": [0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1], "obs": [[0.0450199730694294, -0.03486160933971405, 0.016064710915088654, 0.011697827838361263], [0.04432274028658867, 0.16002631187438965, 0.01629866659641266, -0.2758735120296478], [0.047523267567157745, 0.3549119830131531, 0.010781196877360344, -0.5633715987205505], [0.05462150648236275, 0.15964041650295258, -0.0004862352798227221, -0.2673116624355316], [0.05781431496143341, 0.3547693192958832, -0.0058324686251580715, -0.5601479411125183], [0.06490969657897949, 0.1597297042608261, -0.017035426571965218, -0.2693082094192505], [0.06810429692268372, -0.035145051777362823, -0.022421590983867645, 0.01795332506299019], [0.06740139424800873, -0.22993838787078857, -0.022062525153160095, 0.30347850918769836], [0.06280262768268585, -0.03450907766819, -0.01599295437335968, 0.00392001261934638], [0.06211244314908981, -0.22939805686473846, -0.0159145537763834, 0.2915143668651581], [0.057524483650922775, -0.03405284881591797, -0.010084266774356365, -0.006145021412521601], [0.05684342607855797, 0.1612122654914856, -0.010207167826592922, -0.3019925057888031], [0.06006767228245735, 0.35647818446159363, -0.016247017309069633, -0.597877025604248], [0.06719723343849182, 0.5518236756324768, -0.028204558417201042, -0.8956329822540283], [0.07823371142148972, 0.7473164796829224, -0.04611721634864807, -1.1970465183258057], [0.09318003803491592, 0.5528207421302795, -0.07005815207958221, -0.9191668629646301], [0.10423645377159119, 0.35871216654777527, -0.08844148367643356, -0.6492984294891357], [0.11141069233417511, 0.16492627561092377, -0.10142745822668076, -0.38572362065315247], [0.11470922082662582, 0.3613308370113373, -0.10914192348718643, -0.7085849642753601], [0.12193583697080612, 0.16787634789943695, -0.12331362813711166, -0.45215386152267456], [0.12529335916042328, 0.36450672149658203, -0.1323567032814026, -0.7810221314430237], [0.13258349895477295, 0.1714283674955368, -0.14797714352607727, -0.5327370762825012], [0.13601206243038177, -0.021336432546377182, -0.15863189101219177, -0.29009655117988586], [0.13558533787727356, 0.17564991116523743, -0.1644338220357895, -0.6283085346221924], [0.13909833133220673, 0.3726385235786438, -0.1769999861717224, -0.9679317474365234], [0.14655110239982605, 0.5696383714675903, -0.19635862112045288, -1.3105814456939697]], "new_obs": [[0.04432274028658867, 0.16002631187438965, 0.01629866659641266, -0.2758735120296478], [0.047523267567157745, 0.3549119830131531, 0.010781196877360344, -0.5633715987205505], [0.05462150648236275, 0.15964041650295258, -0.0004862352798227221, -0.2673116624355316], [0.05781431496143341, 0.3547693192958832, -0.0058324686251580715, -0.5601479411125183], [0.06490969657897949, 0.1597297042608261, -0.017035426571965218, -0.2693082094192505], [0.06810429692268372, -0.035145051777362823, -0.022421590983867645, 0.01795332506299019], [0.06740139424800873, -0.22993838787078857, -0.022062525153160095, 0.30347850918769836], [0.06280262768268585, -0.03450907766819, -0.01599295437335968, 0.00392001261934638], [0.06211244314908981, -0.22939805686473846, -0.0159145537763834, 0.2915143668651581], [0.057524483650922775, -0.03405284881591797, -0.010084266774356365, -0.006145021412521601], [0.05684342607855797, 0.1612122654914856, -0.010207167826592922, -0.3019925057888031], [0.06006767228245735, 0.35647818446159363, -0.016247017309069633, -0.597877025604248], [0.06719723343849182, 0.5518236756324768, -0.028204558417201042, -0.8956329822540283], [0.07823371142148972, 0.7473164796829224, -0.04611721634864807, -1.1970465183258057], [0.09318003803491592, 0.5528207421302795, -0.07005815207958221, -0.9191668629646301], [0.10423645377159119, 0.35871216654777527, -0.08844148367643356, -0.6492984294891357], [0.11141069233417511, 0.16492627561092377, -0.10142745822668076, -0.38572362065315247], [0.11470922082662582, 0.3613308370113373, -0.10914192348718643, -0.7085849642753601], [0.12193583697080612, 0.16787634789943695, -0.12331362813711166, -0.45215386152267456], [0.12529335916042328, 0.36450672149658203, -0.1323567032814026, -0.7810221314430237], [0.13258349895477295, 0.1714283674955368, -0.14797714352607727, -0.5327370762825012], [0.13601206243038177, -0.021336432546377182, -0.15863189101219177, -0.29009655117988586], [0.13558533787727356, 0.17564991116523743, -0.1644338220357895, -0.6283085346221924], [0.13909833133220673, 0.3726385235786438, -0.1769999861717224, -0.9679317474365234], [0.14655110239982605, 0.5696383714675903, -0.19635862112045288, -1.3105814456939697], [0.15794387459754944, 0.7666289806365967, -0.2225702553987503, -1.6577483415603638]]} -{"type": "SampleBatch", "weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "eps_id": [464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363], "dones": [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true], "infos": [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}], "prev_rewards": [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "t": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "agent_index": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "action_prob": [0.49811699986457825, 0.5603018999099731, 0.4948766827583313, 0.5607614517211914, 0.4922669231891632, 0.43934890627861023, 0.6127749681472778, 0.438413143157959, 0.38857191801071167, 0.6461699604988098, 0.6107516288757324, 0.43830615282058716, 0.608411967754364, 0.5631444454193115, 0.518650472164154, 0.5026047825813293, 0.48087823390960693, 0.5650154948234558, 0.4770132005214691, 0.5669832229614258], "actions": [0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1], "q_values": [[0.034373246133327484, 0.041905246675014496], [-0.040324486792087555, 0.20206278562545776], [0.03108956664800644, 0.0515836626291275], [-0.03812238574028015, 0.20613068342208862], [0.016220448538661003, 0.047155141830444336], [-0.03483893722295761, 0.20896606147289276], [-0.10473792254924774, 0.3542538285255432], [-0.02594645321369171, 0.22165822982788086], [-0.10031923651695251, 0.35299989581108093], [-0.1714298129081726, 0.430816113948822], [-0.09505866467952728, 0.3554142117500305], [0.0006859749555587769, 0.2487252801656723], [-0.08787457644939423, 0.35276734828948975], [0.004122734069824219, 0.25805625319480896], [0.038704317063093185, 0.11334069073200226], [-0.01853189617395401, -0.028951097279787064], [0.025288723409175873, 0.10181311517953873], [-0.020684152841567993, 0.24085858464241028], [0.013561476022005081, 0.10557354986667633], [-0.03565507382154465, 0.23389792442321777]], "rewards": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "prev_actions": [0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0], "obs": [[-0.03543581813573837, 0.03231120854616165, 0.04250812903046608, -0.04545578733086586], [-0.03478959575295448, -0.16339369118213654, 0.04159901291131973, 0.2603300213813782], [-0.038057468831539154, 0.0311104916036129, 0.04680561274290085, -0.018947282806038857], [-0.03743525967001915, -0.16465036571025848, 0.046426668763160706, 0.28812822699546814], [-0.04072826728224754, 0.02977983094751835, 0.052189234644174576, 0.010441737249493599], [-0.04013266786932945, -0.1660502403974533, 0.052398066967725754, 0.3191235661506653], [-0.043453674763441086, -0.36187776923179626, 0.05878053978085518, 0.6278597116470337], [-0.05069122835993767, -0.1676233410835266, 0.0713377296924591, 0.35425281524658203], [-0.05404369533061981, -0.36368328332901, 0.07842279225587845, 0.6685502529144287], [-0.061317361891269684, -0.5598031282424927, 0.09179379791021347, 0.9848584532737732], [-0.07251342386007309, -0.36602261662483215, 0.11149096488952637, 0.7223610281944275], [-0.07983388006687164, -0.17260494828224182, 0.12593817710876465, 0.4667462706565857], [-0.08328597247600555, -0.3692602813243866, 0.13527311384677887, 0.7963211536407471], [-0.09067118167877197, -0.17622822523117065, 0.15119953453540802, 0.5490673184394836], [-0.09419574588537216, 0.01648259162902832, 0.16218088567256927, 0.30758246779441833], [-0.09386609494686127, 0.20896689593791962, 0.1683325320482254, 0.0701172798871994], [-0.08968675881624222, 0.011881737969815731, 0.1697348803281784, 0.4108228385448456], [-0.08944912254810333, -0.18518869578838348, 0.17795133590698242, 0.751843273639679], [-0.09315289556980133, 0.00709147984161973, 0.19298820197582245, 0.5200196504592896], [-0.09301106631755829, -0.1901485174894333, 0.20338858664035797, 0.8667741417884827]], "new_obs": [[-0.03478959575295448, -0.16339369118213654, 0.04159901291131973, 0.2603300213813782], [-0.038057468831539154, 0.0311104916036129, 0.04680561274290085, -0.018947282806038857], [-0.03743525967001915, -0.16465036571025848, 0.046426668763160706, 0.28812822699546814], [-0.04072826728224754, 0.02977983094751835, 0.052189234644174576, 0.010441737249493599], [-0.04013266786932945, -0.1660502403974533, 0.052398066967725754, 0.3191235661506653], [-0.043453674763441086, -0.36187776923179626, 0.05878053978085518, 0.6278597116470337], [-0.05069122835993767, -0.1676233410835266, 0.0713377296924591, 0.35425281524658203], [-0.05404369533061981, -0.36368328332901, 0.07842279225587845, 0.6685502529144287], [-0.061317361891269684, -0.5598031282424927, 0.09179379791021347, 0.9848584532737732], [-0.07251342386007309, -0.36602261662483215, 0.11149096488952637, 0.7223610281944275], [-0.07983388006687164, -0.17260494828224182, 0.12593817710876465, 0.4667462706565857], [-0.08328597247600555, -0.3692602813243866, 0.13527311384677887, 0.7963211536407471], [-0.09067118167877197, -0.17622822523117065, 0.15119953453540802, 0.5490673184394836], [-0.09419574588537216, 0.01648259162902832, 0.16218088567256927, 0.30758246779441833], [-0.09386609494686127, 0.20896689593791962, 0.1683325320482254, 0.0701172798871994], [-0.08968675881624222, 0.011881737969815731, 0.1697348803281784, 0.4108228385448456], [-0.08944912254810333, -0.18518869578838348, 0.17795133590698242, 0.751843273639679], [-0.09315289556980133, 0.00709147984161973, 0.19298820197582245, 0.5200196504592896], [-0.09301106631755829, -0.1901485174894333, 0.20338858664035797, 0.8667741417884827], [-0.09681403636932373, 0.0017116105882450938, 0.22072407603263855, 0.6443008184432983]]} +{"type": "SampleBatch", "states": [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []],, "weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "eps_id": [241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760, 241561760], "dones": [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true], "infos": [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}], "prev_rewards": [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "t": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55], "agent_index": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "action_prob": [0.4979577958583832, 0.5745141506195068, 0.5042742490768433, 0.5248998403549194, 0.5048907995223999, 0.5254997611045837, 0.4930223524570465, 0.5723332166671753, 0.5071576237678528, 0.5262983441352844, 0.5075111389160156, 0.4721700847148895, 0.4541035294532776, 0.5691784024238586, 0.45002007484436035, 0.42802754044532776, 0.5951988697052002, 0.5743389129638672, 0.44297751784324646, 0.5751434564590454, 0.4427056908607483, 0.575354278087616, 0.5583169460296631, 0.5349109768867493, 0.49323225021362305, 0.42819857597351074, 0.6240300536155701, 0.42723774909973145, 0.6247843503952026, 0.4268564283847809, 0.6255699396133423, 0.5718400478363037, 0.49357253313064575, 0.5718478560447693, 0.506999135017395, 0.4627947509288788, 0.44369709491729736, 0.42281273007392883, 0.40176495909690857, 0.6177492141723633, 0.6000679731369019, 0.4211883246898651, 0.5995147228240967, 0.578464925289154, 0.5586039423942566, 0.5260810256004333, 0.4879906177520752, 0.42811155319213867, 0.6308852434158325, 0.5760338306427002, 0.5073276162147522, 0.46694710850715637, 0.43938523530960083, 0.5832104086875916, 0.5628215670585632, 0.5309032201766968], "actions": [0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], "q_values": [[-0.005643954500555992, 0.0025248583406209946], [-0.04723002016544342, 0.2530632019042969], [-0.004162287805229425, 0.012935103848576546], [0.05779631435871124, -0.041885510087013245], [-0.0001599406823515892, 0.019403917714953423], [0.05187809467315674, -0.05020952224731445], [-6.351247429847717e-05, 0.027848877012729645], [-0.03533334285020828, 0.2560437023639679], [0.005023432895541191, 0.03365574777126312], [0.04525064304471016, -0.06003996357321739], [0.002838471904397011, 0.032885171473026276], [0.03723599761724472, -0.07419878989458084], [0.09575563669204712, -0.0883483961224556], [0.16416001319885254, -0.11433979868888855], [0.09313704073429108, -0.10745253413915634], [0.16196757555007935, -0.12793570756912231], [0.23910409212112427, -0.1463954746723175], [0.15805242955684662, -0.14152376353740692], [0.09662380814552307, -0.1324627697467804], [0.1541520208120346, -0.14871598780155182], [0.0929112657904625, -0.1372770369052887], [0.1511463224887848, -0.15258446335792542], [0.0875367745757103, -0.14679750800132751], [0.08854943513870239, -0.05132210999727249], [0.018426118418574333, 0.045498818159103394], [-0.04996141046285629, 0.23924344778060913], [-0.09354546666145325, 0.4131438434123993], [-0.038044273853302, 0.255085825920105], [-0.09211604297161102, 0.4177895784378052], [-0.030748017132282257, 0.26394063234329224], [-0.09104493260383606, 0.4222134053707123], [-0.02319370210170746, 0.2661687135696411], [0.02133956551551819, 0.04705086350440979], [-0.021654099225997925, 0.2677402198314667], [0.01794305630028248, 0.04594135284423828], [0.05681019276380539, -0.0922863557934761], [0.11023147404193878, -0.1159394159913063], [0.16652457416057587, -0.14471273124217987], [0.23569053411483765, -0.16242587566375732], [0.31461724638938904, -0.165388286113739], [0.22523169219493866, -0.1805165857076645], [0.14499591290950775, -0.17290116846561432], [0.2126035839319229, -0.19084002077579498], [0.12525871396064758, -0.19121608138084412], [0.07890036702156067, -0.15659788250923157], [0.07070913910865784, -0.03370969370007515], [-0.0010413788259029388, 0.047005534172058105], [-0.05502410978078842, 0.2345360815525055], [-0.15737640857696533, 0.37863999605178833], [-0.09506852179765701, 0.21144413948059082], [-0.06340484321117401, -0.0340922586619854], [0.016717009246349335, -0.11568755656480789], [0.059842679649591446, -0.1838146150112152], [0.12809047102928162, -0.20787617564201355], [0.055311597883701324, -0.19730976223945618], [-0.022230863571166992, -0.14600159227848053]], "rewards": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "prev_actions": [0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], "obs": [[0.040251147001981735, -0.009447001852095127, 0.04735473543405533, -0.00123753328807652], [0.040062207728624344, -0.2052149772644043, 0.04732998460531235, 0.30600231885910034], [0.03595791012048721, -0.010798314586281776, 0.05345002934336662, 0.028613731265068054], [0.03574194014072418, 0.18351800739765167, 0.054022304713726044, -0.24673765897750854], [0.039412301033735275, -0.012332209385931492, 0.04908755049109459, 0.06248391792178154], [0.03916565701365471, 0.1820528209209442, 0.050337228924036026, -0.2143164724111557], [0.0428067147731781, -0.013751287944614887, 0.04605090245604515, 0.09381057322025299], [0.04253168776631355, -0.20950199663639069, 0.04792711138725281, 0.4006595313549042], [0.03834164887666702, -0.015091483481228352, 0.055940303951501846, 0.12346379458904266], [0.03803981840610504, 0.17918626964092255, 0.05840957909822464, -0.1510591059923172], [0.041623543947935104, -0.01672130823135376, 0.055388398468494415, 0.1594637781381607], [0.04128911718726158, 0.17756572365760803, 0.05857767164707184, -0.11524398624897003], [0.044840432703495026, 0.37180155515670776, 0.05627279356122017, -0.3888860046863556], [0.05227646231651306, 0.5660815238952637, 0.04849507287144661, -0.6633091568946838], [0.06359809637069702, 0.3703196048736572, 0.03522888943552971, -0.3557596206665039], [0.0710044875741005, 0.5649234652519226, 0.028113696724176407, -0.6371290683746338], [0.08230295777320862, 0.7596423029899597, 0.015371114946901798, -0.9208275675773621], [0.09749580174684525, 0.5643160343170166, -0.003045437391847372, -0.623353898525238], [0.10878212004899979, 0.36923670768737793, -0.015512514859437943, -0.3316316306591034], [0.116166852414608, 0.5645759701728821, -0.022145148366689682, -0.6291658282279968], [0.1274583786725998, 0.36976999044418335, -0.03472846373915672, -0.3435385525226593], [0.1348537802696228, 0.5653683543205261, -0.04159923642873764, -0.6469672918319702], [0.14616113901138306, 0.3708499073982239, -0.054538581520318985, -0.3676687479019165], [0.15357813239097595, 0.17654363811016083, -0.06189195439219475, -0.09266908466815948], [0.15710900723934174, -0.01763911545276642, -0.06374533474445343, 0.17986272275447845], [0.1567562371492386, -0.2117937058210373, -0.06014808267354965, 0.4517746567726135], [0.15252035856246948, -0.4060157239437103, -0.0511125884950161, 0.7249079942703247], [0.14440004527568817, -0.21022562682628632, -0.03661442920565605, 0.4165858030319214], [0.14019553363323212, -0.4048100411891937, -0.028282713145017624, 0.6975045800209045], [0.13209933042526245, -0.20930756628513336, -0.014332621358335018, 0.39605414867401123], [0.12791317701339722, -0.4042232632637024, -0.006411538925021887, 0.6841840147972107], [0.1198287084698677, -0.20901288092136383, 0.007272141519933939, 0.38948947191238403], [0.11564845591783524, -0.013994891196489334, 0.015061930753290653, 0.09910821169614792], [0.11536855250597, -0.20932942628860474, 0.01704409532248974, 0.39650481939315796], [0.11118196696043015, -0.014453399926424026, 0.024974191561341286, 0.1092439591884613], [0.1108928993344307, 0.18030193448066711, 0.0271590705960989, -0.17545630037784576], [0.11449893563985825, 0.3750248849391937, 0.023649943992495537, -0.45944923162460327], [0.12199943512678146, 0.5698046684265137, 0.014460960403084755, -0.7445847988128662], [0.13339552283287048, 0.7647241353988647, -0.000430735235568136, -1.032681941986084], [0.14869001507759094, 0.9598518013954163, -0.02108437567949295, -1.3255001306533813], [0.16788704693317413, 0.7650023102760315, -0.047594375908374786, -1.0394892692565918], [0.1831870973110199, 0.5705440044403076, -0.06838416308164597, -0.762119472026825], [0.1945979744195938, 0.7665379047393799, -0.08362655341625214, -1.0755125284194946], [0.2099287360906601, 0.5726144313812256, -0.10513680428266525, -0.8102014064788818], [0.2213810235261917, 0.3790779709815979, -0.12134082615375519, -0.552353024482727], [0.22896258533000946, 0.1858503371477127, -0.1323878914117813, -0.30022940039634705], [0.2326795905828476, -0.007160619366914034, -0.13839247822761536, -0.05205482989549637], [0.23253637552261353, -0.2000548243522644, -0.1394335776567459, 0.1939624696969986], [0.22853527963161469, -0.3929353952407837, -0.13555432856082916, 0.4396146833896637], [0.22067657113075256, -0.1961815357208252, -0.1267620325088501, 0.10746019333600998], [0.21675294637680054, 0.0005075104418210685, -0.12461283057928085, -0.22237446904182434], [0.21676309406757355, 0.19716985523700714, -0.1290603131055832, -0.5516219735145569], [0.2207064926624298, 0.39384564757347107, -0.14009276032447815, -0.8820206522941589], [0.22858339548110962, 0.5905638933181763, -0.15773317217826843, -1.2152597904205322], [0.2403946816921234, 0.39778846502304077, -0.18203836679458618, -0.9758678674697876], [0.24835044145584106, 0.20551282167434692, -0.20155572891235352, -0.745444118976593]], "new_obs": [[0.040062207728624344, -0.2052149772644043, 0.04732998460531235, 0.30600231885910034], [0.03595791012048721, -0.010798314586281776, 0.05345002934336662, 0.028613731265068054], [0.03574194014072418, 0.18351800739765167, 0.054022304713726044, -0.24673765897750854], [0.039412301033735275, -0.012332209385931492, 0.04908755049109459, 0.06248391792178154], [0.03916565701365471, 0.1820528209209442, 0.050337228924036026, -0.2143164724111557], [0.0428067147731781, -0.013751287944614887, 0.04605090245604515, 0.09381057322025299], [0.04253168776631355, -0.20950199663639069, 0.04792711138725281, 0.4006595313549042], [0.03834164887666702, -0.015091483481228352, 0.055940303951501846, 0.12346379458904266], [0.03803981840610504, 0.17918626964092255, 0.05840957909822464, -0.1510591059923172], [0.041623543947935104, -0.01672130823135376, 0.055388398468494415, 0.1594637781381607], [0.04128911718726158, 0.17756572365760803, 0.05857767164707184, -0.11524398624897003], [0.044840432703495026, 0.37180155515670776, 0.05627279356122017, -0.3888860046863556], [0.05227646231651306, 0.5660815238952637, 0.04849507287144661, -0.6633091568946838], [0.06359809637069702, 0.3703196048736572, 0.03522888943552971, -0.3557596206665039], [0.0710044875741005, 0.5649234652519226, 0.028113696724176407, -0.6371290683746338], [0.08230295777320862, 0.7596423029899597, 0.015371114946901798, -0.9208275675773621], [0.09749580174684525, 0.5643160343170166, -0.003045437391847372, -0.623353898525238], [0.10878212004899979, 0.36923670768737793, -0.015512514859437943, -0.3316316306591034], [0.116166852414608, 0.5645759701728821, -0.022145148366689682, -0.6291658282279968], [0.1274583786725998, 0.36976999044418335, -0.03472846373915672, -0.3435385525226593], [0.1348537802696228, 0.5653683543205261, -0.04159923642873764, -0.6469672918319702], [0.14616113901138306, 0.3708499073982239, -0.054538581520318985, -0.3676687479019165], [0.15357813239097595, 0.17654363811016083, -0.06189195439219475, -0.09266908466815948], [0.15710900723934174, -0.01763911545276642, -0.06374533474445343, 0.17986272275447845], [0.1567562371492386, -0.2117937058210373, -0.06014808267354965, 0.4517746567726135], [0.15252035856246948, -0.4060157239437103, -0.0511125884950161, 0.7249079942703247], [0.14440004527568817, -0.21022562682628632, -0.03661442920565605, 0.4165858030319214], [0.14019553363323212, -0.4048100411891937, -0.028282713145017624, 0.6975045800209045], [0.13209933042526245, -0.20930756628513336, -0.014332621358335018, 0.39605414867401123], [0.12791317701339722, -0.4042232632637024, -0.006411538925021887, 0.6841840147972107], [0.1198287084698677, -0.20901288092136383, 0.007272141519933939, 0.38948947191238403], [0.11564845591783524, -0.013994891196489334, 0.015061930753290653, 0.09910821169614792], [0.11536855250597, -0.20932942628860474, 0.01704409532248974, 0.39650481939315796], [0.11118196696043015, -0.014453399926424026, 0.024974191561341286, 0.1092439591884613], [0.1108928993344307, 0.18030193448066711, 0.0271590705960989, -0.17545630037784576], [0.11449893563985825, 0.3750248849391937, 0.023649943992495537, -0.45944923162460327], [0.12199943512678146, 0.5698046684265137, 0.014460960403084755, -0.7445847988128662], [0.13339552283287048, 0.7647241353988647, -0.000430735235568136, -1.032681941986084], [0.14869001507759094, 0.9598518013954163, -0.02108437567949295, -1.3255001306533813], [0.16788704693317413, 0.7650023102760315, -0.047594375908374786, -1.0394892692565918], [0.1831870973110199, 0.5705440044403076, -0.06838416308164597, -0.762119472026825], [0.1945979744195938, 0.7665379047393799, -0.08362655341625214, -1.0755125284194946], [0.2099287360906601, 0.5726144313812256, -0.10513680428266525, -0.8102014064788818], [0.2213810235261917, 0.3790779709815979, -0.12134082615375519, -0.552353024482727], [0.22896258533000946, 0.1858503371477127, -0.1323878914117813, -0.30022940039634705], [0.2326795905828476, -0.007160619366914034, -0.13839247822761536, -0.05205482989549637], [0.23253637552261353, -0.2000548243522644, -0.1394335776567459, 0.1939624696969986], [0.22853527963161469, -0.3929353952407837, -0.13555432856082916, 0.4396146833896637], [0.22067657113075256, -0.1961815357208252, -0.1267620325088501, 0.10746019333600998], [0.21675294637680054, 0.0005075104418210685, -0.12461283057928085, -0.22237446904182434], [0.21676309406757355, 0.19716985523700714, -0.1290603131055832, -0.5516219735145569], [0.2207064926624298, 0.39384564757347107, -0.14009276032447815, -0.8820206522941589], [0.22858339548110962, 0.5905638933181763, -0.15773317217826843, -1.2152597904205322], [0.2403946816921234, 0.39778846502304077, -0.18203836679458618, -0.9758678674697876], [0.24835044145584106, 0.20551282167434692, -0.20155572891235352, -0.745444118976593], [0.2524607181549072, 0.01365789957344532, -0.21646460890769958, -0.5223444700241089]]} +{"type": "SampleBatch", "states": [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []], "weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "eps_id": [1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020, 1238833020], "dones": [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true], "infos": [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}], "prev_rewards": [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "t": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25], "agent_index": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "action_prob": [0.5135254263877869, 0.4770704507827759, 0.5442214012145996, 0.47627949714660645, 0.5454674363136292, 0.5253314971923828, 0.48434364795684814, 0.5828204154968262, 0.48531463742256165, 0.5827109813690186, 0.5136748552322388, 0.4766709804534912, 0.45407694578170776, 0.4279625415802002, 0.5955550074577332, 0.5748928189277649, 0.5481062531471252, 0.4735119938850403, 0.5489782094955444, 0.47440415620803833, 0.5505622625350952, 0.5247683525085449, 0.5148704051971436, 0.4746163487434387, 0.4442490339279175, 0.4205590784549713], "actions": [1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], "q_values": [[-0.015597449615597725, 0.038517292588949203], [0.04316295310854912, -0.04861947521567345], [0.09876783937215805, -0.0785810723900795], [0.03863132745027542, -0.05632191151380539], [0.09450361132621765, -0.08787006139755249], [0.033118072897195816, -0.06829479336738586], [-0.011613234877586365, 0.0510326623916626], [-0.08389873802661896, 0.25046348571777344], [-0.021378351375460625, 0.0373799204826355], [-0.08555285632610321, 0.24835921823978424], [-0.028901388868689537, 0.025811681523919106], [0.02785981446504593, -0.0655241534113884], [0.0917566642165184, -0.09245472401380539], [0.1692613959312439, -0.12090739607810974], [0.25693047046661377, -0.1300475001335144], [0.1545487344264984, -0.14729353785514832], [0.055337414145469666, -0.13768470287322998], [0.00671960785984993, -0.09933169186115265], [0.05141502618789673, -0.14512820541858673], [-0.008995093405246735, -0.1114681214094162], [0.0450827032327652, -0.15785999596118927], [-0.02486952394247055, -0.12402410060167313], [-0.15750475227832794, -0.09800545871257782], [-0.04371977970004082, -0.14534175395965576], [0.03489668667316437, -0.1890382468700409], [0.1171964704990387, -0.20328232645988464]], "rewards": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "prev_actions": [0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1], "obs": [[0.0450199730694294, -0.03486160933971405, 0.016064710915088654, 0.011697827838361263], [0.04432274028658867, 0.16002631187438965, 0.01629866659641266, -0.2758735120296478], [0.047523267567157745, 0.3549119830131531, 0.010781196877360344, -0.5633715987205505], [0.05462150648236275, 0.15964041650295258, -0.0004862352798227221, -0.2673116624355316], [0.05781431496143341, 0.3547693192958832, -0.0058324686251580715, -0.5601479411125183], [0.06490969657897949, 0.1597297042608261, -0.017035426571965218, -0.2693082094192505], [0.06810429692268372, -0.035145051777362823, -0.022421590983867645, 0.01795332506299019], [0.06740139424800873, -0.22993838787078857, -0.022062525153160095, 0.30347850918769836], [0.06280262768268585, -0.03450907766819, -0.01599295437335968, 0.00392001261934638], [0.06211244314908981, -0.22939805686473846, -0.0159145537763834, 0.2915143668651581], [0.057524483650922775, -0.03405284881591797, -0.010084266774356365, -0.006145021412521601], [0.05684342607855797, 0.1612122654914856, -0.010207167826592922, -0.3019925057888031], [0.06006767228245735, 0.35647818446159363, -0.016247017309069633, -0.597877025604248], [0.06719723343849182, 0.5518236756324768, -0.028204558417201042, -0.8956329822540283], [0.07823371142148972, 0.7473164796829224, -0.04611721634864807, -1.1970465183258057], [0.09318003803491592, 0.5528207421302795, -0.07005815207958221, -0.9191668629646301], [0.10423645377159119, 0.35871216654777527, -0.08844148367643356, -0.6492984294891357], [0.11141069233417511, 0.16492627561092377, -0.10142745822668076, -0.38572362065315247], [0.11470922082662582, 0.3613308370113373, -0.10914192348718643, -0.7085849642753601], [0.12193583697080612, 0.16787634789943695, -0.12331362813711166, -0.45215386152267456], [0.12529335916042328, 0.36450672149658203, -0.1323567032814026, -0.7810221314430237], [0.13258349895477295, 0.1714283674955368, -0.14797714352607727, -0.5327370762825012], [0.13601206243038177, -0.021336432546377182, -0.15863189101219177, -0.29009655117988586], [0.13558533787727356, 0.17564991116523743, -0.1644338220357895, -0.6283085346221924], [0.13909833133220673, 0.3726385235786438, -0.1769999861717224, -0.9679317474365234], [0.14655110239982605, 0.5696383714675903, -0.19635862112045288, -1.3105814456939697]], "new_obs": [[0.04432274028658867, 0.16002631187438965, 0.01629866659641266, -0.2758735120296478], [0.047523267567157745, 0.3549119830131531, 0.010781196877360344, -0.5633715987205505], [0.05462150648236275, 0.15964041650295258, -0.0004862352798227221, -0.2673116624355316], [0.05781431496143341, 0.3547693192958832, -0.0058324686251580715, -0.5601479411125183], [0.06490969657897949, 0.1597297042608261, -0.017035426571965218, -0.2693082094192505], [0.06810429692268372, -0.035145051777362823, -0.022421590983867645, 0.01795332506299019], [0.06740139424800873, -0.22993838787078857, -0.022062525153160095, 0.30347850918769836], [0.06280262768268585, -0.03450907766819, -0.01599295437335968, 0.00392001261934638], [0.06211244314908981, -0.22939805686473846, -0.0159145537763834, 0.2915143668651581], [0.057524483650922775, -0.03405284881591797, -0.010084266774356365, -0.006145021412521601], [0.05684342607855797, 0.1612122654914856, -0.010207167826592922, -0.3019925057888031], [0.06006767228245735, 0.35647818446159363, -0.016247017309069633, -0.597877025604248], [0.06719723343849182, 0.5518236756324768, -0.028204558417201042, -0.8956329822540283], [0.07823371142148972, 0.7473164796829224, -0.04611721634864807, -1.1970465183258057], [0.09318003803491592, 0.5528207421302795, -0.07005815207958221, -0.9191668629646301], [0.10423645377159119, 0.35871216654777527, -0.08844148367643356, -0.6492984294891357], [0.11141069233417511, 0.16492627561092377, -0.10142745822668076, -0.38572362065315247], [0.11470922082662582, 0.3613308370113373, -0.10914192348718643, -0.7085849642753601], [0.12193583697080612, 0.16787634789943695, -0.12331362813711166, -0.45215386152267456], [0.12529335916042328, 0.36450672149658203, -0.1323567032814026, -0.7810221314430237], [0.13258349895477295, 0.1714283674955368, -0.14797714352607727, -0.5327370762825012], [0.13601206243038177, -0.021336432546377182, -0.15863189101219177, -0.29009655117988586], [0.13558533787727356, 0.17564991116523743, -0.1644338220357895, -0.6283085346221924], [0.13909833133220673, 0.3726385235786438, -0.1769999861717224, -0.9679317474365234], [0.14655110239982605, 0.5696383714675903, -0.19635862112045288, -1.3105814456939697], [0.15794387459754944, 0.7666289806365967, -0.2225702553987503, -1.6577483415603638]]} +{"type": "SampleBatch", "states": [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []], "weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "eps_id": [464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363, 464626363], "dones": [false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, true], "infos": [{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}], "prev_rewards": [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "t": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], "agent_index": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "action_prob": [0.49811699986457825, 0.5603018999099731, 0.4948766827583313, 0.5607614517211914, 0.4922669231891632, 0.43934890627861023, 0.6127749681472778, 0.438413143157959, 0.38857191801071167, 0.6461699604988098, 0.6107516288757324, 0.43830615282058716, 0.608411967754364, 0.5631444454193115, 0.518650472164154, 0.5026047825813293, 0.48087823390960693, 0.5650154948234558, 0.4770132005214691, 0.5669832229614258], "actions": [0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1], "q_values": [[0.034373246133327484, 0.041905246675014496], [-0.040324486792087555, 0.20206278562545776], [0.03108956664800644, 0.0515836626291275], [-0.03812238574028015, 0.20613068342208862], [0.016220448538661003, 0.047155141830444336], [-0.03483893722295761, 0.20896606147289276], [-0.10473792254924774, 0.3542538285255432], [-0.02594645321369171, 0.22165822982788086], [-0.10031923651695251, 0.35299989581108093], [-0.1714298129081726, 0.430816113948822], [-0.09505866467952728, 0.3554142117500305], [0.0006859749555587769, 0.2487252801656723], [-0.08787457644939423, 0.35276734828948975], [0.004122734069824219, 0.25805625319480896], [0.038704317063093185, 0.11334069073200226], [-0.01853189617395401, -0.028951097279787064], [0.025288723409175873, 0.10181311517953873], [-0.020684152841567993, 0.24085858464241028], [0.013561476022005081, 0.10557354986667633], [-0.03565507382154465, 0.23389792442321777]], "rewards": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "prev_actions": [0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0], "obs": [[-0.03543581813573837, 0.03231120854616165, 0.04250812903046608, -0.04545578733086586], [-0.03478959575295448, -0.16339369118213654, 0.04159901291131973, 0.2603300213813782], [-0.038057468831539154, 0.0311104916036129, 0.04680561274290085, -0.018947282806038857], [-0.03743525967001915, -0.16465036571025848, 0.046426668763160706, 0.28812822699546814], [-0.04072826728224754, 0.02977983094751835, 0.052189234644174576, 0.010441737249493599], [-0.04013266786932945, -0.1660502403974533, 0.052398066967725754, 0.3191235661506653], [-0.043453674763441086, -0.36187776923179626, 0.05878053978085518, 0.6278597116470337], [-0.05069122835993767, -0.1676233410835266, 0.0713377296924591, 0.35425281524658203], [-0.05404369533061981, -0.36368328332901, 0.07842279225587845, 0.6685502529144287], [-0.061317361891269684, -0.5598031282424927, 0.09179379791021347, 0.9848584532737732], [-0.07251342386007309, -0.36602261662483215, 0.11149096488952637, 0.7223610281944275], [-0.07983388006687164, -0.17260494828224182, 0.12593817710876465, 0.4667462706565857], [-0.08328597247600555, -0.3692602813243866, 0.13527311384677887, 0.7963211536407471], [-0.09067118167877197, -0.17622822523117065, 0.15119953453540802, 0.5490673184394836], [-0.09419574588537216, 0.01648259162902832, 0.16218088567256927, 0.30758246779441833], [-0.09386609494686127, 0.20896689593791962, 0.1683325320482254, 0.0701172798871994], [-0.08968675881624222, 0.011881737969815731, 0.1697348803281784, 0.4108228385448456], [-0.08944912254810333, -0.18518869578838348, 0.17795133590698242, 0.751843273639679], [-0.09315289556980133, 0.00709147984161973, 0.19298820197582245, 0.5200196504592896], [-0.09301106631755829, -0.1901485174894333, 0.20338858664035797, 0.8667741417884827]], "new_obs": [[-0.03478959575295448, -0.16339369118213654, 0.04159901291131973, 0.2603300213813782], [-0.038057468831539154, 0.0311104916036129, 0.04680561274290085, -0.018947282806038857], [-0.03743525967001915, -0.16465036571025848, 0.046426668763160706, 0.28812822699546814], [-0.04072826728224754, 0.02977983094751835, 0.052189234644174576, 0.010441737249493599], [-0.04013266786932945, -0.1660502403974533, 0.052398066967725754, 0.3191235661506653], [-0.043453674763441086, -0.36187776923179626, 0.05878053978085518, 0.6278597116470337], [-0.05069122835993767, -0.1676233410835266, 0.0713377296924591, 0.35425281524658203], [-0.05404369533061981, -0.36368328332901, 0.07842279225587845, 0.6685502529144287], [-0.061317361891269684, -0.5598031282424927, 0.09179379791021347, 0.9848584532737732], [-0.07251342386007309, -0.36602261662483215, 0.11149096488952637, 0.7223610281944275], [-0.07983388006687164, -0.17260494828224182, 0.12593817710876465, 0.4667462706565857], [-0.08328597247600555, -0.3692602813243866, 0.13527311384677887, 0.7963211536407471], [-0.09067118167877197, -0.17622822523117065, 0.15119953453540802, 0.5490673184394836], [-0.09419574588537216, 0.01648259162902832, 0.16218088567256927, 0.30758246779441833], [-0.09386609494686127, 0.20896689593791962, 0.1683325320482254, 0.0701172798871994], [-0.08968675881624222, 0.011881737969815731, 0.1697348803281784, 0.4108228385448456], [-0.08944912254810333, -0.18518869578838348, 0.17795133590698242, 0.751843273639679], [-0.09315289556980133, 0.00709147984161973, 0.19298820197582245, 0.5200196504592896], [-0.09301106631755829, -0.1901485174894333, 0.20338858664035797, 0.8667741417884827], [-0.09681403636932373, 0.0017116105882450938, 0.22072407603263855, 0.6443008184432983]]} From 5d195896d81576d1e9c203c55912fc37eb42074a Mon Sep 17 00:00:00 2001 From: Edilmo Palencia Date: Mon, 5 Oct 2020 19:56:47 -0700 Subject: [PATCH 3/6] RNN support in DDPG and SAC --- rllib/agents/ddpg/ddpg_tf_policy.py | 32 ++++++++++++++++--- rllib/agents/ddpg/ddpg_torch_policy.py | 28 ++++++++++++++--- rllib/agents/dqn/dqn_tf_policy.py | 43 +++++++++++++++----------- rllib/agents/dqn/dqn_torch_policy.py | 31 ++++++++++++++++--- rllib/agents/sac/sac_tf_policy.py | 22 ++++++++++--- rllib/agents/sac/sac_torch_policy.py | 30 +++++++++++++----- rllib/policy/torch_policy.py | 2 ++ rllib/utils/tf_ops.py | 23 +++++++++++--- 8 files changed, 164 insertions(+), 47 deletions(-) diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index b57b53de1b68..8ca865b04893 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -105,6 +105,8 @@ def build_ddpg_models(policy, observation_space, action_space, config): def get_distribution_inputs_and_class(policy, model, obs_batch, + state_batches, + seq_lens, *, explore=True, is_training=False, @@ -112,7 +114,7 @@ def get_distribution_inputs_and_class(policy, model_out, _ = model({ "obs": obs_batch, "is_training": is_training, - }, [], None) + }, state_batches, seq_lens) dist_inputs = model.get_policy_output(model_out) return dist_inputs, (TorchDeterministic @@ -128,6 +130,18 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): huber_threshold = policy.config["huber_threshold"] l2_reg = policy.config["l2_reg"] + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + input_dict = { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": True, @@ -137,9 +151,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): "is_training": True, } - model_out_t, _ = model(input_dict, [], None) - model_out_tp1, _ = model(input_dict_next, [], None) - target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None) + model_out_t, _ = model(input_dict, states_in, seq_lens) + model_out_tp1, _ = model(input_dict_next, states_out, seq_lens) + target_model_out_tp1, _ = policy.target_model(input_dict_next, states_out, seq_lens) # Policy network evaluation. with tf.variable_scope(POLICY_SCOPE, reuse=True): @@ -246,6 +260,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] + for i, h in enumerate(states_in): + input_dict["state_in_{}".format(i)] = h + for i, h in enumerate(states_out): + input_dict["state_out_{}".format(i)] = h if log_once("ddpg_custom_loss"): logger.warning( "You are using a state-preprocessor with DDPG and " @@ -362,7 +380,10 @@ class ComputeTDErrorMixin: def __init__(self, loss_fn): @make_tf_callable(self.get_session(), dynamic_shape=True) def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs # Do forward pass on loss to update td errors attribute # (one TD-error value per item in batch to update PR weights). loss_fn( @@ -373,6 +394,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1), SampleBatch.DONES: tf.convert_to_tensor(done_mask), PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights), + **state_in_out_and_seq_lens }) # `self.td_error` is set in loss_fn. return self.td_error diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index 1c8d1d7a9bfa..6337e8ec14a5 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -35,6 +35,18 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): huber_threshold = policy.config["huber_threshold"] l2_reg = policy.config["l2_reg"] + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + input_dict = { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": True, @@ -44,9 +56,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): "is_training": True, } - model_out_t, _ = model(input_dict, [], None) - model_out_tp1, _ = model(input_dict_next, [], None) - target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None) + model_out_t, _ = model(input_dict, states_in, seq_lens) + model_out_tp1, _ = model(input_dict_next, states_out, seq_lens) + target_model_out_tp1, _ = policy.target_model(input_dict_next, states_out, seq_lens) # Policy network evaluation. # prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) @@ -146,6 +158,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] + for i, h in enumerate(states_in): + input_dict["state_in_{}".format(i)] = h + for i, h in enumerate(states_out): + input_dict["state_out_{}".format(i)] = h [actor_loss, critic_loss] = model.custom_loss( [actor_loss, critic_loss], input_dict) @@ -214,7 +230,10 @@ def before_init_fn(policy, obs_space, action_space, config): class ComputeTDErrorMixin: def __init__(self, loss_fn): def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs input_dict = self._lazy_tensor_dict({ SampleBatch.CUR_OBS: obs_t, SampleBatch.ACTIONS: act_t, @@ -223,6 +242,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, SampleBatch.DONES: done_mask, PRIO_WEIGHTS: importance_weights, }) + input_dict.update(state_in_out_and_seq_lens) # Do forward pass on loss to update td errors attribute # (one TD-error value per item in batch to update PR weights). loss_fn(self, self.model, None, input_dict) diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index 0802ae284f8f..fe975784aa93 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -111,15 +111,11 @@ def __init__(self, class ComputeTDErrorMixin: def __init__(self): @make_tf_callable(self.get_session(), dynamic_shape=True) - def compute_td_error(obs_t, states_t, act_t, rew_t, obs_tp1, done_mask, - states_tp1, seq_lens, importance_weights): - state_in_out_and_seq_lens = {} - for i, h in enumerate(states_t): - state_in_out_and_seq_lens["state_in_{}".format(i)] = h - for i, h in enumerate(states_tp1): - state_in_out_and_seq_lens["state_out_{}".format(i)] = h - if state_in_out_and_seq_lens: - state_in_out_and_seq_lens["seq_lens"] = seq_lens + def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs # Do forward pass on loss to update td error attribute build_q_losses( self, self.model, None, { @@ -222,7 +218,7 @@ def build_q_losses(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else None + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.array([]) config = policy.config # q network evaluation q_t, q_logits_t, q_dist_t = compute_q_values( @@ -323,6 +319,8 @@ def setup_late_mixins(policy, obs_space, action_space, config): def compute_q_values(policy, model, obs, states, seq_lens, explore): config = policy.config + if type(states) is tuple: + states = list(states) model_out, state = model({ SampleBatch.CUR_OBS: obs, "is_training": policy._get_is_training_placeholder(), @@ -386,21 +384,30 @@ def _adjust_nstep(n_step, gamma, obs, states_in, actions, rewards, new_obs, done new_obs[i] = new_obs[i + j] dones[i] = dones[i + j] rewards[i] += gamma**j * rewards[i + j] - states_out[i] = states_out[i + j] + if states_out: + states_out[i] = states_out[i + j] def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None): + state_in_out_and_seq_lens = {} states_in = [] i = 0 - while "state_in_{}".format(i) in batch: - states_in.append(batch["state_in_{}".format(i)]) + key ="state_in_{}".format(i) + while key in batch: + states_in.append(batch[key]) + state_in_out_and_seq_lens[key] = batch[key] i += 1 + key ="state_in_{}".format(i) states_out = [] i = 0 - while "state_out_{}".format(i) in batch: - states_out.append(batch["state_out_{}".format(i)]) + key ="state_out_{}".format(i) + while key in batch: + states_out.append(batch[key]) + state_in_out_and_seq_lens[key] = batch[key] i += 1 - seq_lens = batch["seq_lens"] if "seq_lens" in batch else None + key = "state_out_{}".format(i) + seq_lens = batch["seq_lens"] if "seq_lens" in batch else np.array([]) + state_in_out_and_seq_lens["seq_lens"] = seq_lens # N-step Q adjustments if policy.config["n_step"] > 1: _adjust_nstep(policy.config["n_step"], policy.config["gamma"], @@ -419,9 +426,9 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None): elif isinstance(policy.action_space, Discrete) and actions.shape[-1] == 1: actions = np.reshape(actions, [-1]) td_errors = policy.compute_td_error( - batch[SampleBatch.CUR_OBS], states_in, actions, + batch[SampleBatch.CUR_OBS], actions, batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS], - batch[SampleBatch.DONES], states_out, seq_lens, batch[PRIO_WEIGHTS]) + batch[SampleBatch.DONES], batch[PRIO_WEIGHTS], **state_in_out_and_seq_lens) new_priorities = ( np.abs(td_errors) + policy.config["prioritized_replay_eps"]) batch.data[PRIO_WEIGHTS] = new_priorities diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index 22bf7905f70e..f57a85d4bee1 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -60,13 +60,17 @@ def __init__(self, class ComputeTDErrorMixin: def __init__(self): def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs input_dict = self._lazy_tensor_dict({SampleBatch.CUR_OBS: obs_t}) input_dict[SampleBatch.ACTIONS] = act_t input_dict[SampleBatch.REWARDS] = rew_t input_dict[SampleBatch.NEXT_OBS] = obs_tp1 input_dict[SampleBatch.DONES] = done_mask input_dict[PRIO_WEIGHTS] = importance_weights + input_dict.update(state_in_out_and_seq_lens) # Do forward pass on loss to update td error attribute build_q_losses(self, self.model, None, input_dict) @@ -137,11 +141,13 @@ def build_q_model_and_distribution(policy, obs_space, action_space, config): def get_distribution_inputs_and_class(policy, model, obs_batch, + state_batches, + seq_lens, *, explore=True, is_training=False, **kwargs): - q_vals = compute_q_values(policy, model, obs_batch, explore, is_training) + q_vals = compute_q_values(policy, model, obs_batch, state_batches, seq_lens, explore, is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_values = q_vals @@ -149,12 +155,25 @@ def get_distribution_inputs_and_class(policy, def build_q_losses(policy, model, _, train_batch): + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] config = policy.config # q network evaluation q_t = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.CUR_OBS], + states_in, + seq_lens, explore=False, is_training=True) @@ -163,6 +182,8 @@ def build_q_losses(policy, model, _, train_batch): policy, policy.target_q_model, train_batch[SampleBatch.NEXT_OBS], + states_out, + seq_lens, explore=False, is_training=True) @@ -177,6 +198,8 @@ def build_q_losses(policy, model, _, train_batch): policy, policy.q_model, train_batch[SampleBatch.NEXT_OBS], + states_out, + seq_lens, explore=False, is_training=True) q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1) @@ -221,14 +244,14 @@ def after_init(policy, obs_space, action_space, config): policy.target_q_model = policy.target_q_model.to(policy.device) -def compute_q_values(policy, model, obs, explore, is_training=False): +def compute_q_values(policy, model, obs, states, seq_lens, explore, is_training=False): if policy.config["num_atoms"] > 1: raise ValueError("torch DQN does not support distributional DQN yet!") model_out, state = model({ SampleBatch.CUR_OBS: obs, "is_training": is_training, - }, [], None) + }, states, seq_lens) advantages_or_q_values = model.get_advantages_or_q_values(model_out) diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 94cb2553b26f..19b7583aa9d9 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -116,6 +116,8 @@ def get_dist_class(config, action_space): def get_distribution_inputs_and_class(policy, model, obs_batch, + state_batches, + seq_lens, *, explore=True, **kwargs): @@ -123,7 +125,7 @@ def get_distribution_inputs_and_class(policy, model_out, state_out = model({ "obs": obs_batch, "is_training": policy._get_is_training_placeholder(), - }, [], None) + }, state_batches, seq_lens) # Get action model output from base-model output. distribution_inputs = model.get_policy_output(model_out) action_dist_class = get_dist_class(policy.config, policy.action_space) @@ -134,20 +136,32 @@ def sac_actor_critic_loss(policy, model, _, train_batch): # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + model_out_t, _ = model({ "obs": train_batch[SampleBatch.CUR_OBS], "is_training": policy._get_is_training_placeholder(), - }, [], None) + }, states_in, seq_lens) model_out_tp1, _ = model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), - }, [], None) + }, states_out, seq_lens) target_model_out_tp1, _ = policy.target_model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), - }, [], None) + }, states_out, seq_lens) # Discrete case. if model.discrete: diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 503bc4979498..3660540fe64d 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -39,9 +39,9 @@ def get_dist_class(config, action_space): def action_distribution_fn(policy, model, obs_batch, + state_batches, + seq_lens, *, - state_batches=None, - seq_lens=None, prev_action_batch=None, prev_reward_batch=None, explore=None, @@ -50,7 +50,7 @@ def action_distribution_fn(policy, model_out, _ = model({ "obs": obs_batch, "is_training": is_training, - }, [], None) + }, state_batches, seq_lens) distribution_inputs = model.get_policy_output(model_out) action_dist_class = get_dist_class(policy.config, policy.action_space) @@ -61,20 +61,32 @@ def actor_critic_loss(policy, model, _, train_batch): # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] + states_in = [] + i = 0 + while "state_in_{}".format(i) in train_batch: + states_in.append(train_batch["state_in_{}".format(i)]) + i += 1 + states_out = [] + i = 0 + while "state_out_{}".format(i) in train_batch: + states_out.append(train_batch["state_out_{}".format(i)]) + i += 1 + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + model_out_t, _ = model({ "obs": train_batch[SampleBatch.CUR_OBS], "is_training": True, - }, [], None) + }, states_in, seq_lens) model_out_tp1, _ = model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": True, - }, [], None) + }, states_out, seq_lens) target_model_out_tp1, _ = policy.target_model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": True, - }, [], None) + }, states_out, seq_lens) alpha = torch.exp(model.log_alpha) @@ -280,7 +292,10 @@ def optimizer_fn(policy, config): class ComputeTDErrorMixin: def __init__(self): def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): + importance_weights, **kwargs): + # kwargs should contain only state input tensors, + # state output tensors and the seq_lens tensor + state_in_out_and_seq_lens = kwargs input_dict = self._lazy_tensor_dict({ SampleBatch.CUR_OBS: obs_t, SampleBatch.ACTIONS: act_t, @@ -289,6 +304,7 @@ def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, SampleBatch.DONES: done_mask, PRIO_WEIGHTS: importance_weights, }) + input_dict.update(state_in_out_and_seq_lens) # Do forward pass on loss to update td errors attribute # (one TD-error value per item in batch to update PR weights). actor_critic_loss(self, self.model, None, input_dict) diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index c13ff5f8afe5..13562158705d 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -142,6 +142,8 @@ def compute_actions(self, self, self.model, input_dict[SampleBatch.CUR_OBS], + state_batches=state_batches, + seq_lens=seq_lens, explore=explore, timestep=timestep, is_training=False) diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index ea636f744f04..e8871e947fb6 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -58,10 +58,11 @@ def make_tf_callable(session_or_none, dynamic_shape=False): def make_wrapper(fn): if session_or_none: - placeholders = [] + args_placeholders = [] + kwargs_placeholders = {} symbolic_out = [None] - def call(*args): + def call(*args, **kwargs): args_flat = [] for a in args: if type(a) is list: @@ -79,13 +80,25 @@ def call(*args): shape = () else: shape = v.shape - placeholders.append( + args_placeholders.append( tf.placeholder( dtype=v.dtype, shape=shape, name="arg_{}".format(i))) - symbolic_out[0] = fn(*placeholders) - feed_dict = dict(zip(placeholders, args)) + for k, v in kwargs.items(): + if dynamic_shape: + if len(v.shape) > 0: + shape = (None, ) + v.shape[1:] + else: + shape = () + else: + shape = v.shape + kwargs_placeholders[k] = tf.placeholder( + dtype=v.dtype, + shape=shape, + name="karg_{}".format(k)) + symbolic_out[0] = fn(*args_placeholders, **kwargs_placeholders) + feed_dict = dict(zip(args_placeholders, args)) ret = session_or_none.run(symbolic_out[0], feed_dict) return ret From a09e815b2ccb93285510820d52caa4fb355da215 Mon Sep 17 00:00:00 2001 From: Edilmo Palencia Date: Tue, 6 Oct 2020 01:18:27 -0700 Subject: [PATCH 4/6] Fixing missing changes --- .../tests/test_prioritized_replay_buffer.py | 19 +++++++++++-------- rllib/policy/eager_tf_policy.py | 8 +++++++- rllib/policy/torch_policy.py | 2 ++ rllib/tests/test_evaluators.py | 8 +++++++- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/rllib/execution/tests/test_prioritized_replay_buffer.py b/rllib/execution/tests/test_prioritized_replay_buffer.py index 38441b8962b8..4b7adad676f4 100644 --- a/rllib/execution/tests/test_prioritized_replay_buffer.py +++ b/rllib/execution/tests/test_prioritized_replay_buffer.py @@ -20,9 +20,12 @@ def _generate_data(self): return ( np.random.random((4, )), # obs_t np.random.choice([0, 1]), # action + [np.random.random((2, )), np.random.random((3, ))], # state t np.random.rand(), # reward np.random.random((4, )), # obs_tp1 np.random.choice([False, True]), # done + [np.random.random((2,)), np.random.random((3,))], # state tp1 + np.random.rand(), # seq lens ) def test_add(self): @@ -65,7 +68,7 @@ def test_update_priorities(self): self.assertTrue(memory._next_idx == i + 1) # Fetch records, their indices and weights. - _, _, _, _, _, weights, indices = \ + _, _, _, _, _, _, _, _, weights, indices = \ memory.sample(3, beta=self.beta) check(weights, np.ones(shape=(3, ))) self.assertEqual(3, len(indices)) @@ -78,7 +81,7 @@ def test_update_priorities(self): # Expect to sample almost only index 1 # (which still has a weight of 1.0). for _ in range(10): - _, _, _, _, _, weights, indices = memory.sample( + _, _, _, _, _, _, _, _, weights, indices = memory.sample( 1000, beta=self.beta) self.assertTrue(970 < np.sum(indices) < 1100) @@ -87,7 +90,7 @@ def test_update_priorities(self): for _ in range(10): rand = np.random.random() + 0.2 memory.update_priorities(np.array([0, 1]), np.array([rand, rand])) - _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) + _, _, _, _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) # Expect biased to higher values due to some 2s, 3s, and 4s. # print(np.sum(indices)) self.assertTrue(400 < np.sum(indices) < 800) @@ -99,7 +102,7 @@ def test_update_priorities(self): rand = np.random.random() + 0.2 memory.update_priorities( np.array([0, 1]), np.array([rand, rand * 2])) - _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) + _, _, _, _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) # print(np.sum(indices)) self.assertTrue(600 < np.sum(indices) < 850) @@ -110,7 +113,7 @@ def test_update_priorities(self): rand = np.random.random() + 0.2 memory.update_priorities( np.array([0, 1]), np.array([rand, rand * 4])) - _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) + _, _, _, _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) # print(np.sum(indices)) self.assertTrue(750 < np.sum(indices) < 950) @@ -121,7 +124,7 @@ def test_update_priorities(self): rand = np.random.random() + 0.2 memory.update_priorities( np.array([0, 1]), np.array([rand, rand * 9])) - _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) + _, _, _, _, _, _, _, _, _, indices = memory.sample(1000, beta=self.beta) # print(np.sum(indices)) self.assertTrue(850 < np.sum(indices) < 1100) @@ -139,7 +142,7 @@ def test_update_priorities(self): np.array([0.001, 0.1, 2., 8., 16., 32., 64., 128., 256., 512.])) counts = Counter() for _ in range(10): - _, _, _, _, _, _, indices = memory.sample( + _, _, _, _, _, _, _, _, _, indices = memory.sample( np.random.randint(100, 600), beta=self.beta) for i in indices: counts[i] += 1 @@ -163,7 +166,7 @@ def test_alpha_parameter(self): self.assertTrue(memory._next_idx == i + 1) # Fetch records, their indices and weights. - _, _, _, _, _, weights, indices = \ + _, _, _, _, _, _, _, _, weights, indices = \ memory.sample(1000, beta=self.beta) counts = Counter() for i in indices: diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 8dd15584f170..ad39e5eecc15 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -248,7 +248,9 @@ def __init__(self, observation_space, action_space, config): if action_distribution_fn: dist_inputs, self.dist_class, _ = action_distribution_fn( - self, self.model, input_dict[SampleBatch.CUR_OBS]) + self, self.model, input_dict[SampleBatch.CUR_OBS], + state_batches=self._state_in, + seq_lens=tf.convert_to_tensor([1])) else: self.model(input_dict, self._state_in, tf.convert_to_tensor([1])) @@ -360,6 +362,8 @@ def compute_actions(self, action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], + state_batches=state_batches, + seq_lens=seq_lens, explore=explore, timestep=timestep, is_training=False) @@ -428,6 +432,8 @@ def compute_log_likelihoods(self, self, self.model, input_dict[SampleBatch.CUR_OBS], + state_batches=state_batches, + seq_lens=seq_lens, explore=False, is_training=False) action_dist = dist_class(dist_inputs, self.model) diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 13562158705d..449678e076de 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -209,6 +209,8 @@ def compute_log_likelihoods(self, policy=self, model=self.model, obs_batch=input_dict[SampleBatch.CUR_OBS], + state_batches=state_batches, + seq_lens=seq_lens, explore=False, is_training=False) # Default action-dist inputs calculation. diff --git a/rllib/tests/test_evaluators.py b/rllib/tests/test_evaluators.py index 93399fa90c5d..74e220891331 100644 --- a/rllib/tests/test_evaluators.py +++ b/rllib/tests/test_evaluators.py @@ -12,17 +12,23 @@ class EvalTest(unittest.TestCase): def test_dqn_n_step(self): obs = [1, 2, 3, 4, 5, 6, 7] + state_in = [[1], [2], [3], [4], [5], [6], [7]] actions = ["a", "b", "a", "a", "a", "b", "a"] rewards = [10.0, 0.0, 100.0, 100.0, 100.0, 100.0, 100.0] new_obs = [2, 3, 4, 5, 6, 7, 8] dones = [0, 0, 0, 0, 0, 0, 1] - _adjust_nstep(3, 0.9, obs, actions, rewards, new_obs, dones) + state_out = [[1], [2], [3], [4], [5], [6], [7]] + seq_lens = [1, 2, 3, 4, 5, 6, 7] + _adjust_nstep(3, 0.9, obs, state_in, actions, rewards, new_obs, dones, state_out, seq_lens) self.assertEqual(obs, [1, 2, 3, 4, 5, 6, 7]) + self.assertEqual(state_in, [[1], [2], [3], [4], [5], [6], [7]]) self.assertEqual(actions, ["a", "b", "a", "a", "a", "b", "a"]) self.assertEqual(new_obs, [4, 5, 6, 7, 8, 8, 8]) self.assertEqual(dones, [0, 0, 0, 0, 1, 1, 1]) self.assertEqual(rewards, [91.0, 171.0, 271.0, 271.0, 271.0, 190.0, 100.0]) + self.assertEqual(state_out, [[3], [4], [5], [6], [7], [7], [7]]) + self.assertEqual(seq_lens, [1, 2, 3, 4, 5, 6, 7]) def test_evaluation_option(self): def env_creator(env_config): From 8ca19130dbe0688268c3d88824118984107bd0c4 Mon Sep 17 00:00:00 2001 From: Edilmo Palencia Date: Thu, 8 Oct 2020 08:30:38 -0700 Subject: [PATCH 5/6] Fixing DDPG and SAC --- rllib/agents/ddpg/ddpg_tf_policy.py | 12 ++--- rllib/agents/ddpg/ddpg_torch_policy.py | 9 ++-- rllib/agents/dqn/dqn_tf_policy.py | 19 ++++---- rllib/agents/dqn/dqn_torch_policy.py | 15 +++--- rllib/agents/sac/sac_tf_policy.py | 3 +- rllib/agents/sac/sac_torch_policy.py | 7 +-- rllib/execution/replay_buffer.py | 2 +- rllib/policy/tf_policy.py | 4 ++ rllib/tests/agents/parameters.py | 64 +++++++++++++++++++++++++- rllib/tests/agents/test_learning.py | 40 +++++++++++++++- rllib/utils/tf_ops.py | 3 +- 11 files changed, 144 insertions(+), 34 deletions(-) diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index 8ca865b04893..65a9a765b6eb 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -111,7 +111,7 @@ def get_distribution_inputs_and_class(policy, explore=True, is_training=False, **kwargs): - model_out, _ = model({ + model_out, state_out = model({ "obs": obs_batch, "is_training": is_training, }, state_batches, seq_lens) @@ -119,7 +119,7 @@ def get_distribution_inputs_and_class(policy, return dist_inputs, (TorchDeterministic if policy.config["framework"] == "torch" else - Deterministic), [] # []=state out + Deterministic), state_out # state out def ddpg_actor_critic_loss(policy, model, _, train_batch): @@ -140,7 +140,7 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) input_dict = { "obs": train_batch[SampleBatch.CUR_OBS], @@ -151,9 +151,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): "is_training": True, } - model_out_t, _ = model(input_dict, states_in, seq_lens) - model_out_tp1, _ = model(input_dict_next, states_out, seq_lens) - target_model_out_tp1, _ = policy.target_model(input_dict_next, states_out, seq_lens) + model_out_t, states_out_t = model(input_dict, states_in, seq_lens) + model_out_tp1, states_out_tp1 = model(input_dict_next, states_out, seq_lens) + target_model_out_tp1, target_states_out_tp1 = policy.target_model(input_dict_next, states_out, seq_lens) # Policy network evaluation. with tf.variable_scope(POLICY_SCOPE, reuse=True): diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index 6337e8ec14a5..d5caf53ba8c2 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -1,4 +1,5 @@ import logging +import numpy as np import ray from ray.rllib.agents.ddpg.ddpg_tf_policy import build_ddpg_models, \ @@ -45,7 +46,7 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) input_dict = { "obs": train_batch[SampleBatch.CUR_OBS], @@ -56,9 +57,9 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): "is_training": True, } - model_out_t, _ = model(input_dict, states_in, seq_lens) - model_out_tp1, _ = model(input_dict_next, states_out, seq_lens) - target_model_out_tp1, _ = policy.target_model(input_dict_next, states_out, seq_lens) + model_out_t, states_out_t = model(input_dict, states_in, seq_lens) + model_out_tp1, states_out_tp1 = model(input_dict_next, states_out, seq_lens) + target_model_out_tp1, target_states_out_tp1 = policy.target_model(input_dict_next, states_out, seq_lens) # Policy network evaluation. # prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index fe975784aa93..9c7d07362624 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -200,11 +200,12 @@ def get_distribution_inputs_and_class(policy, explore=True, **kwargs): q_vals = compute_q_values(policy, model, obs_batch, state_batches, seq_lens, explore) + state_out = q_vals[3] if isinstance(q_vals, tuple) else [] q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_values = q_vals policy.q_func_vars = model.variables() - return policy.q_values, Categorical, [] # state-out + return policy.q_values, Categorical, state_out # state-out def build_q_losses(policy, model, _, train_batch): @@ -218,10 +219,10 @@ def build_q_losses(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.array([]) + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) config = policy.config # q network evaluation - q_t, q_logits_t, q_dist_t = compute_q_values( + q_t, q_logits_t, q_dist_t, q_state_t = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.CUR_OBS], @@ -230,7 +231,7 @@ def build_q_losses(policy, model, _, train_batch): explore=False) # target q network evalution - q_tp1, q_logits_tp1, q_dist_tp1 = compute_q_values( + q_tp1, q_logits_tp1, q_dist_tp1, q_state_tp1 = compute_q_values( policy, policy.target_q_model, train_batch[SampleBatch.NEXT_OBS], @@ -250,7 +251,7 @@ def build_q_losses(policy, model, _, train_batch): # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: q_tp1_using_online_net, q_logits_tp1_using_online_net, \ - q_dist_tp1_using_online_net = compute_q_values( + q_dist_tp1_using_online_net, q_state_tp1_using_online_net = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.NEXT_OBS], states_out, @@ -357,7 +358,7 @@ def compute_q_values(policy, model, obs, states, seq_lens, explore): else: value = action_scores - return value, logits, dist + return value, logits, dist, state def _adjust_nstep(n_step, gamma, obs, states_in, actions, rewards, new_obs, dones, @@ -384,8 +385,10 @@ def _adjust_nstep(n_step, gamma, obs, states_in, actions, rewards, new_obs, done new_obs[i] = new_obs[i + j] dones[i] = dones[i + j] rewards[i] += gamma**j * rewards[i + j] - if states_out: + if len(states_out) >= (i + j + 1): states_out[i] = states_out[i + j] + elif len(states_out) >= (i + 1): + states_out[i] = states_out[-1] def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None): @@ -406,7 +409,7 @@ def postprocess_nstep_and_prio(policy, batch, other_agent=None, episode=None): state_in_out_and_seq_lens[key] = batch[key] i += 1 key = "state_out_{}".format(i) - seq_lens = batch["seq_lens"] if "seq_lens" in batch else np.array([]) + seq_lens = batch["seq_lens"] if "seq_lens" in batch else np.ones(len(batch[SampleBatch.CUR_OBS])) state_in_out_and_seq_lens["seq_lens"] = seq_lens # N-step Q adjustments if policy.config["n_step"] > 1: diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index f57a85d4bee1..ee954699057d 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -1,4 +1,5 @@ from gym.spaces import Discrete +import numpy as np import ray from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \ @@ -147,11 +148,11 @@ def get_distribution_inputs_and_class(policy, explore=True, is_training=False, **kwargs): - q_vals = compute_q_values(policy, model, obs_batch, state_batches, seq_lens, explore, is_training) + q_vals, state_out = compute_q_values(policy, model, obs_batch, state_batches, seq_lens, explore, is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_values = q_vals - return policy.q_values, TorchCategorical, [] # state-out + return policy.q_values, TorchCategorical, state_out # state-out def build_q_losses(policy, model, _, train_batch): @@ -165,10 +166,10 @@ def build_q_losses(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) config = policy.config # q network evaluation - q_t = compute_q_values( + q_t, q_state_t = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.CUR_OBS], @@ -178,7 +179,7 @@ def build_q_losses(policy, model, _, train_batch): is_training=True) # target q network evalution - q_tp1 = compute_q_values( + q_tp1, q_state_tp1 = compute_q_values( policy, policy.target_q_model, train_batch[SampleBatch.NEXT_OBS], @@ -194,7 +195,7 @@ def build_q_losses(policy, model, _, train_batch): # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: - q_tp1_using_online_net = compute_q_values( + q_tp1_using_online_net, q_state_tp1_using_online_net = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.NEXT_OBS], @@ -264,7 +265,7 @@ def compute_q_values(policy, model, obs, states, seq_lens, explore, is_training= else: q_values = advantages_or_q_values - return q_values + return q_values, state def grad_process_and_td_error_fn(policy, optimizer, loss): diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 19b7583aa9d9..04e13bf78d7a 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -1,5 +1,6 @@ from gym.spaces import Box, Discrete import logging +import numpy as np import ray import ray.experimental.tf_utils @@ -146,7 +147,7 @@ def sac_actor_critic_loss(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) model_out_t, _ = model({ "obs": train_batch[SampleBatch.CUR_OBS], diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index 3660540fe64d..b38c6919c500 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -1,5 +1,6 @@ from gym.spaces import Discrete import logging +import numpy as np import ray import ray.experimental.tf_utils @@ -47,14 +48,14 @@ def action_distribution_fn(policy, explore=None, timestep=None, is_training=None): - model_out, _ = model({ + model_out, state_out = model({ "obs": obs_batch, "is_training": is_training, }, state_batches, seq_lens) distribution_inputs = model.get_policy_output(model_out) action_dist_class = get_dist_class(policy.config, policy.action_space) - return distribution_inputs, action_dist_class, [] + return distribution_inputs, action_dist_class, state_out def actor_critic_loss(policy, model, _, train_batch): @@ -71,7 +72,7 @@ def actor_critic_loss(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else [] + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) model_out_t, _ = model({ "obs": train_batch[SampleBatch.CUR_OBS], diff --git a/rllib/execution/replay_buffer.py b/rllib/execution/replay_buffer.py index 1e48b8ca20f5..c7eee58e83d5 100644 --- a/rllib/execution/replay_buffer.py +++ b/rllib/execution/replay_buffer.py @@ -386,7 +386,7 @@ def add_batch(self, batch): self.replay_buffers[policy_id].add( row["obs"], states_in, row["actions"], row["rewards"], row["new_obs"], row["dones"], states_out, - row["seq_lens"] if "seq_lens" in row else None, + row["seq_lens"] if "seq_lens" in row else np.ones(len(row["obs"])), row["weights"] if "weights" in row else None) self.num_added += batch.count diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index d3b00a2d89d1..56b6b149fef8 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -598,6 +598,10 @@ def _build_compute_actions(self, raise ValueError( "Must pass in RNN state batches for placeholders {}, got {}". format(self._state_inputs, state_batches)) + if len(obs_batch) == 0 and state_batches: + raise ValueError( + "Must pass obs_batch in order to build RNN properly {}, got {}". + format(obs_batch, state_batches)) builder.add_feed_dict(self.extra_compute_action_feed_dict()) builder.add_feed_dict({self._obs_input: obs_batch}) diff --git a/rllib/tests/agents/parameters.py b/rllib/tests/agents/parameters.py index 6eb5d9fb236d..fc6dd2ec328b 100644 --- a/rllib/tests/agents/parameters.py +++ b/rllib/tests/agents/parameters.py @@ -1,8 +1,9 @@ """Contains the pytest params used in `test_agents` tests.""" from itertools import chain -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import attr +import gym import pytest from ray.rllib.agents.trainer_factory import ( @@ -11,6 +12,7 @@ ContinuousActionSpaceAlgorithm, Framework, ) +from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole @attr.s(auto_attribs=True) @@ -27,7 +29,7 @@ def for_frameworks( cls, algorithm: Algorithm, config_updates: dict, - env: str, + env: Union[str, gym.Env], frameworks: Optional[List[Framework]] = None, n_iter=2, threshold=1.0, @@ -64,6 +66,24 @@ def for_cart_pole( threshold=threshold, ) + @classmethod + def for_stateless_cart_pole( + cls, + algorithm: Algorithm, + config_updates: dict, + frameworks: Optional[List[Framework]] = None, + n_iter=2, + threshold=1.0, + ) -> List["TestAgentParams"]: + return cls.for_frameworks( + algorithm=algorithm, + config_updates=config_updates, + env=StatelessCartPole, + frameworks=frameworks, + n_iter=n_iter, + threshold=threshold, + ) + @classmethod def for_pendulum( cls, @@ -338,6 +358,46 @@ def astuple(self): ] +test_rnn_monotonic_convergence_params: List[ + Tuple[Algorithm, dict, str, Framework, int, float] +] = [ + x.astuple() + for x in chain( + # TestAgentParams.for_pendulum( + # algorithm=ContinuousActionSpaceAlgorithm.APEX_DDPG, + # config_updates={ + # "use_huber": True, + # "clip_rewards": False, + # "num_workers": 4, + # "n_step": 1, + # "target_network_update_freq": 50000, + # "tau": 1.0, + # }, + # n_iter=200, + # threshold=-750.0, + # ), + TestAgentParams.for_stateless_cart_pole( + # TestAgentParams.for_cart_pole( + algorithm=DiscreteActionSpaceAlgorithm.APEX_DQN, + config_updates={ + "target_network_update_freq": 20000, + "num_workers": 1, + "num_envs_per_worker": 8, + "train_batch_size": 64, + "gamma": 0.95, + "model": { + "use_lstm": True, + "lstm_use_prev_action_reward": True, + }, + }, + n_iter=200, + threshold=150.0, + frameworks=[Framework.TensorFlow] + ), + ) +] + + if __name__ == "__main__": import sys diff --git a/rllib/tests/agents/test_learning.py b/rllib/tests/agents/test_learning.py index beeb1e9ff881..25523d406ba2 100644 --- a/rllib/tests/agents/test_learning.py +++ b/rllib/tests/agents/test_learning.py @@ -13,7 +13,7 @@ from ray.rllib.agents.trainer_factory import Algorithm, Framework from ray.rllib.tests.agents.parameters import ( test_convergence_params, - test_monotonic_convergence_params, + test_monotonic_convergence_params, test_rnn_monotonic_convergence_params, ) @@ -85,3 +85,41 @@ def test_monotonically_improving_algorithms_can_converge_with_different_framewor break assert learnt, f"{episode_reward_mean} < {threshold}" + +@pytest.mark.skip("WIP") +@pytest.mark.minutes +@pytest.mark.usefixtures("ray_env") +@pytest.mark.usefixtures("using_framework") +@pytest.mark.parametrize( + "algorithm, config_overrides, env, framework, n_iter, threshold", + test_rnn_monotonic_convergence_params, +) +def test_rnn_monotonically_improving_algorithms_can_converge_with_different_frameworks( + algorithm: Algorithm, + config_overrides: dict, + env: str, + framework: Framework, + n_iter: int, + threshold: float, + trainer: Trainer, +): + """I should be able to train an algorithm to convergence with the following + frameworks: + 1. TensorFlow (Graph Mode) + 2. TensorFlow (Eager Mode) + 3. PyTorch + NOTE: Not all algorithms have been implemented in all frameworks. + NOTE: For monotonically improving algorithms (like PPO), its enough to stop training + after the episode reward mean of an epoch exceeds the set threshold, even if we + haven't trained for n_iter number of epochs. + """ + learnt = False + episode_reward_mean = -float("inf") + for _ in range(n_iter): + results = trainer.train() + episode_reward_mean = results["episode_reward_mean"] + if episode_reward_mean >= threshold: + learnt = True + break + + assert learnt, f"{episode_reward_mean} < {threshold}" diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index e8871e947fb6..28babe3e8b7a 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -98,7 +98,8 @@ def call(*args, **kwargs): shape=shape, name="karg_{}".format(k)) symbolic_out[0] = fn(*args_placeholders, **kwargs_placeholders) - feed_dict = dict(zip(args_placeholders, args)) + feed_dict = dict(list(zip(args_placeholders, args)) + + [(p, kwargs[k]) for k, p in kwargs_placeholders.items()]) ret = session_or_none.run(symbolic_out[0], feed_dict) return ret From e39c3268b65d8bd89fca2d9cea9085f27bfcff49 Mon Sep 17 00:00:00 2001 From: Edilmo Palencia Date: Thu, 8 Oct 2020 22:26:15 -0700 Subject: [PATCH 6/6] Fixing eager mode --- rllib/agents/ddpg/ddpg_tf_policy.py | 5 ++++- rllib/agents/ddpg/ddpg_torch_policy.py | 5 ++++- rllib/agents/dqn/dqn_tf_policy.py | 5 ++++- rllib/agents/dqn/dqn_torch_policy.py | 5 ++++- rllib/agents/sac/sac_tf_policy.py | 5 ++++- rllib/agents/sac/sac_torch_policy.py | 5 ++++- rllib/tests/agents/test_learning.py | 2 +- 7 files changed, 25 insertions(+), 7 deletions(-) diff --git a/rllib/agents/ddpg/ddpg_tf_policy.py b/rllib/agents/ddpg/ddpg_tf_policy.py index 65a9a765b6eb..0c81b143d3c9 100644 --- a/rllib/agents/ddpg/ddpg_tf_policy.py +++ b/rllib/agents/ddpg/ddpg_tf_policy.py @@ -140,7 +140,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) + batch_size = (train_batch[SampleBatch.CUR_OBS].shape[0] + if isinstance(train_batch[SampleBatch.CUR_OBS], tf.Tensor) + else len(train_batch[SampleBatch.CUR_OBS])) + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(batch_size) input_dict = { "obs": train_batch[SampleBatch.CUR_OBS], diff --git a/rllib/agents/ddpg/ddpg_torch_policy.py b/rllib/agents/ddpg/ddpg_torch_policy.py index d5caf53ba8c2..67d9b407cebc 100644 --- a/rllib/agents/ddpg/ddpg_torch_policy.py +++ b/rllib/agents/ddpg/ddpg_torch_policy.py @@ -46,7 +46,10 @@ def ddpg_actor_critic_loss(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) + batch_size = (train_batch[SampleBatch.CUR_OBS].shape[0] + if isinstance(train_batch[SampleBatch.CUR_OBS], torch.Tensor) + else len(train_batch[SampleBatch.CUR_OBS])) + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(batch_size) input_dict = { "obs": train_batch[SampleBatch.CUR_OBS], diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index 9c7d07362624..e640aa349c97 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -219,7 +219,10 @@ def build_q_losses(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) + batch_size = (train_batch[SampleBatch.CUR_OBS].shape[0] + if isinstance(train_batch[SampleBatch.CUR_OBS], tf.Tensor) + else len(train_batch[SampleBatch.CUR_OBS])) + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(batch_size) config = policy.config # q network evaluation q_t, q_logits_t, q_dist_t, q_state_t = compute_q_values( diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index ee954699057d..9f76cc1cbed5 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -166,7 +166,10 @@ def build_q_losses(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) + batch_size = (train_batch[SampleBatch.CUR_OBS].shape[0] + if isinstance(train_batch[SampleBatch.CUR_OBS], torch.Tensor) + else len(train_batch[SampleBatch.CUR_OBS])) + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(batch_size) config = policy.config # q network evaluation q_t, q_state_t = compute_q_values( diff --git a/rllib/agents/sac/sac_tf_policy.py b/rllib/agents/sac/sac_tf_policy.py index 04e13bf78d7a..c8b8ea9de7d1 100644 --- a/rllib/agents/sac/sac_tf_policy.py +++ b/rllib/agents/sac/sac_tf_policy.py @@ -147,7 +147,10 @@ def sac_actor_critic_loss(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) + batch_size = (train_batch[SampleBatch.CUR_OBS].shape[0] + if isinstance(train_batch[SampleBatch.CUR_OBS], tf.Tensor) + else len(train_batch[SampleBatch.CUR_OBS])) + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(batch_size) model_out_t, _ = model({ "obs": train_batch[SampleBatch.CUR_OBS], diff --git a/rllib/agents/sac/sac_torch_policy.py b/rllib/agents/sac/sac_torch_policy.py index b38c6919c500..45db129b0d58 100644 --- a/rllib/agents/sac/sac_torch_policy.py +++ b/rllib/agents/sac/sac_torch_policy.py @@ -72,7 +72,10 @@ def actor_critic_loss(policy, model, _, train_batch): while "state_out_{}".format(i) in train_batch: states_out.append(train_batch["state_out_{}".format(i)]) i += 1 - seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(len(train_batch[SampleBatch.CUR_OBS])) + batch_size = (train_batch[SampleBatch.CUR_OBS].shape[0] + if isinstance(train_batch[SampleBatch.CUR_OBS], torch.Tensor) + else len(train_batch[SampleBatch.CUR_OBS])) + seq_lens = train_batch["seq_lens"] if "seq_lens" in train_batch else np.ones(batch_size) model_out_t, _ = model({ "obs": train_batch[SampleBatch.CUR_OBS], diff --git a/rllib/tests/agents/test_learning.py b/rllib/tests/agents/test_learning.py index 25523d406ba2..40a17517a05b 100644 --- a/rllib/tests/agents/test_learning.py +++ b/rllib/tests/agents/test_learning.py @@ -86,7 +86,7 @@ def test_monotonically_improving_algorithms_can_converge_with_different_framewor assert learnt, f"{episode_reward_mean} < {threshold}" -@pytest.mark.skip("WIP") + @pytest.mark.minutes @pytest.mark.usefixtures("ray_env") @pytest.mark.usefixtures("using_framework")