Skip to content

Commit

Permalink
[RLlib] Allow passing **kwargs to action distribution. (#24692)
Browse files Browse the repository at this point in the history
  • Loading branch information
nmatare authored May 18, 2022
1 parent 9b2086c commit 012a4c8
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions rllib/models/tf/tf_action_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,15 +570,19 @@ class MultiActionDistribution(TFActionDistribution):
inputs (Tensor list): A list of tensors from which to compute samples.
"""

def __init__(self, inputs, model, *, child_distributions, input_lens, action_space):
def __init__(
self, inputs, model, *, child_distributions, input_lens, action_space, **kwargs
):
ActionDistribution.__init__(self, inputs, model)

self.action_space_struct = get_base_struct_from_space(action_space)

self.input_lens = np.array(input_lens, dtype=np.int32)
split_inputs = tf.split(inputs, self.input_lens, axis=1)
self.flat_child_distributions = tree.map_structure(
lambda dist, input_: dist(input_, model), child_distributions, split_inputs
lambda dist, input_: dist(input_, model, **kwargs),
child_distributions,
split_inputs,
)

@override(ActionDistribution)
Expand Down

0 comments on commit 012a4c8

Please sign in to comment.