Skip to content

Commit

Permalink
Add TabularPredictNext for causal item prediction task (#1202)
Browse files Browse the repository at this point in the history
* first impelementation of tabular predict next

* add unit tests

* include PR comments

* fix docstrings description

* include Marc comments

* add torchscript support to TabularPredictNextModule

---------

Co-authored-by: edknv <[email protected]>
  • Loading branch information
sararb and edknv authored Jul 18, 2023
1 parent fe7a9f4 commit 52c89a4
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 3 deletions.
188 changes: 187 additions & 1 deletion merlin/models/torch/transforms/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: B
)

_max_sequence_length = self.max_sequence_length
if not _max_sequence_length:
if not torch.jit.isinstance(_max_sequence_length, int):
# Infer the maximum length from the current batch
batch_max_sequence_length = 0
for key, val in inputs.items():
Expand Down Expand Up @@ -338,3 +338,189 @@ def get_seq_length(self, inputs: Dict[str, torch.Tensor]) -> int:
return inputs[first_feat + "__offsets"][-1].item()

return inputs[first_feat].shape[1]


class TabularPredictNext(BatchBlock):
"""A Batchblock instance for preparing sequential inputs and targets
for next-item prediction. The target is extracted from the shifted
sequence of the target feature and the sequential input features
are truncated in the last position.
Parameters
----------
target : Optional[Selection], default=Tags.ID
The sequential input column(s) that will be used to extract the target.
Targets can be one or multiple input features with the same sequence length.
schema : Optional[Schema]
The schema with the sequential columns to be truncated
apply_padding : Optional[bool], default=True
Whether to pad sequential inputs before extracting the target(s).
max_sequence_length : Optional[int], default=None
The maximum length of the sequences after padding.
If None, sequences will be padded to the maximum length in the current batch.
Example usage::
batch_output = transform(batch)
features = {
'feature1': torch.tensor([[4, 3], [5, 2]),
'feature2': torch.tensor([[3,8], [7,9]])
}
schema = Schema(["feature1", "feature2"])
next_item_op = TabularPredictNext(
schema=schema, target='feature1'
)
transformed_batch = next_item_op(Batch(features))
"""

def __init__(
self,
target: Selection = Tags.ID,
schema: Optional[Schema] = None,
apply_padding: bool = True,
max_sequence_length: int = None,
name: Optional[str] = None,
):
super().__init__(
TabularPredictNextModule(
schema=schema,
target=target,
apply_padding=apply_padding,
max_sequence_length=max_sequence_length,
),
name=name,
)


class TabularSequenceTransform(nn.Module):
"""Base PyTorch module for preparing targets from a batch of sequential inputs.
Parameters
----------
target : Optional[Selection], default=Tags.ID
The sequential input column that will be used to extract the target.
In case of multiple targets, either a list of target feature names
or a shared Tag indicating the targets should be provided.
schema : Optional[Schema]
The schema with the sequential columns to be truncated
apply_padding : Optional[bool], default=True
Whether to pad sequential inputs before extracting the target(s).
max_sequence_length : Optional[int], default=None
The maximum length of the sequences after padding.
If None, sequences will be padded to the maximum length in the current batch.
"""

def __init__(
self,
target: Optional[Selection] = Tags.ID,
schema: Optional[Schema] = None,
apply_padding: bool = True,
max_sequence_length: int = None,
):
super().__init__()
self.target = target
if schema:
self.initialize_from_schema(schema)
self._initialized_from_schema = True
self.padding_idx = 0
self.apply_padding = apply_padding
if self.apply_padding:
self.padding_operator = TabularPadding(
schema=self.schema, max_sequence_length=max_sequence_length
)

def initialize_from_schema(self, schema: Schema):
self.schema = schema
self.features: List[str] = self.schema.column_names
target = select(self.schema, self.target)
if not target:
raise ValueError(
f"The target '{self.target}' was not found in the "
f"provided sequential schema: {self.schema}"
)
self.target_name = self._get_target(target)

