From 377a522ce24b0fc998f6c51d6c767d7925e12527 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 29 Apr 2022 16:39:03 +0800 Subject: [PATCH] [RLlib] Fix time dimension shaping for PyTorch RNN models. (#21735) --- rllib/policy/rnn_sequencing.py | 10 ++-- rllib/policy/tests/test_rnn_sequencing.py | 56 ++++++++++++++++++++++- 2 files changed, 61 insertions(+), 5 deletions(-) diff --git a/rllib/policy/rnn_sequencing.py b/rllib/policy/rnn_sequencing.py index ef2efe9ac7cb..22502cd51c4f 100644 --- a/rllib/policy/rnn_sequencing.py +++ b/rllib/policy/rnn_sequencing.py @@ -206,11 +206,13 @@ def add_time_dimension( # Dynamically reshape the padded batch to introduce a time dimension. new_batch_size = padded_batch_size // max_seq_len + batch_major_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:] + padded_outputs = padded_inputs.view(batch_major_shape) + if time_major: - new_shape = (max_seq_len, new_batch_size) + padded_inputs.shape[1:] - else: - new_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:] - return torch.reshape(padded_inputs, new_shape) + # Swap the batch and time dimensions + padded_outputs = padded_outputs.transpose(0, 1) + return padded_outputs @DeveloperAPI diff --git a/rllib/policy/tests/test_rnn_sequencing.py b/rllib/policy/tests/test_rnn_sequencing.py index d1425ca88592..d595058c2bf6 100644 --- a/rllib/policy/tests/test_rnn_sequencing.py +++ b/rllib/policy/tests/test_rnn_sequencing.py @@ -2,12 +2,20 @@ import unittest import ray -from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size +from ray.rllib.policy.rnn_sequencing import ( + pad_batch_to_sequences_of_same_size, + add_time_dimension, +) from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check +tf1, tf, tfv = try_import_tf() +torch, nn = try_import_torch() + + class TestRNNSequencing(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -89,6 +97,52 @@ def test_pad_batch_fixed_max(self): check(s1["a"].shape[0], max_seq_len * num_seqs) check(s1["b"].shape[0], max_seq_len * num_seqs) + def test_add_time_dimension(self): + """Test add_time_dimension gives sequential data along the time dimension""" + + B, T, F = np.random.choice( + np.asarray(list(range(8, 32)), dtype=np.int32), # use int32 for seq_lens + size=3, + replace=False, + ) + + inputs_numpy = np.repeat( + np.arange(B * T)[:, np.newaxis], repeats=F, axis=-1 + ).astype(np.int32) + check(inputs_numpy.shape, (B * T, F)) + + time_shift_diff_batch_major = np.ones(shape=(B, T - 1, F), dtype=np.int32) + time_shift_diff_time_major = np.ones(shape=(T - 1, B, F), dtype=np.int32) + + if tf is not None: + # Test tensorflow batch-major + padded_inputs = tf.constant(inputs_numpy) + batch_major_outputs = add_time_dimension( + padded_inputs, max_seq_len=T, framework="tf", time_major=False + ) + check(batch_major_outputs.shape.as_list(), [B, T, F]) + time_shift_diff = batch_major_outputs[:, 1:] - batch_major_outputs[:, :-1] + check(time_shift_diff, time_shift_diff_batch_major) + + if torch is not None: + # Test torch batch-major + padded_inputs = torch.from_numpy(inputs_numpy) + batch_major_outputs = add_time_dimension( + padded_inputs, max_seq_len=T, framework="torch", time_major=False + ) + check(batch_major_outputs.shape, (B, T, F)) + time_shift_diff = batch_major_outputs[:, 1:] - batch_major_outputs[:, :-1] + check(time_shift_diff, time_shift_diff_batch_major) + + # Test torch time-major + padded_inputs = torch.from_numpy(inputs_numpy) + time_major_outputs = add_time_dimension( + padded_inputs, max_seq_len=T, framework="torch", time_major=True + ) + check(time_major_outputs.shape, (T, B, F)) + time_shift_diff = time_major_outputs[1:] - time_major_outputs[:-1] + check(time_shift_diff, time_shift_diff_time_major) + if __name__ == "__main__": import pytest