From f34c245eb14e7ca90a3519b85af93ece14253d77 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 25 Oct 2021 11:15:08 +0200 Subject: [PATCH 1/3] wip --- rllib/agents/a3c/a3c_tf_policy.py | 2 +- rllib/agents/a3c/a3c_torch_policy.py | 2 +- rllib/agents/impala/vtrace_tf_policy.py | 2 +- rllib/agents/impala/vtrace_torch_policy.py | 2 +- rllib/agents/maml/maml_tf_policy.py | 2 +- rllib/agents/maml/maml_torch_policy.py | 2 +- rllib/agents/marwil/marwil_tf_policy.py | 4 +-- rllib/agents/marwil/marwil_torch_policy.py | 2 +- rllib/agents/marwil/tests/test_marwil.py | 2 +- rllib/agents/pg/pg_tf_policy.py | 2 +- rllib/agents/pg/pg_torch_policy.py | 2 +- rllib/agents/ppo/appo_tf_policy.py | 4 +-- rllib/agents/ppo/appo_torch_policy.py | 4 +-- rllib/agents/ppo/ppo_tf_policy.py | 2 +- rllib/examples/custom_tf_policy.py | 2 +- rllib/examples/eager_execution.py | 2 +- rllib/examples/export/onnx_tf.py | 2 +- rllib/examples/export/onnx_torch.py | 2 +- .../rock_paper_scissors_multiagent.py | 2 +- rllib/models/modelv2.py | 25 +++++++++---------- rllib/utils/typing.py | 2 +- 21 files changed, 35 insertions(+), 36 deletions(-) diff --git a/rllib/agents/a3c/a3c_tf_policy.py b/rllib/agents/a3c/a3c_tf_policy.py index dabd5b1134b8..32f8c29457b4 100644 --- a/rllib/agents/a3c/a3c_tf_policy.py +++ b/rllib/agents/a3c/a3c_tf_policy.py @@ -73,7 +73,7 @@ def __init__(self, def actor_critic_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: - model_out, _ = model.from_batch(train_batch) + model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if policy.is_recurrent(): max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) diff --git a/rllib/agents/a3c/a3c_torch_policy.py b/rllib/agents/a3c/a3c_torch_policy.py index 150d689702e9..1939aaa80b0a 100644 --- a/rllib/agents/a3c/a3c_torch_policy.py +++ b/rllib/agents/a3c/a3c_torch_policy.py @@ -39,7 +39,7 @@ def add_advantages( def actor_critic_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: - logits, _ = model.from_batch(train_batch) + logits, _ = model(train_batch) values = model.value_function() if policy.is_recurrent(): diff --git a/rllib/agents/impala/vtrace_tf_policy.py b/rllib/agents/impala/vtrace_tf_policy.py index f5b5ddc4192d..5a786a4da8e9 100644 --- a/rllib/agents/impala/vtrace_tf_policy.py +++ b/rllib/agents/impala/vtrace_tf_policy.py @@ -161,7 +161,7 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False): def build_vtrace_loss(policy, model, dist_class, train_batch): - model_out, _ = model.from_batch(train_batch) + model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index c8738d1875f6..d8ec81483503 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -113,7 +113,7 @@ def __init__(self, def build_vtrace_loss(policy, model, dist_class, train_batch): - model_out, _ = model.from_batch(train_batch) + model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): diff --git a/rllib/agents/maml/maml_tf_policy.py b/rllib/agents/maml/maml_tf_policy.py index 990005ba0b3e..2126bde0482e 100644 --- a/rllib/agents/maml/maml_tf_policy.py +++ b/rllib/agents/maml/maml_tf_policy.py @@ -308,7 +308,7 @@ def split_placeholders(self, placeholder, split): def maml_loss(policy, model, dist_class, train_batch): - logits, state = model.from_batch(train_batch) + logits, state = model(train_batch) policy.cur_lr = policy.config["lr"] if policy.config["worker_index"]: diff --git a/rllib/agents/maml/maml_torch_policy.py b/rllib/agents/maml/maml_torch_policy.py index 695826798272..a5d71fc55db0 100644 --- a/rllib/agents/maml/maml_torch_policy.py +++ b/rllib/agents/maml/maml_torch_policy.py @@ -246,7 +246,7 @@ def split_placeholders(self, placeholder, split): def maml_loss(policy, model, dist_class, train_batch): - logits, state = model.from_batch(train_batch) + logits, state = model(train_batch) policy.cur_lr = policy.config["lr"] if policy.config["worker_index"]: diff --git a/rllib/agents/marwil/marwil_tf_policy.py b/rllib/agents/marwil/marwil_tf_policy.py index 1748770ec1a3..045c9ae62b15 100644 --- a/rllib/agents/marwil/marwil_tf_policy.py +++ b/rllib/agents/marwil/marwil_tf_policy.py @@ -30,7 +30,7 @@ def __init__(self, obs_space: gym.spaces.Space, # input_dict. @make_tf_callable(self.get_session()) def value(**input_dict): - model_out, _ = self.model.from_batch(input_dict, is_training=False) + model_out, _ = self.model(input_dict) # [0] = remove the batch dim. return self.model.value_function()[0] @@ -150,7 +150,7 @@ def __init__(self, policy: Policy, value_estimates: TensorType, def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: - model_out, _ = model.from_batch(train_batch) + model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) value_estimates = model.value_function() diff --git a/rllib/agents/marwil/marwil_torch_policy.py b/rllib/agents/marwil/marwil_torch_policy.py index f16414a960a1..c15447c4825b 100644 --- a/rllib/agents/marwil/marwil_torch_policy.py +++ b/rllib/agents/marwil/marwil_torch_policy.py @@ -19,7 +19,7 @@ def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: - model_out, _ = model.from_batch(train_batch) + model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) actions = train_batch[SampleBatch.ACTIONS] # log\pi_\theta(a|s) diff --git a/rllib/agents/marwil/tests/test_marwil.py b/rllib/agents/marwil/tests/test_marwil.py index b8ca7af86ae2..f16af43cc1e1 100644 --- a/rllib/agents/marwil/tests/test_marwil.py +++ b/rllib/agents/marwil/tests/test_marwil.py @@ -149,7 +149,7 @@ def test_marwil_loss_function(self): cummulative_rewards = torch.tensor(cummulative_rewards) if fw != "tf": batch = policy._lazy_tensor_dict(batch) - model_out, _ = model.from_batch(batch) + model_out, _ = model(batch) vf_estimates = model.value_function() if fw == "tf": model_out, vf_estimates = \ diff --git a/rllib/agents/pg/pg_tf_policy.py b/rllib/agents/pg/pg_tf_policy.py index c49d277d6ed9..f5cd970acc93 100644 --- a/rllib/agents/pg/pg_tf_policy.py +++ b/rllib/agents/pg/pg_tf_policy.py @@ -34,7 +34,7 @@ def pg_tf_loss( of loss tensors. """ # Pass the training data through our model to get distribution parameters. - dist_inputs, _ = model.from_batch(train_batch) + dist_inputs, _ = model(train_batch) # Create an action distribution object. action_dist = dist_class(dist_inputs, model) diff --git a/rllib/agents/pg/pg_torch_policy.py b/rllib/agents/pg/pg_torch_policy.py index 34a17c5e03f9..cbc7ab2306a8 100644 --- a/rllib/agents/pg/pg_torch_policy.py +++ b/rllib/agents/pg/pg_torch_policy.py @@ -35,7 +35,7 @@ def pg_torch_loss( of loss tensors. """ # Pass the training data through our model to get distribution parameters. - dist_inputs, _ = model.from_batch(train_batch) + dist_inputs, _ = model(train_batch) # Create an action distribution object. action_dist = dist_class(dist_inputs, model) diff --git a/rllib/agents/ppo/appo_tf_policy.py b/rllib/agents/ppo/appo_tf_policy.py index 6eb18af97aea..fa1b738f9e56 100644 --- a/rllib/agents/ppo/appo_tf_policy.py +++ b/rllib/agents/ppo/appo_tf_policy.py @@ -99,7 +99,7 @@ def appo_surrogate_loss( Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ - model_out, _ = model.from_batch(train_batch) + model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): @@ -123,7 +123,7 @@ def make_time_major(*args, **kw): rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] - target_model_out, _ = policy.target_model.from_batch(train_batch) + target_model_out, _ = policy.target_model(train_batch) prev_action_dist = dist_class(behaviour_logits, policy.model) values = policy.model.value_function() values_time_major = make_time_major(values) diff --git a/rllib/agents/ppo/appo_torch_policy.py b/rllib/agents/ppo/appo_torch_policy.py index 46bcd5044b7b..17ffefd31dad 100644 --- a/rllib/agents/ppo/appo_torch_policy.py +++ b/rllib/agents/ppo/appo_torch_policy.py @@ -56,7 +56,7 @@ def appo_surrogate_loss(policy: Policy, model: ModelV2, """ target_model = policy.target_models[model] - model_out, _ = model.from_batch(train_batch) + model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): @@ -79,7 +79,7 @@ def _make_time_major(*args, **kwargs): rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] - target_model_out, _ = target_model.from_batch(train_batch) + target_model_out, _ = target_model(train_batch) prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 33df5b8fcca5..5cdb3cb20efb 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -50,7 +50,7 @@ def ppo_surrogate_loss( logits, state, extra_outs = model(train_batch) value_fn_out = extra_outs[SampleBatch.VF_PREDS] else: - logits, state = model.from_batch(train_batch) + logits, state = model(train_batch) value_fn_out = model.value_function() curr_action_dist = dist_class(logits, model) diff --git a/rllib/examples/custom_tf_policy.py b/rllib/examples/custom_tf_policy.py index 47eeaeac85c2..4f14a179e82c 100644 --- a/rllib/examples/custom_tf_policy.py +++ b/rllib/examples/custom_tf_policy.py @@ -16,7 +16,7 @@ def policy_gradient_loss(policy, model, dist_class, train_batch): - logits, _ = model.from_batch(train_batch) + logits, _ = model(train_batch) action_dist = dist_class(logits, model) return -tf.reduce_mean( action_dist.logp(train_batch["actions"]) * train_batch["returns"]) diff --git a/rllib/examples/eager_execution.py b/rllib/examples/eager_execution.py index 8d68f407377c..118978c66c29 100644 --- a/rllib/examples/eager_execution.py +++ b/rllib/examples/eager_execution.py @@ -80,7 +80,7 @@ def compute_penalty(actions, rewards): print("The eagerly computed penalty is", penalty, actions, rewards) return penalty - logits, _ = model.from_batch(train_batch) + logits, _ = model(train_batch) action_dist = dist_class(logits, model) actions = train_batch[SampleBatch.ACTIONS] diff --git a/rllib/examples/export/onnx_tf.py b/rllib/examples/export/onnx_tf.py index 4eb3429eb381..64e976724fa4 100644 --- a/rllib/examples/export/onnx_tf.py +++ b/rllib/examples/export/onnx_tf.py @@ -33,7 +33,7 @@ # Let's run inference on the tensorflow model policy = trainer.get_policy() -result_tf, _ = policy.model.from_batch(test_data) +result_tf, _ = policy.model(test_data) # Evaluate tensor to fetch numpy array with policy._sess.as_default(): diff --git a/rllib/examples/export/onnx_torch.py b/rllib/examples/export/onnx_torch.py index d63ae7634a81..251398b1850e 100644 --- a/rllib/examples/export/onnx_torch.py +++ b/rllib/examples/export/onnx_torch.py @@ -35,7 +35,7 @@ # Let's run inference on the torch model policy = trainer.get_policy() -result_pytorch, _ = policy.model.from_batch({ +result_pytorch, _ = policy.model({ "obs": torch.tensor(test_data["obs"]), }) diff --git a/rllib/examples/rock_paper_scissors_multiagent.py b/rllib/examples/rock_paper_scissors_multiagent.py index 0905314c1140..49e03dd46360 100644 --- a/rllib/examples/rock_paper_scissors_multiagent.py +++ b/rllib/examples/rock_paper_scissors_multiagent.py @@ -146,7 +146,7 @@ def run_with_custom_entropy_loss(args, stop): This performs about the same as the default loss does.""" def entropy_policy_gradient_loss(policy, model, dist_class, train_batch): - logits, _ = model.from_batch(train_batch) + logits, _ = model(train_batch) action_dist = dist_class(logits, model) if args.framework == "torch": # Required by PGTorchPolicy's stats fn. diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 520546b995b8..e5fb908fa587 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -10,7 +10,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils import NullContextManager -from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ TensorType from ray.rllib.utils.spaces.repeated import Repeated @@ -204,18 +204,19 @@ def __call__( # where tensors get automatically converted). if isinstance(input_dict, SampleBatch): restored = input_dict.copy(shallow=True) - # Backward compatibility. - if seq_lens is None: - seq_lens = input_dict.get(SampleBatch.SEQ_LENS) - if not state: - state = [] - i = 0 - while "state_in_{}".format(i) in input_dict: - state.append(input_dict["state_in_{}".format(i)]) - i += 1 else: restored = input_dict.copy() + # Backward compatibility. + if not state: + state = [] + i = 0 + while "state_in_{}".format(i) in input_dict: + state.append(input_dict["state_in_{}".format(i)]) + i += 1 + if seq_lens is None: + seq_lens = input_dict.get(SampleBatch.SEQ_LENS) + # No Preprocessor used: `config._disable_preprocessor_api`=True. # TODO: This is unnecessary for when no preprocessor is used. # Obs are not flat then anymore. However, we'll keep this @@ -255,9 +256,7 @@ def __call__( self._last_output = outputs return outputs, state_out if len(state_out) > 0 else (state or []) - # TODO: (sven) obsolete this method at some point (replace by - # simply calling model directly with a sample_batch as only input). - @PublicAPI + @Deprecated(new="ModelV2.__call__()", error=False) def from_batch(self, train_batch: SampleBatch, is_training: bool = True) -> (TensorType, List[TensorType]): """Convenience function that calls this model with a tensor batch. diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 5a8e7ff1c5e3..ec3d6159616e 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -107,7 +107,7 @@ # Type of dict returned by get_weights() representing model weights. ModelWeights = dict -# An input dict used for direct ModelV2 calls or `ModelV2.from_batch` calls. +# An input dict used for direct ModelV2 calls. ModelInputDict = Dict[str, TensorType] # Some kind of sample batch. From c24d7069c19123685113de176c821fd19fc3bfe3 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 25 Oct 2021 11:37:31 +0200 Subject: [PATCH 2/3] wip --- rllib/models/modelv2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index e5fb908fa587..db234dc4247e 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -266,7 +266,7 @@ def from_batch(self, train_batch: SampleBatch, """ input_dict = train_batch.copy() - input_dict.is_training = is_training + input_dict["is_training"] = is_training states = [] i = 0 while "state_in_{}".format(i) in input_dict: From 7619016f74a34e7670d71f0775c36d7a930c13dd Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 25 Oct 2021 13:51:49 +0200 Subject: [PATCH 3/3] wip --- rllib/agents/dqn/dqn_tf_policy.py | 5 ++--- rllib/agents/dqn/dqn_torch_policy.py | 8 ++------ rllib/policy/dynamic_tf_policy.py | 4 +++- rllib/policy/torch_policy.py | 4 +++- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/rllib/agents/dqn/dqn_tf_policy.py b/rllib/agents/dqn/dqn_tf_policy.py index 0605cf430cfb..0cb24451f008 100644 --- a/rllib/agents/dqn/dqn_tf_policy.py +++ b/rllib/agents/dqn/dqn_tf_policy.py @@ -212,12 +212,12 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space, def get_distribution_inputs_and_class(policy: Policy, model: ModelV2, - obs_batch: TensorType, + input_dict: SampleBatch, *, explore=True, **kwargs): q_vals = compute_q_values( - policy, model, {"obs": obs_batch}, state_batches=None, explore=explore) + policy, model, input_dict, state_batches=None, explore=explore) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_values = q_vals @@ -342,7 +342,6 @@ def compute_q_values(policy: Policy, config = policy.config - input_dict.is_training = policy._get_is_training_placeholder() model_out, state = model(input_dict, state_batches or [], seq_lens) if config["num_atoms"] > 1: diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index fda7889a94c9..10def1bfd099 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -204,16 +204,13 @@ def build_q_model_and_distribution( def get_distribution_inputs_and_class( policy: Policy, model: ModelV2, - obs_batch: TensorType, + input_dict: SampleBatch, *, explore: bool = True, is_training: bool = False, **kwargs) -> Tuple[TensorType, type, List[TensorType]]: q_vals = compute_q_values( - policy, - model, {"obs": obs_batch}, - explore=explore, - is_training=is_training) + policy, model, input_dict, explore=explore, is_training=is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals model.tower_stats["q_values"] = q_vals @@ -350,7 +347,6 @@ def compute_q_values(policy: Policy, is_training: bool = False): config = policy.config - input_dict.is_training = is_training model_out, state = model(input_dict, state_batches or [], seq_lens) if config["num_atoms"] > 1: diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index bce494358669..f2cb0cd9c51b 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -621,7 +621,9 @@ def _initialize_loss_from_dummy_batch( ) train_batch = SampleBatch( - dict(self._input_dict, **self._loss_input_dict)) + dict(self._input_dict, **self._loss_input_dict), + _is_training=True, + ) if self._state_inputs: train_batch[SampleBatch.SEQ_LENS] = self._seq_lens diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index bf1c69410ff8..5bb97a58e8de 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -264,7 +264,8 @@ def compute_actions( with torch.no_grad(): seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) input_dict = self._lazy_tensor_dict({ - SampleBatch.CUR_OBS: obs_batch + SampleBatch.CUR_OBS: obs_batch, + "is_training": False, }) if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ @@ -291,6 +292,7 @@ def compute_actions_from_input_dict( with torch.no_grad(): # Pass lazy (torch) tensor dict to Model as `input_dict`. input_dict = self._lazy_tensor_dict(input_dict) + input_dict.is_training = False # Pack internal state inputs into (separate) list. state_batches = [ input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]