def _get_target(self, target):
return target.column_names

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Batch, **kwargs
) -> Batch:
raise NotImplementedError()

def _check_seq_inputs_targets(self, batch: Batch):
self._check_input_sequence_lengths(batch)
self._check_target_shape(batch)

def _check_target_shape(self, batch: Batch):
for name in self.target_name:
if name not in batch.features:
raise ValueError(f"Inputs features do not contain target column ({name})")

target = batch.features[name]
if target.ndim < 2:
raise ValueError(
f"Sequential target column ({name}) "
f"must be a 2D tensor, but shape is {target.ndim}"
)
lengths = batch.sequences.length(name)
if any(lengths <= 1):
raise ValueError(
f"2nd dim of target column ({name})"
"must be greater than 1 for sequential input to be shifted as target"
)

def _check_input_sequence_lengths(self, batch: Batch):
if not batch.sequences.lengths:
raise ValueError(
"The input `batch` should include information about input sequences lengths"
)
sequence_lengths = torch.stack([batch.sequences.length(name) for name in self.features])
assert torch.all(sequence_lengths.eq(sequence_lengths[0])), (
"All tabular sequence features need to have the same sequence length, "
f"found {sequence_lengths}"
)


class TabularPredictNextModule(TabularSequenceTransform):
"""A PyTorch module for preparing tabular sequence data for next-item prediction."""

def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Batch) -> Batch:
if self.apply_padding:
batch = self.padding_operator(batch)
self._check_seq_inputs_targets(batch)

# Shifts the target column to be the next item of corresponding input column
new_targets: Dict[str, torch.Tensor] = dict()
for name in self.target_name:
new_target = batch.features[name]
new_target = new_target[:, 1:]
new_targets[name] = new_target

# Removes the last item of the sequence, as it belongs to the target
new_inputs = dict()
for k, v in batch.features.items():
if k in self.features:
new_inputs[k] = v[:, :-1]

# Generates information about the new lengths and causal masks
new_lengths, causal_masks = {}, {}
for name in self.features:
new_lengths[name] = batch.sequences.lengths[name] - 1
_max_length = list(new_targets.values())[0].shape[
-1
] # all new targets have same output sequence length
causal_mask = self._generate_causal_mask(list(new_lengths.values())[0], _max_length)
for name in self.features:
causal_masks[name] = causal_mask

return Batch(
features=new_inputs,
targets=new_targets,
sequences=Sequence(new_lengths, masks=causal_masks),
)

def _generate_causal_mask(self, seq_lengths: torch.Tensor, max_len: int):
"""
Generate a 2D mask from a tensor of sequence lengths.
"""
return torch.arange(max_len)[None, :] < seq_lengths[:, None]
137 changes: 135 additions & 2 deletions tests/unit/torch/transforms/test_sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
import pytest
import torch

from merlin.models.torch.batch import Batch
from merlin.models.torch.transforms.sequences import BroadcastToSequence, TabularPadding
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.transforms.sequences import (
BroadcastToSequence,
TabularPadding,
TabularPredictNext,
)
from merlin.models.torch.utils import module_utils
from merlin.schema import ColumnSchema, Schema, Tags

Expand Down Expand Up @@ -144,3 +148,132 @@ def test_unsupported_dimensions(self):

with pytest.raises(RuntimeError, match="Unsupported number of dimensions: 4"):
self.broadcast(self.input_tensors)


class TestTabularPredictNext:
@pytest.fixture
def sequence_batch(self):
a_values, a_offsets = _get_values_offsets(data=[[1, 2, 3], [3, 6], [3, 4, 5, 6]])
b_values, b_offsets = _get_values_offsets([[34, 30, 31], [30, 31], [33, 23, 50, 51]])

c_values, c_offsets = _get_values_offsets([[1, 2, 3, 4], [5, 6], [5, 6, 7, 8, 9, 10]])
d_values, d_offsets = _get_values_offsets(
[[10, 20, 30, 40], [50, 60], [50, 60, 70, 80, 90, 100]]
)

