Skip to content

Commit

Permalink
Adding CrossAttentionBlock (#1193)
Browse files Browse the repository at this point in the history
* Adding CrossAttentionBlock

* Fix failing test by setting dropout=0.0

* proposed changes

---------

Co-authored-by: edknv <[email protected]>
Co-authored-by: edknv <[email protected]>
  • Loading branch information
3 people committed Jul 12, 2023
1 parent 28b690c commit f035867
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 0 deletions.
2 changes: 2 additions & 0 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
repeat_parallel,
repeat_parallel_like,
)
from merlin.models.torch.blocks.attention import CrossAttentionBlock
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.blocks.experts import CGCBlock, MMOEBlock, PLEBlock
from merlin.models.torch.blocks.mlp import MLPBlock
Expand Down Expand Up @@ -102,4 +103,5 @@
"DaskEncoder",
"DaskPredictor",
"stack_context",
"CrossAttentionBlock",
]
168 changes: 168 additions & 0 deletions merlin/models/torch/blocks/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from copy import deepcopy
from typing import Dict, Optional, Union

import torch
from torch import nn

from merlin.models.torch.batch import Batch
from merlin.models.torch.block import Block


class CrossAttentionBlock(Block):
"""
Cross Attention Block module which performs a multihead attention operation
on a provided context and sequence.
Note this block assumes that the input and output tensors are provided as
(batch, seq, feature). When using modules provided in PyTorch, e.g.,
``torch.nn.MultiheadAttention``, the ``batch_first`` parameter should be
set to True to match the shape.
Example usage
-------------
>>> cross = CrossAttentionBlock(
... attention=nn.MultiheadAttention(10, 2, batch_first=True),
... key="context",
... seq_key="sequence",
... )
>>> input_dict = {
... "context": torch.randn(1, 2, 10),
... "sequence": torch.randn(1, 6, 10)}
... }
>>> cross(input_dict)
Parameters
----------
module : nn.Module
Variable length input module list.
attention : nn.MultiheadAttention, optional
Predefined multihead attention module. If not provided, it's inferred from the first module.
name : str, optional
Name for the block.
key : str, optional
Key for the context tensor in the input dictionary.
seq_key : str, optional
Key for the sequence tensor in the input dictionary.
"""

def __init__(
self,
*module: nn.Module,
attention: Optional[nn.MultiheadAttention] = None,
name: str = None,
key: str = "context",
seq_key: Optional[str] = None,
):
super().__init__(*module, name=name)

self.key = key
self.seq_key = seq_key
if attention is None:
if not (
hasattr(module[0], "d_model")
and hasattr(module[0], "nhead")
and hasattr(module[0], "dropout")
):
raise ValueError("Attention module not provided and cannot be inferred from module")

# Try to infer from module
cross_attention = nn.MultiheadAttention(
module[0].d_model, module[0].nhead, module[0].dropout
)
else:
cross_attention = attention

self.cross_attention = nn.ModuleList([cross_attention])
if len(module) > 1:
for m in module:
self.cross_attention.append(
m.copy() if hasattr(m, "copy") else deepcopy(cross_attention)
)

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
) -> torch.Tensor:
"""
Perform forward pass of the CrossAttentionBlock.
Parameters
----------
inputs : Union[torch.Tensor, Dict[str, torch.Tensor]]
Dictionary containing the input tensors.
batch : Optional[Batch]
Optional batch information for the forward pass.
Returns
-------
torch.Tensor
Output tensor after the multihead attention operation.
Raises
------
ValueError
If the input is a torch.Tensor instead of a dictionary.
"""

if isinstance(inputs, torch.Tensor):
raise ValueError("CrossAttentionBlock requires a dictionary input")

context, sequence = self.get_context(inputs), self.get_seq(inputs)

for module, attention in zip(self.values, self.cross_attention):
sequence, _ = attention(sequence, context, context)
sequence = module(sequence, batch=batch)

