Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix failing test cases: Soft-deprecate ModelV2.from_batch (in favor of ModelV2.__call__). #19693

Merged
merged 3 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllib/agents/a3c/a3c_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self,
def actor_critic_loss(policy: Policy, model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
if policy.is_recurrent():
max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def add_advantages(
def actor_critic_loss(policy: Policy, model: ModelV2,
dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
logits, _ = model.from_batch(train_batch)
logits, _ = model(train_batch)
values = model.value_function()

if policy.is_recurrent():
Expand Down
5 changes: 2 additions & 3 deletions rllib/agents/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ def build_q_model(policy: Policy, obs_space: gym.spaces.Space,

def get_distribution_inputs_and_class(policy: Policy,
model: ModelV2,
obs_batch: TensorType,
input_dict: SampleBatch,
*,
explore=True,
**kwargs):
q_vals = compute_q_values(
policy, model, {"obs": obs_batch}, state_batches=None, explore=explore)
policy, model, input_dict, state_batches=None, explore=explore)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

policy.q_values = q_vals
Expand Down Expand Up @@ -342,7 +342,6 @@ def compute_q_values(policy: Policy,

config = policy.config

input_dict.is_training = policy._get_is_training_placeholder()
model_out, state = model(input_dict, state_batches or [], seq_lens)

if config["num_atoms"] > 1:
Expand Down
8 changes: 2 additions & 6 deletions rllib/agents/dqn/dqn_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,13 @@ def build_q_model_and_distribution(
def get_distribution_inputs_and_class(
policy: Policy,
model: ModelV2,
obs_batch: TensorType,
input_dict: SampleBatch,
*,
explore: bool = True,
is_training: bool = False,
**kwargs) -> Tuple[TensorType, type, List[TensorType]]:
q_vals = compute_q_values(
policy,
model, {"obs": obs_batch},
explore=explore,
is_training=is_training)
policy, model, input_dict, explore=explore, is_training=is_training)
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals

model.tower_stats["q_values"] = q_vals
Expand Down Expand Up @@ -350,7 +347,6 @@ def compute_q_values(policy: Policy,
is_training: bool = False):
config = policy.config

input_dict.is_training = is_training
model_out, state = model(input_dict, state_batches or [], seq_lens)

if config["num_atoms"] > 1:
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/impala/vtrace_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _make_time_major(policy, seq_lens, tensor, drop_last=False):


def build_vtrace_loss(policy, model, dist_class, train_batch):
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)

if isinstance(policy.action_space, gym.spaces.Discrete):
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self,


def build_vtrace_loss(policy, model, dist_class, train_batch):
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)

if isinstance(policy.action_space, gym.spaces.Discrete):
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/maml/maml_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def split_placeholders(self, placeholder, split):


def maml_loss(policy, model, dist_class, train_batch):
logits, state = model.from_batch(train_batch)
logits, state = model(train_batch)
policy.cur_lr = policy.config["lr"]

if policy.config["worker_index"]:
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/maml/maml_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def split_placeholders(self, placeholder, split):


def maml_loss(policy, model, dist_class, train_batch):
logits, state = model.from_batch(train_batch)
logits, state = model(train_batch)
policy.cur_lr = policy.config["lr"]

if policy.config["worker_index"]:
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/marwil/marwil_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, obs_space: gym.spaces.Space,
# input_dict.
@make_tf_callable(self.get_session())
def value(**input_dict):
model_out, _ = self.model.from_batch(input_dict, is_training=False)
model_out, _ = self.model(input_dict)
# [0] = remove the batch dim.
return self.model.value_function()[0]

Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self, policy: Policy, value_estimates: TensorType,

def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
value_estimates = model.value_function()

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/marwil/marwil_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution,
train_batch: SampleBatch) -> TensorType:
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)
actions = train_batch[SampleBatch.ACTIONS]
# log\pi_\theta(a|s)
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/marwil/tests/test_marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_marwil_loss_function(self):
cummulative_rewards = torch.tensor(cummulative_rewards)
if fw != "tf":
batch = policy._lazy_tensor_dict(batch)
model_out, _ = model.from_batch(batch)
model_out, _ = model(batch)
vf_estimates = model.value_function()
if fw == "tf":
model_out, vf_estimates = \
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/pg/pg_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def pg_tf_loss(
of loss tensors.
"""
# Pass the training data through our model to get distribution parameters.
dist_inputs, _ = model.from_batch(train_batch)
dist_inputs, _ = model(train_batch)

# Create an action distribution object.
action_dist = dist_class(dist_inputs, model)
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/pg/pg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def pg_torch_loss(
of loss tensors.
"""
# Pass the training data through our model to get distribution parameters.
dist_inputs, _ = model.from_batch(train_batch)
dist_inputs, _ = model(train_batch)

# Create an action distribution object.
action_dist = dist_class(dist_inputs, model)
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/ppo/appo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def appo_surrogate_loss(
Union[TensorType, List[TensorType]]: A single loss tensor or a list
of loss tensors.
"""
model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)

if isinstance(policy.action_space, gym.spaces.Discrete):
Expand All @@ -123,7 +123,7 @@ def make_time_major(*args, **kw):
rewards = train_batch[SampleBatch.REWARDS]
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

target_model_out, _ = policy.target_model.from_batch(train_batch)
target_model_out, _ = policy.target_model(train_batch)
prev_action_dist = dist_class(behaviour_logits, policy.model)
values = policy.model.value_function()
values_time_major = make_time_major(values)
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/ppo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def appo_surrogate_loss(policy: Policy, model: ModelV2,
"""
target_model = policy.target_models[model]

model_out, _ = model.from_batch(train_batch)
model_out, _ = model(train_batch)
action_dist = dist_class(model_out, model)

if isinstance(policy.action_space, gym.spaces.Discrete):
Expand All @@ -79,7 +79,7 @@ def _make_time_major(*args, **kwargs):
rewards = train_batch[SampleBatch.REWARDS]
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]

target_model_out, _ = target_model.from_batch(train_batch)
target_model_out, _ = target_model(train_batch)

prev_action_dist = dist_class(behaviour_logits, model)
values = model.value_function()
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def ppo_surrogate_loss(
logits, state, extra_outs = model(train_batch)
value_fn_out = extra_outs[SampleBatch.VF_PREDS]
else:
logits, state = model.from_batch(train_batch)
logits, state = model(train_batch)
value_fn_out = model.value_function()

curr_action_dist = dist_class(logits, model)
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/custom_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def policy_gradient_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
logits, _ = model(train_batch)
action_dist = dist_class(logits, model)
return -tf.reduce_mean(
action_dist.logp(train_batch["actions"]) * train_batch["returns"])
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/eager_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def compute_penalty(actions, rewards):
print("The eagerly computed penalty is", penalty, actions, rewards)
return penalty

logits, _ = model.from_batch(train_batch)
logits, _ = model(train_batch)
action_dist = dist_class(logits, model)

actions = train_batch[SampleBatch.ACTIONS]
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/export/onnx_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

# Let's run inference on the tensorflow model
policy = trainer.get_policy()
result_tf, _ = policy.model.from_batch(test_data)
result_tf, _ = policy.model(test_data)

# Evaluate tensor to fetch numpy array
with policy._sess.as_default():
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/export/onnx_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

# Let's run inference on the torch model
policy = trainer.get_policy()
result_pytorch, _ = policy.model.from_batch({
result_pytorch, _ = policy.model({
"obs": torch.tensor(test_data["obs"]),
})

Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/rock_paper_scissors_multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def run_with_custom_entropy_loss(args, stop):
This performs about the same as the default loss does."""

def entropy_policy_gradient_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
logits, _ = model(train_batch)
action_dist = dist_class(logits, model)
if args.framework == "torch":
# Required by PGTorchPolicy's stats fn.
Expand Down
27 changes: 13 additions & 14 deletions rllib/models/modelv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils import NullContextManager
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.annotations import Deprecated, DeveloperAPI, PublicAPI
from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
TensorType
from ray.rllib.utils.spaces.repeated import Repeated
Expand Down Expand Up @@ -204,18 +204,19 @@ def __call__(
# where tensors get automatically converted).
if isinstance(input_dict, SampleBatch):
restored = input_dict.copy(shallow=True)
# Backward compatibility.
if seq_lens is None:
seq_lens = input_dict.get(SampleBatch.SEQ_LENS)
if not state:
state = []
i = 0
while "state_in_{}".format(i) in input_dict:
state.append(input_dict["state_in_{}".format(i)])
i += 1
else:
restored = input_dict.copy()

# Backward compatibility.
if not state:
state = []
i = 0
while "state_in_{}".format(i) in input_dict:
state.append(input_dict["state_in_{}".format(i)])
i += 1
if seq_lens is None:
seq_lens = input_dict.get(SampleBatch.SEQ_LENS)

# No Preprocessor used: `config._disable_preprocessor_api`=True.
# TODO: This is unnecessary for when no preprocessor is used.
# Obs are not flat then anymore. However, we'll keep this
Expand Down Expand Up @@ -255,9 +256,7 @@ def __call__(
self._last_output = outputs
return outputs, state_out if len(state_out) > 0 else (state or [])

# TODO: (sven) obsolete this method at some point (replace by
# simply calling model directly with a sample_batch as only input).
@PublicAPI
@Deprecated(new="ModelV2.__call__()", error=False)
def from_batch(self, train_batch: SampleBatch,
is_training: bool = True) -> (TensorType, List[TensorType]):
"""Convenience function that calls this model with a tensor batch.
Expand All @@ -267,7 +266,7 @@ def from_batch(self, train_batch: SampleBatch,
"""

input_dict = train_batch.copy()
input_dict.is_training = is_training
input_dict["is_training"] = is_training
states = []
i = 0
while "state_in_{}".format(i) in input_dict:
Expand Down
4 changes: 3 additions & 1 deletion rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,9 @@ def _initialize_loss_from_dummy_batch(
)

train_batch = SampleBatch(
dict(self._input_dict, **self._loss_input_dict))
dict(self._input_dict, **self._loss_input_dict),
_is_training=True,
)

if self._state_inputs:
train_batch[SampleBatch.SEQ_LENS] = self._seq_lens
Expand Down
4 changes: 3 additions & 1 deletion rllib/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def compute_actions(
with torch.no_grad():
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
input_dict = self._lazy_tensor_dict({
SampleBatch.CUR_OBS: obs_batch
SampleBatch.CUR_OBS: obs_batch,
"is_training": False,
})
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \
Expand All @@ -291,6 +292,7 @@ def compute_actions_from_input_dict(
with torch.no_grad():
# Pass lazy (torch) tensor dict to Model as `input_dict`.
input_dict = self._lazy_tensor_dict(input_dict)
input_dict.is_training = False
# Pack internal state inputs into (separate) list.
state_batches = [
input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
Expand Down
2 changes: 1 addition & 1 deletion rllib/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
# Type of dict returned by get_weights() representing model weights.
ModelWeights = dict

# An input dict used for direct ModelV2 calls or `ModelV2.from_batch` calls.
# An input dict used for direct ModelV2 calls.
ModelInputDict = Dict[str, TensorType]

# Some kind of sample batch.
Expand Down