Skip to content

Commit

Permalink
Merge branch 'master' into explainer_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chaojun-zhang authored Jul 23, 2024
2 parents 963acba + 00c4564 commit 4d8d16c
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `PatchTransformerAggregation` layer ([#9487](https://github.com/pyg-team/pytorch_geometric/pull/9487))
- Added the `nn.nlp.LLM` model ([#9462](https://github.com/pyg-team/pytorch_geometric/pull/9462))
- Added an example of training GNNs for a graph-level regression task ([#9070](https://github.com/pyg-team/pytorch_geometric/pull/9070))
- Added `utils.from_rdmol`/`utils.to_rdmol` functionality ([#9452](https://github.com/pyg-team/pytorch_geometric/pull/9452))
Expand Down
6 changes: 6 additions & 0 deletions test/graphgym/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from torch_geometric.graphgym.config import set_run_dir
from torch_geometric.graphgym.loader import create_loader
from torch_geometric.graphgym.logger import Logger, LoggerCallback
from torch_geometric.testing import withPackage


@withPackage('yacs', 'pytorch_lightning')
def test_logger_callback():
loaders = create_loader()
assert len(loaders) == 3

set_run_dir('.')
logger = LoggerCallback()
assert isinstance(logger.train_logger, Logger)
assert isinstance(logger.val_logger, Logger)
Expand Down
27 changes: 27 additions & 0 deletions test/nn/aggr/test_patch_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch

from torch_geometric.nn import PatchTransformerAggregation
from torch_geometric.testing import withCUDA


@withCUDA
def test_patch_transformer_aggregation(device: torch.device) -> None:
aggr = PatchTransformerAggregation(
in_channels=16,
out_channels=32,
patch_size=2,
hidden_channels=8,
num_transformer_blocks=1,
heads=2,
dropout=0.2,
aggr=['sum', 'mean', 'min', 'max', 'var', 'std'],
).to(device)
aggr.reset_parameters()
assert str(aggr) == 'PatchTransformerAggregation(16, 32, patch_size=2)'

index = torch.tensor([0, 0, 1, 1, 1, 2], device=device)
x = torch.randn(index.size(0), 16, device=device)

out = aggr(x, index)
assert out.device == device
assert out.size() == (3, aggr.out_channels)
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .set_transformer import SetTransformerAggregation
from .lcm import LCMAggregation
from .variance_preserving import VariancePreservingAggregation
from .patch_transformer import PatchTransformerAggregation

__all__ = classes = [
'Aggregation',
Expand Down Expand Up @@ -53,4 +54,5 @@
'SetTransformerAggregation',
'LCMAggregation',
'VariancePreservingAggregation',
'PatchTransformerAggregation',
]
143 changes: 143 additions & 0 deletions torch_geometric/nn/aggr/patch_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import math
from typing import List, Optional, Union

import torch
from torch import Tensor

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock
from torch_geometric.nn.encoding import PositionalEncoding
from torch_geometric.utils import scatter


class PatchTransformerAggregation(Aggregation):
r"""Performs patch transformer aggregation in which the elements to
aggregate are processed by multi-head attention blocks across patches, as
described in the `"Simplifying Temporal Heterogeneous Network for
Continuous-Time Link Prediction"
<https://dl.acm.org/doi/pdf/10.1145/3583780.3615059>`_ paper.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
patch_size (int): Number of elements in a patch.
hidden_channels (int): Intermediate size of each sample.
num_transformer_blocks (int, optional): Number of transformer blocks
(default: :obj:`1`).
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
dropout (float, optional): Dropout probability of attention weights.
(default: :obj:`0.0`)
aggr (str or list[str], optional): The aggregation module, *e.g.*,
:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"var"`, :obj:`"std"`. (default: :obj:`"mean"`)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
hidden_channels: int,
num_transformer_blocks: int = 1,
heads: int = 1,
dropout: float = 0.0,
aggr: Union[str, List[str]] = 'mean',
) -> None:
super().__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
self.aggrs = [aggr] if isinstance(aggr, str) else aggr

assert len(self.aggrs) > 0
for aggr in self.aggrs:
assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']

self.lin = torch.nn.Linear(in_channels, hidden_channels)
self.pad_projector = torch.nn.Linear(
patch_size * hidden_channels,
hidden_channels,
)
self.pe = PositionalEncoding(hidden_channels)

self.blocks = torch.nn.ModuleList([
MultiheadAttentionBlock(
channels=hidden_channels,
heads=heads,
layer_norm=True,
dropout=dropout,
) for _ in range(num_transformer_blocks)
])

self.fc = torch.nn.Linear(
hidden_channels * len(self.aggrs),
out_channels,
)

def reset_parameters(self) -> None:
self.lin.reset_parameters()
self.pad_projector.reset_parameters()
self.pe.reset_parameters()
for block in self.blocks:
block.reset_parameters()
self.fc.reset_parameters()

@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
def forward(
self,
x: Tensor,
index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None,
dim: int = -2,
max_num_elements: Optional[int] = None,
) -> Tensor:

if max_num_elements is None:
if ptr is not None:
count = ptr.diff()
else:
count = scatter(torch.ones_like(index), index, dim=0,
dim_size=dim_size, reduce='sum')
max_num_elements = int(count.max()) + 1

# Set `max_num_elements` to a multiple of `patch_size`:
max_num_elements = (math.floor(max_num_elements / self.patch_size) *
self.patch_size)

x = self.lin(x)

# TODO If groups are heavily unbalanced, this will create a lot of
# "empty" patches. Try to figure out a way to fix this.
# [batch_size, num_patches * patch_size, hidden_channels]
x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
max_num_elements=max_num_elements)

# [batch_size, num_patches, patch_size * hidden_channels]
x = x.view(x.size(0), max_num_elements // self.patch_size,
self.patch_size * x.size(-1))

# [batch_size, num_patches, hidden_channels]
x = self.pad_projector(x)

x = x + self.pe(torch.arange(x.size(1), device=x.device))

# [batch_size, num_patches, hidden_channels]
for block in self.blocks:
x = block(x, x)

# [batch_size, hidden_channels]
outs: List[Tensor] = []
for aggr in self.aggrs:
out = getattr(torch, aggr)(x, dim=1)
outs.append(out[0] if isinstance(out, tuple) else out)
out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0]

# [batch_size, out_channels]
return self.fc(out)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, patch_size={self.patch_size})')

0 comments on commit 4d8d16c

Please sign in to comment.