Skip to content

Commit

Permalink
Add padding operator to the PyTorch API (#1177)
Browse files Browse the repository at this point in the history
* add padding op

* remove unused mask prefix

* Apply suggestions from code review

Co-authored-by: Marc Romeyn <[email protected]>

* Add test for tracing the model with torchscript

* fix linting

* add module_test to test_padded_targets

---------

Co-authored-by: Marc Romeyn <[email protected]>
Co-authored-by: edknv <[email protected]>
  • Loading branch information
3 people authored Jul 4, 2023
1 parent 0b1c198 commit e2930f8
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 0 deletions.
175 changes: 175 additions & 0 deletions merlin/models/torch/transforms/sequences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from torch import nn

from merlin.models.torch.batch import Batch, Sequence
from merlin.schema import Schema, Tags


class TabularPadding(nn.Module):
"""A PyTorch module for padding tabular sequence data.
Parameters
----------
schema : Schema
The schema of the tabular data, which defines the column names of input features.
max_sequence_length : Optional[int], default=None
The maximum length of the sequences after padding.
If None, sequences will be padded to the maximum length in the current batch.
Example usage::
features = {
'feature1': torch.tensor([[4, 3], [5, 2]),
'feature2': torch.tensor([[3,8], [7,9]])
}
schema = Schema(["feature1", "feature2"])
_max_sequence_length = 10
padding_op = TabularBatchPadding(
schema=schema, max_sequence_length=_max_sequence_length
)
padded_batch = padding_op(Batch(feaures))
Notes:
- If the schema contains continuous list features,
ensure that they are normalized within the range of [0, 1].
This is necessary because we will be padding them
to a max_sequence_length using the minimum value of 0.0.
- The current class only supports right padding.
"""

def __init__(
self,
schema: Schema,
max_sequence_length: Optional[int] = None,
):
super().__init__()
self.schema = schema
self.max_sequence_length = max_sequence_length
self.features: List[str] = self.schema.column_names
self.sparse_features = self.schema.select_by_tag(Tags.SEQUENCE).column_names
self.padding_idx = 0

def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Batch) -> Batch:
_max_sequence_length = self.max_sequence_length
if not _max_sequence_length:
# Infer the maximum length from the current batch
batch_max_sequence_length = 0
for key, val in batch.features.items():
if key.endswith("__offsets"):
offsets = val
max_row_length = int(torch.max(offsets[1:] - offsets[:-1]))
batch_max_sequence_length = max(max_row_length, batch_max_sequence_length)
_max_sequence_length = batch_max_sequence_length

# Store the non-padded lengths of list features
seq_inputs_lengths = self._get_sequence_lengths(batch.features)
seq_shapes: List[torch.Tensor] = list(seq_inputs_lengths.values())
if not torch.all(torch.stack([torch.all(x == seq_shapes[0]) for x in seq_shapes])):
raise ValueError(
"The sequential inputs must have the same length for each row in the batch, "
f"but they are different: {seq_shapes}"
)
# Pad the features of the batch
batch_padded = {}
for key, value in batch.features.items():
if key.endswith("__offsets"):
col_name = key[: -len("__offsets")]
if col_name in self.features:
padded_values = self._pad_ragged_tensor(
batch.features[f"{col_name}__values"], value, _max_sequence_length
)
batch_padded[col_name] = padded_values
elif key.endswith("__values"):
continue
else:
col_name = key
if col_name in self.features and seq_inputs_lengths.get(col_name) is not None:
# pad dense list features
batch_padded[col_name] = self._pad_dense_tensor(value, _max_sequence_length)

# Pad targets of the batch
targets_padded = None
if batch.targets is not None:
targets_padded = {}
for key, value in batch.targets.items():
if key.endswith("__offsets"):
col_name = key[: -len("__offsets")]
padded_values = self._pad_ragged_tensor(
batch.targets[f"{col_name}__values"], value, _max_sequence_length
)
targets_padded[col_name] = padded_values
elif key.endswith("__values"):
continue
else:
targets_padded[key] = value

return Batch(
features=batch_padded, targets=targets_padded, sequences=Sequence(seq_inputs_lengths)
)

def _get_sequence_lengths(self, sequences: Dict[str, torch.Tensor]):
"""Compute the effective length of each sequence in a dictionary of sequences."""
seq_inputs_lengths = {}
for key, val in sequences.items():
if key.endswith("__offsets"):
seq_inputs_lengths[key[: -len("__offsets")]] = val[1:] - val[:-1]
elif key in self.sparse_features:
seq_inputs_lengths[key] = (val != self.padding_idx).sum(-1)
return seq_inputs_lengths

def _squeeze(self, tensor: torch.Tensor):
"""Squeeze a tensor of shape (N,1) to shape (N)."""
if len(tensor.shape) == 2:
return tensor.squeeze(1)
return tensor

