diff --git a/CHANGELOG.md b/CHANGELOG.md index 28c12e17b832..7d666ad72af8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added the `PatchTransformerAggregation` layer ([#9487](https://github.com/pyg-team/pytorch_geometric/pull/9487)) - Added the `nn.nlp.LLM` model ([#9462](https://github.com/pyg-team/pytorch_geometric/pull/9462)) - Added an example of training GNNs for a graph-level regression task ([#9070](https://github.com/pyg-team/pytorch_geometric/pull/9070)) - Added `utils.from_rdmol`/`utils.to_rdmol` functionality ([#9452](https://github.com/pyg-team/pytorch_geometric/pull/9452)) diff --git a/test/graphgym/test_logger.py b/test/graphgym/test_logger.py index 74dd0c33d09a..d5f87ef0d30d 100644 --- a/test/graphgym/test_logger.py +++ b/test/graphgym/test_logger.py @@ -1,9 +1,15 @@ +from torch_geometric.graphgym.config import set_run_dir +from torch_geometric.graphgym.loader import create_loader from torch_geometric.graphgym.logger import Logger, LoggerCallback from torch_geometric.testing import withPackage @withPackage('yacs', 'pytorch_lightning') def test_logger_callback(): + loaders = create_loader() + assert len(loaders) == 3 + + set_run_dir('.') logger = LoggerCallback() assert isinstance(logger.train_logger, Logger) assert isinstance(logger.val_logger, Logger) diff --git a/test/nn/aggr/test_patch_transformer.py b/test/nn/aggr/test_patch_transformer.py new file mode 100644 index 000000000000..430531048117 --- /dev/null +++ b/test/nn/aggr/test_patch_transformer.py @@ -0,0 +1,27 @@ +import torch + +from torch_geometric.nn import PatchTransformerAggregation +from torch_geometric.testing import withCUDA + + +@withCUDA +def test_patch_transformer_aggregation(device: torch.device) -> None: + aggr = PatchTransformerAggregation( + in_channels=16, + out_channels=32, + patch_size=2, + hidden_channels=8, + num_transformer_blocks=1, + heads=2, + dropout=0.2, + aggr=['sum', 'mean', 'min', 'max', 'var', 'std'], + ).to(device) + aggr.reset_parameters() + assert str(aggr) == 'PatchTransformerAggregation(16, 32, patch_size=2)' + + index = torch.tensor([0, 0, 1, 1, 1, 2], device=device) + x = torch.randn(index.size(0), 16, device=device) + + out = aggr(x, index) + assert out.device == device + assert out.size() == (3, aggr.out_channels) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index 0b3425849059..aaf8c95e9135 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -25,6 +25,7 @@ from .set_transformer import SetTransformerAggregation from .lcm import LCMAggregation from .variance_preserving import VariancePreservingAggregation +from .patch_transformer import PatchTransformerAggregation __all__ = classes = [ 'Aggregation', @@ -53,4 +54,5 @@ 'SetTransformerAggregation', 'LCMAggregation', 'VariancePreservingAggregation', + 'PatchTransformerAggregation', ] diff --git a/torch_geometric/nn/aggr/patch_transformer.py b/torch_geometric/nn/aggr/patch_transformer.py new file mode 100644 index 000000000000..480408488e0a --- /dev/null +++ b/torch_geometric/nn/aggr/patch_transformer.py @@ -0,0 +1,143 @@ +import math +from typing import List, Optional, Union + +import torch +from torch import Tensor + +from torch_geometric.experimental import disable_dynamic_shapes +from torch_geometric.nn.aggr import Aggregation +from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock +from torch_geometric.nn.encoding import PositionalEncoding +from torch_geometric.utils import scatter + + +class PatchTransformerAggregation(Aggregation): + r"""Performs patch transformer aggregation in which the elements to + aggregate are processed by multi-head attention blocks across patches, as + described in the `"Simplifying Temporal Heterogeneous Network for + Continuous-Time Link Prediction" + `_ paper. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + patch_size (int): Number of elements in a patch. + hidden_channels (int): Intermediate size of each sample. + num_transformer_blocks (int, optional): Number of transformer blocks + (default: :obj:`1`). + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + dropout (float, optional): Dropout probability of attention weights. + (default: :obj:`0.0`) + aggr (str or list[str], optional): The aggregation module, *e.g.*, + :obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, + :obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`) + """ + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + hidden_channels: int, + num_transformer_blocks: int = 1, + heads: int = 1, + dropout: float = 0.0, + aggr: Union[str, List[str]] = 'mean', + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + self.aggrs = [aggr] if isinstance(aggr, str) else aggr + + assert len(self.aggrs) > 0 + for aggr in self.aggrs: + assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std'] + + self.lin = torch.nn.Linear(in_channels, hidden_channels) + self.pad_projector = torch.nn.Linear( + patch_size * hidden_channels, + hidden_channels, + ) + self.pe = PositionalEncoding(hidden_channels) + + self.blocks = torch.nn.ModuleList([ + MultiheadAttentionBlock( + channels=hidden_channels, + heads=heads, + layer_norm=True, + dropout=dropout, + ) for _ in range(num_transformer_blocks) + ]) + + self.fc = torch.nn.Linear( + hidden_channels * len(self.aggrs), + out_channels, + ) + + def reset_parameters(self) -> None: + self.lin.reset_parameters() + self.pad_projector.reset_parameters() + self.pe.reset_parameters() + for block in self.blocks: + block.reset_parameters() + self.fc.reset_parameters() + + @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) + def forward( + self, + x: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + max_num_elements: Optional[int] = None, + ) -> Tensor: + + if max_num_elements is None: + if ptr is not None: + count = ptr.diff() + else: + count = scatter(torch.ones_like(index), index, dim=0, + dim_size=dim_size, reduce='sum') + max_num_elements = int(count.max()) + 1 + + # Set `max_num_elements` to a multiple of `patch_size`: + max_num_elements = (math.floor(max_num_elements / self.patch_size) * + self.patch_size) + + x = self.lin(x) + + # TODO If groups are heavily unbalanced, this will create a lot of + # "empty" patches. Try to figure out a way to fix this. + # [batch_size, num_patches * patch_size, hidden_channels] + x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, + max_num_elements=max_num_elements) + + # [batch_size, num_patches, patch_size * hidden_channels] + x = x.view(x.size(0), max_num_elements // self.patch_size, + self.patch_size * x.size(-1)) + + # [batch_size, num_patches, hidden_channels] + x = self.pad_projector(x) + + x = x + self.pe(torch.arange(x.size(1), device=x.device)) + + # [batch_size, num_patches, hidden_channels] + for block in self.blocks: + x = block(x, x) + + # [batch_size, hidden_channels] + outs: List[Tensor] = [] + for aggr in self.aggrs: + out = getattr(torch, aggr)(x, dim=1) + outs.append(out[0] if isinstance(out, tuple) else out) + out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0] + + # [batch_size, out_channels] + return self.fc(out) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, patch_size={self.patch_size})')