return sequence

def get_context(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Retrieve the context tensor from the input dictionary using the key.
Parameters
----------
x : Dict[str, torch.Tensor]
Input dictionary containing the tensors.
Returns
-------
torch.Tensor
The context tensor.
"""
return x[self.key]

def get_seq(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Retrieve the sequence tensor from the input dictionary using the key.
Parameters
----------
x : Dict[str, torch.Tensor]
Input dictionary containing the tensors.
Returns
-------
torch.Tensor
The sequence tensor.
Raises
------
RuntimeError
If the seq_key is not found in the input dictionary or if the dictionary has more
than 2 keys and seq_key is not defined.
"""
if self.seq_key is None:
if len(x) == 2:
for key in x.keys():
if key != self.key:
return x[key]
else:
raise RuntimeError(
"Please set seq_key for when more than 2 keys are present ",
f"in the input dictionary, got: {x}.",
)

if self.seq_key not in x:
raise RuntimeError(f"Could not find {self.seq_key} in input dictionary, got: {x}.")

return x[self.seq_key]
5 changes: 5 additions & 0 deletions merlin/models/torch/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def tensors(self, inputs):
def get_schema(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor], Schema]) -> Schema:
if isinstance(inputs, Schema):
return inputs

return self.tensors(inputs)


Expand Down Expand Up @@ -553,6 +554,9 @@ def select(self, selection: Selection) -> "Selectable":

@output_schema.register_tensor(torch.Tensor)
def _tensor_to_schema(input, name="output"):
if input is None:
return Schema([ColumnSchema(name)])

kwargs = dict(dims=input.shape[1:], dtype=input.dtype)

if len(input.shape) > 1 and input.dtype != torch.int32:
Expand Down Expand Up @@ -587,6 +591,7 @@ def _(input):
@output_schema.register_tensor(Tuple[torch.Tensor])
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor, Optional[torch.Tensor]])
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@output_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor])
@input_schema.register_tensor(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor])
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/torch/blocks/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import torch
from torch import nn

from merlin.models.torch.blocks.attention import CrossAttentionBlock
from merlin.models.torch.utils import module_utils


class TestCrossAttentionBlock:
def setup_method(self):
# Set up a simple CrossAttentionBlock instance for testing.
self.cross = CrossAttentionBlock(
nn.TransformerEncoderLayer(10, 2, dim_feedforward=10, batch_first=True, dropout=0.0),
attention=nn.MultiheadAttention(10, 2, batch_first=True),
key="context",
seq_key="sequence",
)
self.input_dict = {"context": torch.randn(1, 2, 10), "sequence": torch.randn(1, 6, 10)}

def test_init(self):
assert self.cross.key == "context"
assert self.cross.seq_key == "sequence"
assert isinstance(self.cross.cross_attention, nn.ModuleList)
assert isinstance(self.cross.cross_attention[0], nn.MultiheadAttention)

def test_forward(self):
out = self.cross(self.input_dict)
assert isinstance(out, torch.Tensor)
assert out.shape == self.input_dict["sequence"].shape

def test_forward_torch_script(self):
out = module_utils.module_test(self.cross, self.input_dict)
assert isinstance(out, torch.Tensor)
assert out.shape == self.input_dict["sequence"].shape

def test_get_seq_error(self):
with pytest.raises(RuntimeError, match="Could not find"):
self.cross.get_seq(
{"context": torch.randn(1, 10), "0": torch.randn(1, 10), "1": torch.randn(1, 10)}
)

with pytest.raises(
RuntimeError, match="Please set seq_key for when more than 2 keys are present"
):
cross = CrossAttentionBlock(
attention=nn.MultiheadAttention(10, 2, batch_first=True),
)
cross.get_seq(
{"context": torch.randn(1, 10), "0": torch.randn(1, 10), "1": torch.randn(1, 10)}
)

0 comments on commit f035867

Please sign in to comment.