features = {
"a__values": a_values,
"a__offsets": a_offsets,
"b__values": b_values,
"b__offsets": b_offsets,
"c__values": c_values,
"c__offsets": c_offsets,
"d__values": d_values,
"d__offsets": d_offsets,
"e_dense": torch.Tensor([[1, 2, 3, 0], [5, 6, 0, 0], [4, 5, 6, 7]]),
"f_context": torch.Tensor([1, 2, 3, 4]),
}
targets = None
return Batch(features, targets)

@pytest.fixture
def sequence_schema_1(self):
return Schema(
[
ColumnSchema("a", tags=[Tags.SEQUENCE]),
ColumnSchema("b", tags=[Tags.SEQUENCE]),
ColumnSchema("e_dense", tags=[Tags.SEQUENCE]),
]
)

@pytest.fixture
def sequence_schema_2(self):
return Schema(
[
ColumnSchema("c", tags=[Tags.SEQUENCE, Tags.ID]),
ColumnSchema("d", tags=[Tags.SEQUENCE]),
]
)

@pytest.fixture
def padded_batch(self, sequence_schema_1, sequence_batch):
padding_op = TabularPadding(schema=sequence_schema_1)
return padding_op(sequence_batch)

def test_tabular_sequence_transform_wrong_inputs(self, padded_batch, sequence_schema_1):
with pytest.raises(
ValueError,
match="The target 'Tags.ID' was not found in the provided sequential schema:",
):
transform = TabularPredictNext(
schema=sequence_schema_1,
target=Tags.ID,
)

transform = TabularPredictNext(
schema=sequence_schema_1,
target="a",
apply_padding=False,
)
with pytest.raises(
ValueError,
match="The input `batch` should include information about input sequences lengths",
):
transform(Batch({"b": padded_batch.features["b"]}))

with pytest.raises(
ValueError,
match="Inputs features do not contain target column",
):
transform(Batch({"b": padded_batch.features["b"]}, sequences=padded_batch.sequences))

with pytest.raises(
ValueError, match="must be greater than 1 for sequential input to be shifted as target"
):
transform = TabularPredictNext(
schema=sequence_schema_1.select_by_name("a"), target="a", apply_padding=False
)
transform(
Batch(
{"a": torch.Tensor([[1, 2], [1, 0], [3, 4]])},
sequences=Sequence(lengths={"a": torch.Tensor([2, 1, 2])}),
)
)

def test_transform_predict_next(self, sequence_batch, padded_batch, sequence_schema_1):
transform = TabularPredictNext(schema=sequence_schema_1, target="a")

batch_output = module_utils.module_test(transform, sequence_batch)

assert list(batch_output.features.keys()) == ["a", "b", "e_dense"]
for k in ["a", "b", "e_dense"]:
assert torch.equal(batch_output.features[k], padded_batch.features[k][:, :-1])
assert torch.equal(batch_output.sequences.length("a"), torch.Tensor([2, 1, 3]))

def test_transform_predict_next_multi_sequence(
self, sequence_batch, padded_batch, sequence_schema_1, sequence_schema_2
):
import merlin.models.torch as mm

transform_1 = TabularPredictNext(schema=sequence_schema_1, target="a")
transform_2 = TabularPredictNext(schema=sequence_schema_2)
transform_block = mm.BatchBlock(
mm.ParallelBlock({"transform_1": transform_1, "transform_2": transform_2})
)
batch_output = transform_block(sequence_batch)

assert list(batch_output.features.keys()) == ["a", "b", "e_dense", "f_context", "c", "d"]
assert list(batch_output.targets.keys()) == ["a", "c"]

assert torch.equal(batch_output.sequences.length("a"), torch.Tensor([2, 1, 3]))
assert torch.equal(batch_output.sequences.length("c"), torch.Tensor([3, 1, 5]))
assert torch.all(
batch_output.sequences.mask("a") # target mask
== torch.Tensor(
[
[True, True, False],
[True, False, False],
[True, True, True],
]
)
)

0 comments on commit 52c89a4

Please sign in to comment.