def _get_indices(self, offsets: torch.Tensor, diff_offsets: torch.Tensor):
"""Compute indices for a sparse tensor from offsets and their differences."""
row_ids = torch.arange(len(offsets) - 1, device=offsets.device)
row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets)
row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets)
col_ids = (
torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated
)
indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], dim=1)
return indices

def _pad_ragged_tensor(self, values: torch.Tensor, offsets: torch.Tensor, padding_length: int):
"""Pad a ragged features represented by "values" and "offsets" to a dense tensor
of length `padding_length`.
"""
values = self._squeeze(values)
offsets = self._squeeze(offsets)
num_rows = len(offsets) - 1
diff_offsets = offsets[1:] - offsets[:-1]
max_length = int(diff_offsets.max())
indices = self._get_indices(offsets, diff_offsets)
sparse_tensor = torch.sparse_coo_tensor(
indices.T, values, torch.Size([num_rows, max_length]), device=values.device
)

return self._pad_dense_tensor(sparse_tensor.to_dense(), padding_length)

def _pad_dense_tensor(self, tensor: torch.Tensor, length: int) -> torch.Tensor:
"""Pad a dense tensor along its second dimension to a specified length."""
if len(tensor.shape) == 2:
pad_diff = length - tensor.shape[1]
return F.pad(input=tensor, pad=(0, pad_diff, 0, 0))
return tensor
97 changes: 97 additions & 0 deletions tests/unit/torch/transforms/test_sequences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from itertools import accumulate

import pytest
import torch

from merlin.models.torch.batch import Batch
from merlin.models.torch.transforms.sequences import TabularPadding
from merlin.models.torch.utils import module_utils
from merlin.schema import ColumnSchema, Schema, Tags


def _get_values_offsets(data):
values = []
row_lengths = []
for row in data:
row_lengths.append(len(row))
values += row
offsets = [0] + list(accumulate(row_lengths))
return torch.tensor(values), torch.tensor(offsets)


class TestPadBatch:
@pytest.fixture
def sequence_batch(self):
a_values, a_offsets = _get_values_offsets(data=[[1, 2], [], [3, 4, 5]])
b_values, b_offsets = _get_values_offsets([[34, 30], [], [33, 23, 50]])
features = {
"a__values": a_values,
"a__offsets": a_offsets,
"b__values": b_values,
"b__offsets": b_offsets,
"c_dense": torch.Tensor([[1, 2, 0], [0, 0, 0], [4, 5, 6]]),
"d_context": torch.Tensor([1, 2, 3]),
}
targets = None
return Batch(features, targets)

@pytest.fixture
def sequence_schema(self):
return Schema(
[
ColumnSchema("a", tags=[Tags.SEQUENCE]),
ColumnSchema("b", tags=[Tags.SEQUENCE]),
ColumnSchema("c_dense", tags=[Tags.SEQUENCE]),
ColumnSchema("d_context", tags=[Tags.CONTEXT]),
]
)

def test_padded_features(self, sequence_batch, sequence_schema):
_max_sequence_length = 8
padding_op = TabularPadding(
schema=sequence_schema, max_sequence_length=_max_sequence_length
)
padded_batch = module_utils.module_test(padding_op, sequence_batch)

assert torch.equal(padded_batch.sequences.length("a"), torch.Tensor([2, 0, 3]))
assert set(padded_batch.features.keys()) == set(["a", "b", "c_dense"])
for feature in ["a", "b", "c_dense"]:
assert padded_batch.features[feature].shape[1] == _max_sequence_length

def test_batch_invalid_lengths(self):
# Test when targets is not a tensor nor a dictionary of tensors
a_values, a_offsets = _get_values_offsets(data=[[1, 2], [], [3, 4, 5]])
b_values, b_offsets = _get_values_offsets([[34], [23, 56], [33, 23, 50, 4]])

with pytest.raises(
ValueError,
match="The sequential inputs must have the same length for each row in the batch",
):
padding_op = TabularPadding(schema=Schema(["a", "b"]))
padding_op(
inputs=None,
batch=Batch(
{
"a__values": a_values,
"a__offsets": a_offsets,
"b__values": b_values,
"b__offsets": b_offsets,
}
),
)

def test_padded_targets(self, sequence_batch, sequence_schema):
_max_sequence_length = 8
target_values, target_offsets = _get_values_offsets([[10, 11], [], [12, 13, 14]])
sequence_batch.targets = {
"target_1": torch.Tensor([3, 4, 6]),
"target_2__values": target_values,
"target_2__offsets": target_offsets,
}
padding_op = TabularPadding(
schema=sequence_schema, max_sequence_length=_max_sequence_length
)
padded_batch = module_utils.module_test(padding_op, sequence_batch)

assert padded_batch.targets["target_2"].shape[1] == _max_sequence_length
assert torch.equal(padded_batch.targets["target_1"], sequence_batch.targets["target_1"])

0 comments on commit e2930f8

Please sign in to comment.