Skip to content

Commit

Permalink
Adding BroadcastToSequence (#1195)
Browse files Browse the repository at this point in the history
* First commit

* Adding BroadcastToSequence

* Adding to __init__

---------

Co-authored-by: edknv <[email protected]>
  • Loading branch information
marcromeyn and edknv committed Jul 11, 2023
1 parent 36d582a commit 86ac779
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 1 deletion.
3 changes: 3 additions & 0 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from merlin.models.torch.outputs.tabular import TabularOutputBlock
from merlin.models.torch.router import RouterBlock
from merlin.models.torch.transforms.agg import Concat, Stack
from merlin.models.torch.transforms.sequences import BroadcastToSequence, TabularPadding

input_schema = schema.input_schema
output_schema = schema.output_schema
Expand Down Expand Up @@ -92,4 +93,6 @@
"MMOEBlock",
"PLEBlock",
"CGCBlock",
"TabularPadding",
"BroadcastToSequence",
]
108 changes: 108 additions & 0 deletions merlin/models/torch/transforms/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch import nn

from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.schema import Selection, select
from merlin.schema import Schema, Tags


Expand Down Expand Up @@ -173,3 +174,110 @@ def _pad_dense_tensor(self, tensor: torch.Tensor, length: int) -> torch.Tensor:
pad_diff = length - tensor.shape[1]
return F.pad(input=tensor, pad=(0, pad_diff, 0, 0))
return tensor


class BroadcastToSequence(nn.Module):
"""
A PyTorch module to broadcast features to match the sequence length.
BroadcastToSequence is a PyTorch module designed to facilitate broadcasting
of specific features within a given data schema to match a given sequence length.
This can be particularly useful in sequence-based neural networks, where different
types of inputs need to be processed in sync within the network, and all inputs need
to be of the same length.
For example, in a sequence-to-sequence learning problem, one might have a feature
representing a constant property for each sequence (like an ID or a group), and you
want this feature to be available at each time step. In this case, you can use
BroadcastToSequence to 'broadcast' this feature along the time dimension,
creating a copy for each time step.
Parameters
----------
to_broadcast : Selection
The features that need to be broadcasted.
sequence : Selection
The sequence features.
"""

def __init__(self, to_broadcast: Selection, sequence: Selection):
super().__init__()
self.to_broadcast = to_broadcast
self.sequence = sequence

def initialize_from_schema(self, schema: Schema):
"""
Initialize the module from a schema.
Parameters
----------
schema : Schema
The input-schema of this module
"""
self.schema = schema
self.to_broadcast_features: List[str] = select(schema, self.to_broadcast).column_names
self.sequence_features: List[str] = select(schema, self.sequence).column_names

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward propagation method.
Parameters
----------
inputs : Dict[str, torch.Tensor]
The inputs dictionary containing the tensors to be broadcasted.
Returns
-------
Dict[str, torch.Tensor]
The dictionary containing the broadcasted tensors.
Raises
------
RuntimeError
If a tensor has an unsupported number of dimensions.
"""

outputs = {}
seq_length = self.get_seq_length(inputs)

# Iterate over the to_broadcast_features and broadcast each tensor to the sequence length
for key, val in inputs.items():
if key in self.to_broadcast_features:
# Check the dimension of the original tensor
if len(val.shape) == 1: # for 1D tensor (batch dimension only)
broadcasted_tensor = val.unsqueeze(1).repeat(1, seq_length)
elif len(val.shape) == 2: # for 2D tensor (batch dimension + feature dimension)
broadcasted_tensor = val.unsqueeze(1).repeat(1, seq_length, 1)
else:
raise RuntimeError(f"Unsupported number of dimensions: {len(val.shape)}")

# Update the inputs dictionary with the broadcasted tensor
outputs[key] = broadcasted_tensor
else:
outputs[key] = val

return outputs

def get_seq_length(self, inputs: Dict[str, torch.Tensor]) -> int:
"""
Get the sequence length from inputs.
Parameters
----------
inputs : Dict[str, torch.Tensor]
The inputs dictionary.
Returns
-------
int
The sequence length.
"""

first_feat = self.sequence_features[0]

if first_feat + "__offsets" in inputs:
return inputs[first_feat + "__offsets"][-1].item()

return inputs[first_feat].shape[1]
51 changes: 50 additions & 1 deletion tests/unit/torch/transforms/test_sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from merlin.models.torch.batch import Batch
from merlin.models.torch.transforms.sequences import TabularPadding
from merlin.models.torch.transforms.sequences import BroadcastToSequence, TabularPadding
from merlin.models.torch.utils import module_utils
from merlin.schema import ColumnSchema, Schema, Tags

Expand Down Expand Up @@ -95,3 +95,52 @@ def test_padded_targets(self, sequence_batch, sequence_schema):

assert padded_batch.targets["target_2"].shape[1] == _max_sequence_length
assert torch.equal(padded_batch.targets["target_1"], sequence_batch.targets["target_1"])


class TestBroadcastToSequence:
def setup_method(self):
self.input_tensors = {
"feature_1": torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
"feature_2": torch.tensor(
[[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]], [[4.0, 4.0], [5.0, 5.0], [6.0, 6.0]]]
),
"feature_3": torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]),
}
self.schema = Schema(list(self.input_tensors.keys()))
self.to_broadcast = Schema(["feature_1", "feature_3"])
self.sequence = Schema(["feature_2"])
self.broadcast = BroadcastToSequence(self.to_broadcast, self.sequence)

def test_initialize_from_schema(self):
self.broadcast.initialize_from_schema(self.schema)
assert self.broadcast.to_broadcast_features == ["feature_1", "feature_3"]
assert self.broadcast.sequence_features == ["feature_2"]

def test_get_seq_length(self):
self.broadcast.initialize_from_schema(self.schema)
assert self.broadcast.get_seq_length(self.input_tensors) == 3

def test_get_seq_length_offsets(self):
self.broadcast.initialize_from_schema(self.schema)

inputs = {
"feature_1": torch.tensor([1, 2]),
"feature_2__offsets": torch.tensor([2, 3]),
"feature_3": torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]),
}

assert self.broadcast.get_seq_length(inputs) == 3

def test_forward(self):
self.broadcast.initialize_from_schema(self.schema)
output = module_utils.module_test(self.broadcast, self.input_tensors)
assert output["feature_1"].shape == (2, 3, 3)
assert output["feature_3"].shape == (2, 3, 3)
assert output["feature_2"].shape == (2, 3, 2)

def test_unsupported_dimensions(self):
self.broadcast.initialize_from_schema(self.schema)
self.input_tensors["feature_3"] = torch.rand(10, 3, 3, 3)

with pytest.raises(RuntimeError, match="Unsupported number of dimensions: 4"):
self.broadcast(self.input_tensors)

0 comments on commit 86ac779

Please sign in to comment.