Skip to content

Commit

Permalink
Remove link, and introduce ShortcutBlock + ResidualBlock (#1170)
Browse files Browse the repository at this point in the history
* Removing Link in favour of some new Blocks like: ResidualBlock & ShortcutBlock

* Removing Link in favour of some new Blocks like: ResidualBlock & ShortcutBlock

* Add conversion test

* Some bug fixes + 100% test-coverage for block.py

* Improve doc-strings

* Remove un-used isblock again
  • Loading branch information
marcromeyn authored Jul 3, 2023
1 parent 86d0a34 commit eff55db
Show file tree
Hide file tree
Showing 8 changed files with 332 additions and 345 deletions.
4 changes: 3 additions & 1 deletion merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from merlin.models.torch import schema
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block, ParallelBlock
from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
Expand Down Expand Up @@ -45,9 +45,11 @@
"ParallelBlock",
"Sequence",
"RegressionOutput",
"ResidualBlock",
"RouterBlock",
"SelectKeys",
"SelectFeatures",
"ShortcutBlock",
"TabularInputBlock",
"Concat",
"Stack",
Expand Down
196 changes: 174 additions & 22 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
from merlin.models.torch import schema
from merlin.models.torch.batch import Batch
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.link import Link, LinkType
from merlin.models.torch.registry import registry
from merlin.models.torch.utils.traversal_utils import TraversableMixin, leaf
from merlin.models.torch.utils.traversal_utils import TraversableMixin
from merlin.models.utils.registry import RegistryMixin
from merlin.schema import Schema

Expand All @@ -41,8 +40,6 @@ class Block(BlockContainer, RegistryMixin, TraversableMixin):
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
track_schema : bool, default = True
If True, the schema of the output tensors are tracked.
"""

registry = registry
Expand Down Expand Up @@ -73,7 +70,7 @@ def forward(

return inputs

def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Block":
def repeat(self, n: int = 1, name=None) -> "Block":
"""
Creates a new block by repeating the current block `n` times.
Each repetition is a deep copy of the current block.
Expand All @@ -97,9 +94,6 @@ def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Blo
raise ValueError("n must be greater than 0")

repeats = [self.copy() for _ in range(n - 1)]
if link:
parsed_link = Link.parse(link)
repeats = [parsed_link.copy().setup_link(repeat) for repeat in repeats]

return Block(self, *repeats, name=name)

Expand Down Expand Up @@ -221,7 +215,7 @@ def forward(

return outputs

def append(self, module: nn.Module, link: Optional[LinkType] = None):
def append(self, module: nn.Module):
"""Appends a module to the post-processing stage.
Parameters
Expand All @@ -235,7 +229,7 @@ def append(self, module: nn.Module, link: Optional[LinkType] = None):
The current object itself.
"""

self.post.append(module, link=link)
self.post.append(module)

return self

Expand All @@ -244,7 +238,7 @@ def prepend(self, module: nn.Module):

return self

def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
def append_to(self, name: str, module: nn.Module):
"""Appends a module to a specified branch.
Parameters
Expand All @@ -260,11 +254,11 @@ def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = Non
The current object itself.
"""

self.branches[name].append(module, link=link)
self.branches[name].append(module)

return self

def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
def prepend_to(self, name: str, module: nn.Module):
"""Prepends a module to a specified branch.
Parameters
Expand All @@ -279,11 +273,11 @@ def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = No
ParallelBlock
The current object itself.
"""
self.branches[name].prepend(module, link=link)
self.branches[name].prepend(module)

return self

def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
def append_for_each(self, module: nn.Module, shared=False):
"""Appends a module to each branch.
Parameters
Expand All @@ -300,11 +294,11 @@ def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkTy
The current object itself.
"""

self.branches.append_for_each(module, shared=shared, link=link)
self.branches.append_for_each(module, shared=shared)

return self

def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
def prepend_for_each(self, module: nn.Module, shared=False):
"""Prepends a module to each branch.
Parameters
Expand All @@ -321,7 +315,7 @@ def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkT
The current object itself.
"""

self.branches.prepend_for_each(module, shared=shared, link=link)
self.branches.prepend_for_each(module, shared=shared)

return self

Expand Down Expand Up @@ -356,10 +350,7 @@ def leaf(self) -> nn.Module:
raise ValueError("Cannot call leaf() on a ParallelBlock with multiple branches")

first = list(self.branches.values())[0]
if hasattr(first, "leaf"):
return first.leaf()

return leaf(first)
return first.leaf()

def __getitem__(self, idx: Union[slice, int, str]):
if isinstance(idx, str) and idx in self.branches:
Expand Down Expand Up @@ -415,6 +406,167 @@ def __repr__(self) -> str:
return self._get_name() + branches


class ResidualBlock(Block):
"""
A block that applies each contained module sequentially on the input
and performs a residual connection after each module.
Parameters
----------
*module : nn.Module
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
"""

def forward(self, inputs: torch.Tensor, batch: Optional[Batch] = None):
"""
Forward pass through the block. Applies each contained module sequentially on the input.
Parameters
----------
inputs : Union[torch.Tensor, Dict[str, torch.Tensor]]
The input data as a tensor or a dictionary of tensors.
batch : Optional[Batch], default = None
Optional batch of data. If provided, it is used by the `module`s.
Returns
-------
torch.Tensor or Dict[str, torch.Tensor]
The output of the block after processing the input.
"""
shortcut, outputs = inputs, inputs
for module in self.values:
outputs = shortcut + module(outputs, batch=batch)

return outputs


class ShortcutBlock(Block):
"""
A block with a 'shortcut' or a 'skip connection'.
The shortcut tensor can be propagated through the layers of the module or not,
depending on the value of `propagate_shortcut` argument:
If `propagate_shortcut` is True, the shortcut tensor is passed through
each layer of the module.
If `propagate_shortcut` is False, the shortcut tensor is only used as part of
the final output dictionary.
Example usage::
>>> shortcut = mm.ShortcutBlock(nn.Identity())
>>> shortcut(torch.ones(1, 1))
{'shortcut': tensor([[1.]]), 'output': tensor([[1.]])}
Parameters
----------
*module : nn.Module
Variable length argument list of PyTorch modules to be contained in the block.
name : str, optional
The name of the module, by default None.
propagate_shortcut : bool, optional
If True, propagates the shortcut tensor through the layers of this block, by default False.
shortcut_name : str, optional
The name to use for the shortcut tensor, by default "shortcut".
output_name : str, optional
The name to use for the output tensor, by default "output".
"""

def __init__(
self,
*module: nn.Module,
name: Optional[str] = None,
propagate_shortcut: bool = False,
shortcut_name: str = "shortcut",
output_name: str = "output",
):
super().__init__(*module, name=name)
self.shortcut_name = shortcut_name
self.output_name = output_name
self.propagate_shortcut = propagate_shortcut

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
) -> Dict[str, torch.Tensor]:
"""
Defines the forward propagation of the module.
Parameters
----------
inputs : Union[torch.Tensor, Dict[str, torch.Tensor]]
The input tensor or a dictionary of tensors.
batch : Batch, optional
A batch of inputs, by default None.
Returns
-------
Dict[str, torch.Tensor]
The output tensor as a dictionary.
Raises
------
RuntimeError
If the shortcut name is not found in the input dictionary, or
if the module does not return a tensor or a dictionary with a key 'output_name'.
"""

if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]):
if self.shortcut_name not in inputs:
raise RuntimeError(
f"Shortcut name {self.shortcut_name} not found in inputs {inputs}"
)
shortcut = inputs[self.shortcut_name]
else:
shortcut = inputs

