diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index bcbc77c5d5..4edcb6bd87 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -44,6 +44,7 @@ from merlin.models.torch.outputs.tabular import TabularOutputBlock from merlin.models.torch.router import RouterBlock from merlin.models.torch.transforms.agg import Concat, Stack +from merlin.models.torch.transforms.sequences import BroadcastToSequence, TabularPadding input_schema = schema.input_schema output_schema = schema.output_schema @@ -92,4 +93,6 @@ "MMOEBlock", "PLEBlock", "CGCBlock", + "TabularPadding", + "BroadcastToSequence", ] diff --git a/merlin/models/torch/transforms/sequences.py b/merlin/models/torch/transforms/sequences.py index c3343ef85b..9654f934e1 100644 --- a/merlin/models/torch/transforms/sequences.py +++ b/merlin/models/torch/transforms/sequences.py @@ -20,6 +20,7 @@ from torch import nn from merlin.models.torch.batch import Batch, Sequence +from merlin.models.torch.schema import Selection, select from merlin.schema import Schema, Tags @@ -173,3 +174,110 @@ def _pad_dense_tensor(self, tensor: torch.Tensor, length: int) -> torch.Tensor: pad_diff = length - tensor.shape[1] return F.pad(input=tensor, pad=(0, pad_diff, 0, 0)) return tensor + + +class BroadcastToSequence(nn.Module): + """ + A PyTorch module to broadcast features to match the sequence length. + + BroadcastToSequence is a PyTorch module designed to facilitate broadcasting + of specific features within a given data schema to match a given sequence length. + This can be particularly useful in sequence-based neural networks, where different + types of inputs need to be processed in sync within the network, and all inputs need + to be of the same length. + + For example, in a sequence-to-sequence learning problem, one might have a feature + representing a constant property for each sequence (like an ID or a group), and you + want this feature to be available at each time step. In this case, you can use + BroadcastToSequence to 'broadcast' this feature along the time dimension, + creating a copy for each time step. + + Parameters + ---------- + to_broadcast : Selection + The features that need to be broadcasted. + sequence : Selection + The sequence features. + + """ + + def __init__(self, to_broadcast: Selection, sequence: Selection): + super().__init__() + self.to_broadcast = to_broadcast + self.sequence = sequence + + def initialize_from_schema(self, schema: Schema): + """ + Initialize the module from a schema. + + Parameters + ---------- + schema : Schema + The input-schema of this module + """ + self.schema = schema + self.to_broadcast_features: List[str] = select(schema, self.to_broadcast).column_names + self.sequence_features: List[str] = select(schema, self.sequence).column_names + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward propagation method. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + The inputs dictionary containing the tensors to be broadcasted. + + Returns + ------- + Dict[str, torch.Tensor] + The dictionary containing the broadcasted tensors. + + Raises + ------ + RuntimeError + If a tensor has an unsupported number of dimensions. + """ + + outputs = {} + seq_length = self.get_seq_length(inputs) + + # Iterate over the to_broadcast_features and broadcast each tensor to the sequence length + for key, val in inputs.items(): + if key in self.to_broadcast_features: + # Check the dimension of the original tensor + if len(val.shape) == 1: # for 1D tensor (batch dimension only) + broadcasted_tensor = val.unsqueeze(1).repeat(1, seq_length) + elif len(val.shape) == 2: # for 2D tensor (batch dimension + feature dimension) + broadcasted_tensor = val.unsqueeze(1).repeat(1, seq_length, 1) + else: + raise RuntimeError(f"Unsupported number of dimensions: {len(val.shape)}") + + # Update the inputs dictionary with the broadcasted tensor + outputs[key] = broadcasted_tensor + else: + outputs[key] = val + + return outputs + + def get_seq_length(self, inputs: Dict[str, torch.Tensor]) -> int: + """ + Get the sequence length from inputs. + + Parameters + ---------- + inputs : Dict[str, torch.Tensor] + The inputs dictionary. + + Returns + ------- + int + The sequence length. + """ + + first_feat = self.sequence_features[0] + + if first_feat + "__offsets" in inputs: + return inputs[first_feat + "__offsets"][-1].item() + + return inputs[first_feat].shape[1] diff --git a/tests/unit/torch/transforms/test_sequences.py b/tests/unit/torch/transforms/test_sequences.py index 76286400a9..03af7fa271 100644 --- a/tests/unit/torch/transforms/test_sequences.py +++ b/tests/unit/torch/transforms/test_sequences.py @@ -4,7 +4,7 @@ import torch from merlin.models.torch.batch import Batch -from merlin.models.torch.transforms.sequences import TabularPadding +from merlin.models.torch.transforms.sequences import BroadcastToSequence, TabularPadding from merlin.models.torch.utils import module_utils from merlin.schema import ColumnSchema, Schema, Tags @@ -95,3 +95,52 @@ def test_padded_targets(self, sequence_batch, sequence_schema): assert padded_batch.targets["target_2"].shape[1] == _max_sequence_length assert torch.equal(padded_batch.targets["target_1"], sequence_batch.targets["target_1"]) + + +class TestBroadcastToSequence: + def setup_method(self): + self.input_tensors = { + "feature_1": torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), + "feature_2": torch.tensor( + [[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], [[4.0, 4.0], [5.0, 5.0], [6.0, 6.0]]] + ), + "feature_3": torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]), + } + self.schema = Schema(list(self.input_tensors.keys())) + self.to_broadcast = Schema(["feature_1", "feature_3"]) + self.sequence = Schema(["feature_2"]) + self.broadcast = BroadcastToSequence(self.to_broadcast, self.sequence) + + def test_initialize_from_schema(self): + self.broadcast.initialize_from_schema(self.schema) + assert self.broadcast.to_broadcast_features == ["feature_1", "feature_3"] + assert self.broadcast.sequence_features == ["feature_2"] + + def test_get_seq_length(self): + self.broadcast.initialize_from_schema(self.schema) + assert self.broadcast.get_seq_length(self.input_tensors) == 3 + + def test_get_seq_length_offsets(self): + self.broadcast.initialize_from_schema(self.schema) + + inputs = { + "feature_1": torch.tensor([1, 2]), + "feature_2__offsets": torch.tensor([2, 3]), + "feature_3": torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]), + } + + assert self.broadcast.get_seq_length(inputs) == 3 + + def test_forward(self): + self.broadcast.initialize_from_schema(self.schema) + output = module_utils.module_test(self.broadcast, self.input_tensors) + assert output["feature_1"].shape == (2, 3, 3) + assert output["feature_3"].shape == (2, 3, 3) + assert output["feature_2"].shape == (2, 3, 2) + + def test_unsupported_dimensions(self): + self.broadcast.initialize_from_schema(self.schema) + self.input_tensors["feature_3"] = torch.rand(10, 3, 3, 3) + + with pytest.raises(RuntimeError, match="Unsupported number of dimensions: 4"): + self.broadcast(self.input_tensors)