-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add padding operator to the PyTorch API (#1177)
* add padding op * remove unused mask prefix * Apply suggestions from code review Co-authored-by: Marc Romeyn <[email protected]> * Add test for tracing the model with torchscript * fix linting * add module_test to test_padded_targets --------- Co-authored-by: Marc Romeyn <[email protected]> Co-authored-by: edknv <[email protected]>
- Loading branch information
1 parent
0b1c198
commit e2930f8
Showing
2 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from typing import Dict, List, Optional, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
from merlin.models.torch.batch import Batch, Sequence | ||
from merlin.schema import Schema, Tags | ||
|
||
|
||
class TabularPadding(nn.Module): | ||
"""A PyTorch module for padding tabular sequence data. | ||
Parameters | ||
---------- | ||
schema : Schema | ||
The schema of the tabular data, which defines the column names of input features. | ||
max_sequence_length : Optional[int], default=None | ||
The maximum length of the sequences after padding. | ||
If None, sequences will be padded to the maximum length in the current batch. | ||
Example usage:: | ||
features = { | ||
'feature1': torch.tensor([[4, 3], [5, 2]), | ||
'feature2': torch.tensor([[3,8], [7,9]]) | ||
} | ||
schema = Schema(["feature1", "feature2"]) | ||
_max_sequence_length = 10 | ||
padding_op = TabularBatchPadding( | ||
schema=schema, max_sequence_length=_max_sequence_length | ||
) | ||
padded_batch = padding_op(Batch(feaures)) | ||
Notes: | ||
- If the schema contains continuous list features, | ||
ensure that they are normalized within the range of [0, 1]. | ||
This is necessary because we will be padding them | ||
to a max_sequence_length using the minimum value of 0.0. | ||
- The current class only supports right padding. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
schema: Schema, | ||
max_sequence_length: Optional[int] = None, | ||
): | ||
super().__init__() | ||
self.schema = schema | ||
self.max_sequence_length = max_sequence_length | ||
self.features: List[str] = self.schema.column_names | ||
self.sparse_features = self.schema.select_by_tag(Tags.SEQUENCE).column_names | ||
self.padding_idx = 0 | ||
|
||
def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Batch) -> Batch: | ||
_max_sequence_length = self.max_sequence_length | ||
if not _max_sequence_length: | ||
# Infer the maximum length from the current batch | ||
batch_max_sequence_length = 0 | ||
for key, val in batch.features.items(): | ||
if key.endswith("__offsets"): | ||
offsets = val | ||
max_row_length = int(torch.max(offsets[1:] - offsets[:-1])) | ||
batch_max_sequence_length = max(max_row_length, batch_max_sequence_length) | ||
_max_sequence_length = batch_max_sequence_length | ||
|
||
# Store the non-padded lengths of list features | ||
seq_inputs_lengths = self._get_sequence_lengths(batch.features) | ||
seq_shapes: List[torch.Tensor] = list(seq_inputs_lengths.values()) | ||
if not torch.all(torch.stack([torch.all(x == seq_shapes[0]) for x in seq_shapes])): | ||
raise ValueError( | ||
"The sequential inputs must have the same length for each row in the batch, " | ||
f"but they are different: {seq_shapes}" | ||
) | ||
# Pad the features of the batch | ||
batch_padded = {} | ||
for key, value in batch.features.items(): | ||
if key.endswith("__offsets"): | ||
col_name = key[: -len("__offsets")] | ||
if col_name in self.features: | ||
padded_values = self._pad_ragged_tensor( | ||
batch.features[f"{col_name}__values"], value, _max_sequence_length | ||
) | ||
batch_padded[col_name] = padded_values | ||
elif key.endswith("__values"): | ||
continue | ||
else: | ||
col_name = key | ||
if col_name in self.features and seq_inputs_lengths.get(col_name) is not None: | ||
# pad dense list features | ||
batch_padded[col_name] = self._pad_dense_tensor(value, _max_sequence_length) | ||
|
||
# Pad targets of the batch | ||
targets_padded = None | ||
if batch.targets is not None: | ||
targets_padded = {} | ||
for key, value in batch.targets.items(): | ||
if key.endswith("__offsets"): | ||
col_name = key[: -len("__offsets")] | ||
padded_values = self._pad_ragged_tensor( | ||
batch.targets[f"{col_name}__values"], value, _max_sequence_length | ||
) | ||
targets_padded[col_name] = padded_values | ||
elif key.endswith("__values"): | ||
continue | ||
else: | ||
targets_padded[key] = value | ||
|
||
return Batch( | ||
features=batch_padded, targets=targets_padded, sequences=Sequence(seq_inputs_lengths) | ||
) | ||
|
||
def _get_sequence_lengths(self, sequences: Dict[str, torch.Tensor]): | ||
"""Compute the effective length of each sequence in a dictionary of sequences.""" | ||
seq_inputs_lengths = {} | ||
for key, val in sequences.items(): | ||
if key.endswith("__offsets"): | ||
seq_inputs_lengths[key[: -len("__offsets")]] = val[1:] - val[:-1] | ||
elif key in self.sparse_features: | ||
seq_inputs_lengths[key] = (val != self.padding_idx).sum(-1) | ||
return seq_inputs_lengths | ||
|
||
def _squeeze(self, tensor: torch.Tensor): | ||
"""Squeeze a tensor of shape (N,1) to shape (N).""" | ||
if len(tensor.shape) == 2: | ||
return tensor.squeeze(1) | ||
return tensor | ||
|
||
def _get_indices(self, offsets: torch.Tensor, diff_offsets: torch.Tensor): | ||
"""Compute indices for a sparse tensor from offsets and their differences.""" | ||
row_ids = torch.arange(len(offsets) - 1, device=offsets.device) | ||
row_ids_repeated = torch.repeat_interleave(row_ids, diff_offsets) | ||
row_offset_repeated = torch.repeat_interleave(offsets[:-1], diff_offsets) | ||
col_ids = ( | ||
torch.arange(len(row_offset_repeated), device=offsets.device) - row_offset_repeated | ||
) | ||
indices = torch.cat([row_ids_repeated.unsqueeze(-1), col_ids.unsqueeze(-1)], dim=1) | ||
return indices | ||
|
||
def _pad_ragged_tensor(self, values: torch.Tensor, offsets: torch.Tensor, padding_length: int): | ||
"""Pad a ragged features represented by "values" and "offsets" to a dense tensor | ||
of length `padding_length`. | ||
""" | ||
values = self._squeeze(values) | ||
offsets = self._squeeze(offsets) | ||
num_rows = len(offsets) - 1 | ||
diff_offsets = offsets[1:] - offsets[:-1] | ||
max_length = int(diff_offsets.max()) | ||
indices = self._get_indices(offsets, diff_offsets) | ||
sparse_tensor = torch.sparse_coo_tensor( | ||
indices.T, values, torch.Size([num_rows, max_length]), device=values.device | ||
) | ||
|
||
return self._pad_dense_tensor(sparse_tensor.to_dense(), padding_length) | ||
|
||
def _pad_dense_tensor(self, tensor: torch.Tensor, length: int) -> torch.Tensor: | ||
"""Pad a dense tensor along its second dimension to a specified length.""" | ||
if len(tensor.shape) == 2: | ||
pad_diff = length - tensor.shape[1] | ||
return F.pad(input=tensor, pad=(0, pad_diff, 0, 0)) | ||
return tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from itertools import accumulate | ||
|
||
import pytest | ||
import torch | ||
|
||
from merlin.models.torch.batch import Batch | ||
from merlin.models.torch.transforms.sequences import TabularPadding | ||
from merlin.models.torch.utils import module_utils | ||
from merlin.schema import ColumnSchema, Schema, Tags | ||
|
||
|
||
def _get_values_offsets(data): | ||
values = [] | ||
row_lengths = [] | ||
for row in data: | ||
row_lengths.append(len(row)) | ||
values += row | ||
offsets = [0] + list(accumulate(row_lengths)) | ||
return torch.tensor(values), torch.tensor(offsets) | ||
|
||
|
||
class TestPadBatch: | ||
@pytest.fixture | ||
def sequence_batch(self): | ||
a_values, a_offsets = _get_values_offsets(data=[[1, 2], [], [3, 4, 5]]) | ||
b_values, b_offsets = _get_values_offsets([[34, 30], [], [33, 23, 50]]) | ||
features = { | ||
"a__values": a_values, | ||
"a__offsets": a_offsets, | ||
"b__values": b_values, | ||
"b__offsets": b_offsets, | ||
"c_dense": torch.Tensor([[1, 2, 0], [0, 0, 0], [4, 5, 6]]), | ||
"d_context": torch.Tensor([1, 2, 3]), | ||
} | ||
targets = None | ||
return Batch(features, targets) | ||
|
||
@pytest.fixture | ||
def sequence_schema(self): | ||
return Schema( | ||
[ | ||
ColumnSchema("a", tags=[Tags.SEQUENCE]), | ||
ColumnSchema("b", tags=[Tags.SEQUENCE]), | ||
ColumnSchema("c_dense", tags=[Tags.SEQUENCE]), | ||
ColumnSchema("d_context", tags=[Tags.CONTEXT]), | ||
] | ||
) | ||
|
||
def test_padded_features(self, sequence_batch, sequence_schema): | ||
_max_sequence_length = 8 | ||
padding_op = TabularPadding( | ||
schema=sequence_schema, max_sequence_length=_max_sequence_length | ||
) | ||
padded_batch = module_utils.module_test(padding_op, sequence_batch) | ||
|
||
assert torch.equal(padded_batch.sequences.length("a"), torch.Tensor([2, 0, 3])) | ||
assert set(padded_batch.features.keys()) == set(["a", "b", "c_dense"]) | ||
for feature in ["a", "b", "c_dense"]: | ||
assert padded_batch.features[feature].shape[1] == _max_sequence_length | ||
|
||
def test_batch_invalid_lengths(self): | ||
# Test when targets is not a tensor nor a dictionary of tensors | ||
a_values, a_offsets = _get_values_offsets(data=[[1, 2], [], [3, 4, 5]]) | ||
b_values, b_offsets = _get_values_offsets([[34], [23, 56], [33, 23, 50, 4]]) | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match="The sequential inputs must have the same length for each row in the batch", | ||
): | ||
padding_op = TabularPadding(schema=Schema(["a", "b"])) | ||
padding_op( | ||
inputs=None, | ||
batch=Batch( | ||
{ | ||
"a__values": a_values, | ||
"a__offsets": a_offsets, | ||
"b__values": b_values, | ||
"b__offsets": b_offsets, | ||
} | ||
), | ||
) | ||
|
||
def test_padded_targets(self, sequence_batch, sequence_schema): | ||
_max_sequence_length = 8 | ||
target_values, target_offsets = _get_values_offsets([[10, 11], [], [12, 13, 14]]) | ||
sequence_batch.targets = { | ||
"target_1": torch.Tensor([3, 4, 6]), | ||
"target_2__values": target_values, | ||
"target_2__offsets": target_offsets, | ||
} | ||
padding_op = TabularPadding( | ||
schema=sequence_schema, max_sequence_length=_max_sequence_length | ||
) | ||
padded_batch = module_utils.module_test(padding_op, sequence_batch) | ||
|
||
assert padded_batch.targets["target_2"].shape[1] == _max_sequence_length | ||
assert torch.equal(padded_batch.targets["target_1"], sequence_batch.targets["target_1"]) |