output = inputs
for module in self.values:
if self.propagate_shortcut:
if torch.jit.isinstance(output, Dict[str, torch.Tensor]):
module_output = module(output, batch=batch)
else:
to_pass: Dict[str, torch.Tensor] = {
self.shortcut_name: shortcut,
self.output_name: torch.jit.annotate(torch.Tensor, output),
}

module_output = module(to_pass, batch=batch)

if torch.jit.isinstance(module_output, torch.Tensor):
output = module_output
elif torch.jit.isinstance(module_output, Dict[str, torch.Tensor]):
output = module_output[self.output_name]
else:
raise RuntimeError(
f"Module {module} must return a tensor or a dict ",
f"with key {self.output_name}",
)
else:
if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]) and torch.jit.isinstance(
output, Dict[str, torch.Tensor]
):
output = output[self.output_name]
_output = module(output, batch=batch)
if torch.jit.isinstance(_output, torch.Tensor) or torch.jit.isinstance(
_output, Dict[str, torch.Tensor]
):
output = _output
else:
raise RuntimeError(
f"Module {module} must return a tensor or a dict ",
f"with key {self.output_name}",
)

to_return = {self.shortcut_name: shortcut}
if torch.jit.isinstance(output, Dict[str, torch.Tensor]):
to_return.update(output)
else:
to_return[self.output_name] = output

return to_return


def get_pre(module: nn.Module) -> BlockContainer:
if hasattr(module, "pre"):
return module.pre
Expand Down
40 changes: 25 additions & 15 deletions merlin/models/torch/blocks/dlrm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Dict, Optional
from typing import Dict, Optional, Union

import torch
from torch import nn

from merlin.models.torch.batch import Batch
from merlin.models.torch.block import Block
from merlin.models.torch.inputs.embedding import EmbeddingTables
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.link import Link
from merlin.models.torch.transforms.agg import MaybeAgg, Stack
from merlin.models.torch.transforms.agg import Stack
from merlin.models.utils.doc_utils import docstring_parameter
from merlin.schema import Schema, Tags

Expand Down Expand Up @@ -77,21 +77,36 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return interactions_flat


class ShortcutConcatContinuous(Link):
class InteractionBlock(Block):
"""
A shortcut connection that concatenates
continuous input features and intermediate outputs.
When there's no continuous input, the intermediate output is returned.
"""

def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
intermediate_output = self.output(inputs)
def __init__(
self,
*module: nn.Module,
name: Optional[str] = None,
prepend_agg: bool = True,
):
if prepend_agg:
module = (Stack(dim=1),) + module
super().__init__(*module, name=name)

if "continuous" in inputs:
return torch.cat((inputs["continuous"], intermediate_output), dim=1)
def forward(
self, inputs: Union[Dict[str, torch.Tensor], torch.Tensor], batch: Optional[Batch] = None
) -> torch.Tensor:
outputs = inputs
for module in self.values:
outputs = module(outputs, batch)

return intermediate_output
if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]):
if "continuous" in inputs:
return torch.cat((inputs["continuous"], outputs), dim=1)

return outputs


@docstring_parameter(dlrm_reference=_DLRM_REF)
Expand Down Expand Up @@ -131,11 +146,6 @@ def __init__(
interaction: Optional[nn.Module] = None,
):
super().__init__(DLRMInputBlock(schema, dim, bottom_block))

self.append(
Block(MaybeAgg(Stack(dim=1)), interaction or DLRMInteraction()),
link=ShortcutConcatContinuous(),
)

self.append(InteractionBlock(interaction or DLRMInteraction()))
if top_block:
self.append(top_block)
Loading

0 comments on commit eff55db

Please sign in to comment.