From c66c8daafde5338223936b6a73ae1218069fcb88 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Thu, 13 Jun 2024 01:33:58 +0900 Subject: [PATCH 1/4] Add MPLP model and example --- examples/mplp.py | 599 ++++++++++++++++++++++++++ torch_geometric/nn/models/__init__.py | 3 + torch_geometric/nn/models/mplp.py | 430 ++++++++++++++++++ 3 files changed, 1032 insertions(+) create mode 100644 examples/mplp.py create mode 100644 torch_geometric/nn/models/mplp.py diff --git a/examples/mplp.py b/examples/mplp.py new file mode 100644 index 000000000000..77ee9039ab07 --- /dev/null +++ b/examples/mplp.py @@ -0,0 +1,599 @@ +"""Implementation of MPLP from the `Pure Message Passing Can Estimate Common Neighbor + for Link Prediction `_ paper. +Based on the code https://github.com/Barcavin/efficient-node-labelling +""" +import argparse +import random +import numpy as np +import os +import sys +import time +from pathlib import Path +from typing import List, Tuple + +import torch +from torch_sparse import SparseTensor +from torch.nn import BCEWithLogitsLoss +from torch.utils.data import DataLoader + +from torch_geometric.data import Data +from torch_geometric.transforms import ToSparseTensor, ToUndirected +import torch_geometric.transforms as T + + +from sklearn.metrics import roc_auc_score + +from ogb.linkproppred import PygLinkPropPredDataset +from ogb.linkproppred import Evaluator + +from torch_geometric.nn.models import MLP, MPLP_GCN, MPLP + +from torch_geometric.utils import degree +from tqdm import tqdm + + +######################## +######## Utils ######### +######################## + + +def get_dataset(root, name: str): + dataset = PygLinkPropPredDataset(name=name, root=root) + data = dataset[0] + """ + SparseTensor's value is NxNx1 for collab. due to edge_weight is |E|x1 + NeuralNeighborCompletion just set edge_weight=None + ELPH use edge_weight + """ + + split_edge = dataset.get_edge_split() + if 'edge_weight' in data: + data.edge_weight = data.edge_weight.view(-1).to(torch.float) + if 'edge' in split_edge['train']: + key = 'edge' + else: + key = 'source_node' + print("-"*20) + print(f"train: {split_edge['train'][key].shape[0]}") + print(f"{split_edge['train'][key]}") + print(f"valid: {split_edge['valid'][key].shape[0]}") + print(f"test: {split_edge['test'][key].shape[0]}") + print(f"max_degree:{degree(data.edge_index[0], data.num_nodes).max()}") + data = ToUndirected()(data) + data = ToSparseTensor(remove_edge_index=False)(data) + data.full_adj_t = data.adj_t + # make node feature as float + if data.x is not None: + data.x = data.x.to(torch.float) + if name != 'ogbl-ddi': + del data.edge_index + return data, split_edge + +def set_random_seeds(random_seed=0): + r"""Sets the seed for generating random numbers.""" + torch.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) + np.random.seed(random_seed) + random.seed(random_seed) + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def data_summary(name: str, data: Data): + num_nodes = data.num_nodes + num_edges = data.num_edges + n_degree = data.adj_t.sum(dim=1).to(torch.float) + avg_degree = n_degree.mean().item() + degree_std = n_degree.std().item() + max_degree = n_degree.max().long().item() + density = num_edges / (num_nodes * (num_nodes - 1) / 2) + if data.x is not None: + attr_dim = data.x.shape[1] + else: + attr_dim = '-' # no attribute + + print("-"*30+'Dataset and Features'+"-"*60) + print("{:<10}|{:<10}|{:<10}|{:<15}|{:<15}|{:<15}|{:<10}|{:<15}"\ + .format('Dataset','#Nodes','#Edges','Avg. node deg.','Std. node deg.','Max. node deg.', 'Density','Attr. Dimension')) + print("-"*110) + print("{:<10}|{:<10}|{:<10}|{:<15.2f}|{:<15.2f}|{:<15}|{:<9.4f}%|{:<15}"\ + .format(name, num_nodes, num_edges, avg_degree, degree_std, max_degree, density*100, attr_dim)) + print("-"*110) + +def initial_embedding(data, hidden_channels, device): + embedding= torch.nn.Embedding(data.num_nodes, hidden_channels).to(device) + torch.nn.init.xavier_uniform_(embedding.weight) + return embedding + +def create_input(data): + if hasattr(data, 'emb') and data.emb is not None: + x = data.emb.weight + else: + x = data.x + return x + + +######################## +##### Train utils ###### +######################## + + +def __elem2spm(element: torch.Tensor, sizes: List[int], val: torch.Tensor=None) -> SparseTensor: + # Convert adjacency matrix to a 1-d vector + col = torch.bitwise_and(element, 0xffffffff) + row = torch.bitwise_right_shift(element, 32) + if val is None: + sp_tensor = SparseTensor(row=row, col=col, sparse_sizes=sizes).to_device( + element.device).fill_value_(1.0) + else: + sp_tensor = SparseTensor(row=row, col=col, value=val, sparse_sizes=sizes).to_device( + element.device) + return sp_tensor + + +def __spm2elem(spm: SparseTensor) -> torch.Tensor: + # Convert 1-d vector to an adjacency matrix + sizes = spm.sizes() + elem = torch.bitwise_left_shift(spm.storage.row(), + 32).add_(spm.storage.col()) + val = spm.storage.value() + return elem, val + +def __spmdiff(adj1: SparseTensor, + adj2: SparseTensor, keep_val=False) -> Tuple[SparseTensor, SparseTensor]: + ''' + return elements in adj1 but not in adj2 and in adj2 but not adj1 + ''' + element1, val1 = __spm2elem(adj1) + element2, val2 = __spm2elem(adj2) + + if element1.shape[0] == 0: + retelem1 = element1 + retelem2 = element2 + else: + idx = torch.searchsorted(element1[:-1], element2) + matchedmask = (element1[idx] == element2) + + maskelem1 = torch.ones_like(element1, dtype=torch.bool) + maskelem1[idx[matchedmask]] = 0 + retelem1 = element1[maskelem1] + + if keep_val and val1 is not None: + retval1 = val1[maskelem1] + return __elem2spm(retelem1, adj1.sizes(), retval1) + else: + return __elem2spm(retelem1, adj1.sizes()) + + +def get_train_test(args): + if args.dataset == "ogbl-citation2": + evaluator = Evaluator(name='ogbl-citation2') + return train_mrr, test_mrr, evaluator + else: + evaluator = Evaluator(name='ogbl-ddi') + return train_hits, test_hits, evaluator + +def train_hits(encoder, predictor, data, split_edge, optimizer, batch_size, + mask_target, num_neg): + encoder.train() + predictor.train() + device = data.adj_t.device() + criterion = BCEWithLogitsLoss(reduction='mean') + pos_train_edge = split_edge['train']['edge'].to(device) + + optimizer.zero_grad() + total_loss = total_examples = 0 + num_pos_max = max(data.adj_t.nnz()//2, pos_train_edge.size(0)) + neg_edge_epoch = torch.randint(0, data.adj_t.size(0), + size=(2, num_pos_max*num_neg), + dtype=torch.long, device=device) + for perm in tqdm(DataLoader(range(pos_train_edge.size(0)), batch_size, + shuffle=True),desc='Train'): + edge = pos_train_edge[perm].t() + if mask_target: + adj_t = data.adj_t + undirected_edges = torch.cat((edge, edge.flip(0)), dim=-1) + target_adj = SparseTensor.from_edge_index(undirected_edges, sparse_sizes=adj_t.sizes()) + adj_t = __spmdiff(adj_t, target_adj, keep_val=True) + else: + adj_t = data.adj_t + + h = encoder(data.x, adj_t) + + neg_edge = neg_edge_epoch[:,perm] + train_edges = torch.cat((edge, neg_edge), dim=-1) + train_label = torch.cat((torch.ones(edge.size()[1]), torch.zeros(neg_edge.size()[1])), dim=0).to(device) + out = predictor(h, adj_t, train_edges).squeeze() + loss = criterion(out, train_label) + + loss.backward() + + if data.x is not None: + torch.nn.utils.clip_grad_norm_(data.x, 1.0) + torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0) + torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() + total_examples += train_label.size(0) + total_loss += loss.item() * train_label.size(0) + + return total_loss / total_examples + + +@torch.no_grad() +def test_hits(encoder, predictor, data, split_edge, evaluator, + batch_size, fast_inference): + encoder.eval() + predictor.eval() + device = data.adj_t.device() + adj_t = data.adj_t + h = encoder(data.x, adj_t) + + def test_split(split, cache_mode=None): + pos_test_edge = split_edge[split]['edge'].to(device) + neg_test_edge = split_edge[split]['edge_neg'].to(device) + + pos_test_preds = [] + for perm in DataLoader(range(pos_test_edge.size(0)), batch_size): + edge = pos_test_edge[perm].t() + out = predictor(h, adj_t, edge, cache_mode=cache_mode) + pos_test_preds += [out.squeeze().cpu()] + pos_test_pred = torch.cat(pos_test_preds, dim=0) + + neg_test_preds = [] + for perm in DataLoader(range(neg_test_edge.size(0)), batch_size): + edge = neg_test_edge[perm].t() + neg_test_preds += [predictor(h, adj_t, edge, cache_mode=cache_mode).squeeze().cpu()] + neg_test_pred = torch.cat(neg_test_preds, dim=0) + return pos_test_pred, neg_test_pred + + pos_valid_pred, neg_valid_pred = test_split('valid') + + start_time = time.perf_counter() + if fast_inference: + # caching + predictor(h, adj_t, None, cache_mode='build') + cache_mode='use' + else: + cache_mode=None + + pos_test_pred, neg_test_pred = test_split('test', cache_mode=cache_mode) + end_time = time.perf_counter() + total_time = end_time - start_time + print(f'Inference for one epoch Took {total_time:.4f} seconds') + if fast_inference: + # delete cache + predictor(h, adj_t, None, cache_mode='delete') + + results = {} + K = 100 + evaluator.K = K + valid_hits = evaluator.eval({ + 'y_pred_pos': pos_valid_pred, + 'y_pred_neg': neg_valid_pred, + })[f'hits@{K}'] + test_hits = evaluator.eval({ + 'y_pred_pos': pos_test_pred, + 'y_pred_neg': neg_test_pred, + })[f'hits@{K}'] + + results[f'Hits@{K}'] = (valid_hits, test_hits) + + valid_result = torch.cat((torch.ones(pos_valid_pred.size()), torch.zeros(neg_valid_pred.size())), dim=0) + valid_pred = torch.cat((pos_valid_pred, neg_valid_pred), dim=0) + + test_result = torch.cat((torch.ones(pos_test_pred.size()), torch.zeros(neg_test_pred.size())), dim=0) + test_pred = torch.cat((pos_test_pred, neg_test_pred), dim=0) + + results['AUC'] = (roc_auc_score(valid_result.cpu().numpy(),valid_pred.cpu().numpy()),roc_auc_score(test_result.cpu().numpy(),test_pred.cpu().numpy())) + + return results + + +def make_symmetric(sparse_tensor, reduce='sum'): + # Extract COO format + indices = sparse_tensor.coalesce().indices() + row, col = indices[0], indices[1] + value = sparse_tensor.coalesce().values() + + # Concatenate the original and transposed entries + all_row = torch.cat([row, col]) + all_col = torch.cat([col, row]) + all_value = torch.cat([value, value]) + + # Create a new COO matrix with these entries + new_indices = torch.stack([all_row, all_col]) + new_value = all_value + + # Remove duplicates by summing the values for symmetric entries + unique_indices, inverse_indices = torch.unique(new_indices, dim=1, return_inverse=True) + unique_value = torch.zeros(unique_indices.size(1), device=value.device).scatter_reduce_(0, inverse_indices, new_value, reduce="amax") + + # Create the symmetric sparse tensor + symmetric_tensor = torch.sparse_coo_tensor(unique_indices, unique_value, sparse_tensor.size()) + + return symmetric_tensor + + +def train_mrr(encoder, predictor, data, split_edge, optimizer, batch_size, + mask_target, num_neg): + encoder.train() + predictor.train() + device = data.adj_t.device() + criterion = BCEWithLogitsLoss(reduction='mean') + source_edge = split_edge['train']['source_node'].to(device) + target_edge = split_edge['train']['target_node'].to(device) + adjmask = torch.ones_like(source_edge, dtype=torch.bool) + + optimizer.zero_grad() + total_loss = total_examples = 0 + for perm in tqdm(DataLoader(range(source_edge.size(0)), batch_size, + shuffle=True),desc='Train'): + if mask_target: + adjmask[perm] = 0 + tei = torch.stack((source_edge[adjmask], target_edge[adjmask]), dim=0) # TODO: check if both direction is removed + + adj_t = SparseTensor.from_edge_index(tei, + sparse_sizes=(data.num_nodes, data.num_nodes)).to_device( + source_edge.device, non_blocking=True) + adjmask[perm] = 1 + + adj_t = adj_t.to_symmetric() + + #adj_t = torch.sparse_coo_tensor(tei, torch.ones((tei.size(1)), device=tei.device), (data.num_nodes, data.num_nodes)) + #adj_t = adj_t.coalesce() + #adj_t = make_symmetric(adj_t).coalesce() + else: + adj_t = data.adj_t + + + h = encoder(data.x, adj_t) + dst_neg = torch.randint(0, data.num_nodes, perm.size()*num_neg, + dtype=torch.long, device=device) + + edge = torch.stack((source_edge[perm], target_edge[perm]), dim=0) + neg_edge = torch.stack((source_edge[perm].repeat(num_neg), dst_neg), dim=0) + train_edges = torch.cat((edge, neg_edge), dim=-1) + train_label = torch.cat((torch.ones(edge.size()[1]), torch.zeros(neg_edge.size()[1])), dim=0).to(device) + out = predictor(h, adj_t, train_edges).squeeze() + loss = criterion(out, train_label) + + loss.backward() + + if data.x is not None: + torch.nn.utils.clip_grad_norm_(data.x, 1.0) + torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.0) + torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() + total_examples += train_label.size(0) + total_loss += loss.item() * train_label.size(0) + + return total_loss / total_examples + + +@torch.no_grad() +def test_mrr(encoder, predictor, data, split_edge, evaluator, + batch_size, fast_inference): + encoder.eval() + predictor.eval() + device = data.adj_t.device() + adj_t = data.adj_t + h = encoder(data.x, adj_t) + + def test_split(split, cache_mode=None): + source = split_edge[split]['source_node'].to(device) + target = split_edge[split]['target_node'].to(device) + target_neg = split_edge[split]['target_node_neg'].to(device) + + pos_preds = [] + for perm in DataLoader(range(source.size(0)), batch_size): + src, dst = source[perm], target[perm] + pos_preds += [predictor(h, adj_t, torch.stack((src, dst)), cache_mode=cache_mode).squeeze().cpu()] + pos_pred = torch.cat(pos_preds, dim=0) + + neg_preds = [] + source = source.view(-1, 1).repeat(1, 1000).view(-1) + target_neg = target_neg.view(-1) + for perm in DataLoader(range(source.size(0)), batch_size): + src, dst_neg = source[perm], target_neg[perm] + neg_preds += [predictor(h, adj_t, torch.stack((src, dst_neg)), cache_mode=cache_mode).squeeze().cpu()] + neg_pred = torch.cat(neg_preds, dim=0).view(-1, 1000) + + return pos_pred, neg_pred + + pos_valid_pred, neg_valid_pred = test_split('valid') + + start_time = time.perf_counter() + if fast_inference: + # caching + predictor(h, adj_t, None, cache_mode='build') + cache_mode='use' + else: + cache_mode=None + + pos_test_pred, neg_test_pred = test_split('test', cache_mode=cache_mode) + end_time = time.perf_counter() + total_time = end_time - start_time + print(f'Inference for one epoch Took {total_time:.4f} seconds') + if fast_inference: + # delete cache + predictor(h, adj_t, None, cache_mode='delete') + + valid_mrr = evaluator.eval({ + 'y_pred_pos': pos_valid_pred, + 'y_pred_neg': neg_valid_pred, + })['mrr_list'].mean().item() + test_mrr = evaluator.eval({ + 'y_pred_pos': pos_test_pred, + 'y_pred_neg': neg_test_pred, + })['mrr_list'].mean().item() + + results = { + "MRR": (valid_mrr, test_mrr), + } + + return results + + +######################## +######## Main ######### +######################## + +def main(): + parser = argparse.ArgumentParser(description='MPLP') + # dataset setting + parser.add_argument('--dataset', type=str, default='collab') + parser.add_argument('--val_ratio', type=float, default=0.1) + parser.add_argument('--test_ratio', type=float, default=0.2) + parser.add_argument('--dataset_dir', type=str, default='./data') + + # MPLP settings + parser.add_argument('--signature_dim', type=int, default=1024, help="the node signature dimension `F` in MPLP") + parser.add_argument('--minimum_degree_onehot', type=int, default=-1, help='the minimum degree of hubs with onehot encoding to reduce variance') + parser.add_argument('--use_degree', type=str, default='none', choices=["none","mlp","AA","RA"], help="rescale vector norm to facilitate weighted count") + parser.add_argument('--signature_sampling', type=str, default='torchhd', help='whether to use torchhd to randomize vectors', choices=["torchhd","gaussian","onehot"]) + parser.add_argument('--fast_inference', type=str2bool, default='False', help='whether to enable a faster inference by caching the node vectors') + parser.add_argument('--mask_target', type=str2bool, default='True', help='whether to mask the target edges to remove the shortcut') + + # model setting + parser.add_argument('--encoder', type=str, default='gcn') + parser.add_argument('--hidden_channels', type=int, default=256) + parser.add_argument('--xdp', type=float, default=0.2) + parser.add_argument('--feat_dropout', type=float, default=0.5) + parser.add_argument('--label_dropout', type=float, default=0.5) + parser.add_argument('--num_layers', type=int, default=2) + parser.add_argument('--device', type=int, default=0) + parser.add_argument('--use_feature', type=str2bool, default='True', help='whether to use node features as input') + parser.add_argument('--feature_combine', type=str, default='hadamard', choices=['hadamard','plus_minus'], help='how to represent a link with two nodes features') + parser.add_argument('--jk', type=str2bool, default='True', help='whether to use Jumping Knowledge') + parser.add_argument('--batchnorm_affine', type=str2bool, default='True', help='whether to use Affine in BatchNorm') + parser.add_argument('--use_embedding', type=str2bool, default='False', help='whether to train node embedding') + + # training setting + parser.add_argument('--batch_size', type=int, default=64 * 1024) + parser.add_argument('--test_batch_size', type=int, default=100000) + parser.add_argument('--epochs', type=int, default=20000) + parser.add_argument('--num_neg', type=int, default=1) + parser.add_argument('--num_hops', type=int, default=2) + parser.add_argument('--lr', type=float, default=0.005) + parser.add_argument('--weight_decay', type=float, default=0) + parser.add_argument('--log_steps', type=int, default=20) + parser.add_argument('--patience', type=int, default=100, help='number of patience steps for early stopping') + parser.add_argument('--runs', type=int, default=10) + parser.add_argument('--metric', type=str, default='Hits@100', help='main evaluation metric') + + # misc + parser.add_argument('--data_split_only', type=str2bool, default='False') + parser.add_argument('--print_summary', type=str, default='') + + args = parser.parse_args() + # start time + start_time = time.time() + set_random_seeds(234) + + device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' + device = torch.device(device) + + data, split_edge = get_dataset(args.dataset_dir, args.dataset) + if args.dataset == "ogbl-citation2": + args.metric = "MRR" + if data.x is None: + args.use_feature = False + + if args.print_summary: + data_summary(args.dataset, data);exit(0) + else: + print(args) + + # Save command line input. + cmd_input = 'python ' + ' '.join(sys.argv) + '\n' + print('Command line input: ' + cmd_input + ' is saved.') + + train, test, evaluator = get_train_test(args) + + val_max = 0.0 + for run in range(args.runs): + if args.minimum_degree_onehot > 0: + d_v = data.adj_t.sum(dim=0).to_dense() + nodes_to_one_hot = d_v >= args.minimum_degree_onehot + one_hot_dim = nodes_to_one_hot.sum() + print(f"number of nodes to onehot: {int(one_hot_dim)}") + data = data.to(device) + if args.use_embedding: + emb = initial_embedding(data, args.hidden_channels, device) + else: + emb = None + if 'gcn' in args.encoder: + encoder = MPLP_GCN(data.num_features, args.hidden_channels, + args.hidden_channels, args.num_layers, + args.feat_dropout, args.xdp, args.use_feature, args.jk, args.encoder, emb).to(device) + elif args.encoder == 'mlp': + encoder = MLP(num_layers=args.num_layers, in_channels=data.num_features, + hidden_channels=args.hidden_channels, out_channels=args.hidden_channels, + dropout=args.feat_dropout, act=None).to(device) + + predictor_in_dim = args.hidden_channels * int(args.use_feature or args.use_embedding) + + predictor = MPLP(predictor_in_dim, args.hidden_channels, + args.num_layers, args.feat_dropout, args.label_dropout, args.num_hops, + signature_sampling=args.signature_sampling, + use_degree=args.use_degree, signature_dim=args.signature_dim, + minimum_degree_onehot=args.minimum_degree_onehot, batchnorm_affine=args.batchnorm_affine, + feature_combine=args.feature_combine) + + predictor = predictor.to(device) + + encoder.reset_parameters() + predictor.reset_parameters() + parameters = list(encoder.parameters()) + list(predictor.parameters()) + optimizer = torch.optim.Adam(parameters, lr=args.lr, weight_decay=args.weight_decay) + total_params = sum(p.numel() for param in parameters for p in param) + print(f'Total number of parameters is {total_params}') + + cnt_wait = 0 + best_val = 0.0 + + for epoch in range(1, 1 + args.epochs): + loss = train(encoder, predictor, data, split_edge, + optimizer, args.batch_size, args.mask_target, + num_neg=args.num_neg) + + results = test(encoder, predictor, data, split_edge, + evaluator, args.test_batch_size, args.fast_inference) + + if results[args.metric][0] >= best_val: + best_val = results[args.metric][0] + cnt_wait = 0 + else: + cnt_wait +=1 + + if epoch % args.log_steps == 0: + for key, result in results.items(): + valid_hits, test_hits = result + to_print = (f'Run: {run + 1:02d}, ' + + f'Epoch: {epoch:02d}, '+ + f'Loss: {loss:.4f}, '+ + f'Valid: {100 * valid_hits:.2f}%, '+ + f'Test: {100 * test_hits:.2f}%') + print(key) + print(to_print) + print('---') + + if cnt_wait >= args.patience: + break + print(f'Highest Valid: {best_val}') + # end time + end_time = time.time() + print(f"Total time: {end_time - start_time:.4f}s") + +if __name__ == "__main__": + main() + diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 334970da5c62..dc9a6f8aa4bf 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -28,6 +28,7 @@ from .pmlp import PMLP from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet +from .mplp import MPLP_GCN, MPLP # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, @@ -75,4 +76,6 @@ 'PMLP', 'NeuralFingerprint', 'ViSNet', + 'MPLP_GCN', + 'MPLP', ] diff --git a/torch_geometric/nn/models/mplp.py b/torch_geometric/nn/models/mplp.py new file mode 100644 index 000000000000..afb657a4326d --- /dev/null +++ b/torch_geometric/nn/models/mplp.py @@ -0,0 +1,430 @@ + +import math + +from torch import Tensor + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch_geometric.nn import GCNConv, SGConv +from torch_geometric.nn.models import MLP +from torch_geometric.utils import k_hop_subgraph as pyg_k_hop_subgraph, to_edge_index + +from torch_sparse import SparseTensor, matmul +from torch_sparse.matmul import spmm_max, spmm_mean, spmm_add + + +from functools import partial + +from typing import Final + +import torchhd + +USE_CUSTOM_MLP=True + +######################## +###### NodeLabel ####### +######################## + +MINIMUM_SIGNATURE_DIM=64 + +class NodeLabel(torch.nn.Module): + def __init__(self, dim: int=1024, signature_sampling="torchhd", + minimum_degree_onehot: int=-1): + super().__init__() + self.dim = dim + self.signature_sampling = signature_sampling + self.cached_two_hop_adj = None + self.minimum_degree_onehot = minimum_degree_onehot + + def forward(self, edges: Tensor, adj_t: SparseTensor, node_weight: Tensor=None, cache_mode=None): + if cache_mode is not None: + return self.propagation_only_cache(edges, adj_t, node_weight, cache_mode) + else: + return self.propagation_only(edges, adj_t, node_weight) + + def get_random_node_vectors(self, adj_t: SparseTensor, node_weight) -> Tensor: + num_nodes = adj_t.size(0) + device = adj_t.device() + if self.minimum_degree_onehot > 0: + degree = adj_t.sum(dim=1) + nodes_to_one_hot = degree >= self.minimum_degree_onehot + one_hot_dim = nodes_to_one_hot.sum() + if one_hot_dim + MINIMUM_SIGNATURE_DIM > self.dim: + raise ValueError(f"There are {int(one_hot_dim)} nodes with degree higher than {self.minimum_degree_onehot}, select a higher threshold to choose fewer nodes as hub") + embedding = torch.zeros(num_nodes, self.dim, device=device) + if one_hot_dim>0: + one_hot_embedding = F.one_hot(torch.arange(0, one_hot_dim)).float().to(device) + embedding[nodes_to_one_hot,:one_hot_dim] = one_hot_embedding + else: + embedding = torch.zeros(num_nodes, self.dim, device=device) + nodes_to_one_hot = torch.zeros(num_nodes, dtype=torch.bool, device=device) + one_hot_dim = 0 + rand_dim = self.dim - one_hot_dim + + if self.signature_sampling == "torchhd": + scale = math.sqrt(1 / rand_dim) + node_vectors = torchhd.random(num_nodes - one_hot_dim, rand_dim, device=device) + node_vectors.mul_(scale) # make them unit vectors + elif self.signature_sampling == "gaussian": + node_vectors = F.normalize(torch.nn.init.normal_(torch.empty((num_nodes - one_hot_dim, rand_dim), dtype=torch.float32, device=device))) + elif self.signature_sampling == "onehot": + embedding = torch.zeros(num_nodes, num_nodes, device=device) + node_vectors = F.one_hot(torch.arange(0, num_nodes)).float().to(device) + + embedding[~nodes_to_one_hot, one_hot_dim:] = node_vectors + + if node_weight is not None: + node_weight = node_weight.unsqueeze(1) # Note: not sqrt here because it can cause problem for MLP when output is negative + # thus, it requires the MLP to approximate one more sqrt? + embedding.mul_(node_weight) + return embedding + + def propagation_only(self, edges: Tensor, adj_t: SparseTensor, node_weight=None): + adj_t, new_edges, subset_nodes = subgraph(edges, adj_t, 2) + node_weight = node_weight[subset_nodes] if node_weight is not None else None + x = self.get_random_node_vectors(adj_t, node_weight=node_weight) + subset = new_edges.view(-1) # flatten the target nodes [row, col] + + subset_unique, inverse_indices = torch.unique(subset, return_inverse=True) + one_hop_x_subgraph_nodes = matmul(adj_t, x) + one_hop_x = one_hop_x_subgraph_nodes[subset] + two_hop_x = matmul(adj_t[subset_unique], one_hop_x_subgraph_nodes)[inverse_indices] + degree_one_hop = adj_t.sum(dim=1) + + one_hop_x = one_hop_x.view(2, new_edges.size(1), -1) + two_hop_x = two_hop_x.view(2, new_edges.size(1), -1) + + count_1_1 = dot_product(one_hop_x[0,:,:], one_hop_x[1,:,:]) + count_1_2 = dot_product(one_hop_x[0,:,:], two_hop_x[1,:,:]) + count_2_1 = dot_product(two_hop_x[0,:,:] , one_hop_x[1,:,:]) + count_2_2 = dot_product((two_hop_x[0,:,:]-degree_one_hop[new_edges[0]].view(-1,1)*x[new_edges[0]]), + (two_hop_x[1,:,:]-degree_one_hop[new_edges[1]].view(-1,1)*x[new_edges[1]])) + + count_self_1_2 = dot_product(one_hop_x[0,:,:] , two_hop_x[0,:,:]) + count_self_2_1 = dot_product(one_hop_x[1,:,:] , two_hop_x[1,:,:]) + degree_u = degree_one_hop[new_edges[0]] + degree_v = degree_one_hop[new_edges[1]] + return count_1_1, count_1_2, count_2_1, count_2_2, count_self_1_2, count_self_2_1, degree_u, degree_v + + def propagation_only_cache(self, edges: Tensor, adj_t: SparseTensor, node_weight=None, cache_mode=None): + if cache_mode == 'build': + # get the 2-hop subgraph of the target edges + x = self.get_random_node_vectors(adj_t, node_weight=node_weight) + + degree_one_hop = adj_t.sum(dim=1) + + one_hop_x = matmul(adj_t, x) + two_iter_x = matmul(adj_t, one_hop_x) + + # caching + self.cached_x = x + self.cached_degree_one_hop = degree_one_hop + + self.cached_one_hop_x = one_hop_x + self.cached_two_iter_x = two_iter_x + return + if cache_mode == 'delete': + del self.cached_x + del self.cached_degree_one_hop + del self.cached_one_hop_x + del self.cached_two_iter_x + return + if cache_mode == 'use': + # loading + x = self.cached_x + degree_one_hop = self.cached_degree_one_hop + + one_hop_x = self.cached_one_hop_x + two_iter_x = self.cached_two_iter_x + count_1_1 = dot_product(one_hop_x[edges[0]] , one_hop_x[edges[1]]) + count_1_2 = dot_product(one_hop_x[edges[0]] , two_iter_x[edges[1]]) + count_2_1 = dot_product(two_iter_x[edges[0]] , one_hop_x[edges[1]]) + count_2_2 = dot_product(two_iter_x[edges[0]]-degree_one_hop[edges[0]].view(-1,1)*x[edges[0]],\ + two_iter_x[edges[1]]-degree_one_hop[edges[1]].view(-1,1)*x[edges[1]]) + + count_self_1_2 = dot_product(one_hop_x[edges[0]] , two_iter_x[edges[0]]) + count_self_2_1 = dot_product(one_hop_x[edges[1]] , two_iter_x[edges[1]]) + + degree_u = degree_one_hop[edges[0]] + degree_v = degree_one_hop[edges[1]] + return count_1_1, count_1_2, count_2_1, count_2_2, count_self_1_2, count_self_2_1, degree_u, degree_v + +def subgraph(edges: Tensor, adj_t: SparseTensor, k: int=2): + row,col = edges + nodes = torch.cat((row,col),dim=-1) + edge_index,_ = to_edge_index(adj_t) + subset, new_edge_index, inv, edge_mask = pyg_k_hop_subgraph(nodes, k, edge_index=edge_index, + num_nodes=adj_t.size(0), relabel_nodes=True) + # subset[inv] = nodes. The new node id is based on `subset`'s order. + # inv means the new idx (in subset) of the old nodes in `nodes` + new_adj_t = SparseTensor(row=new_edge_index[0], col=new_edge_index[1], + sparse_sizes=(subset.size(0), subset.size(0))) + new_edges = inv.view(2,-1) + return new_adj_t, new_edges, subset + +def dotproduct_naive(tensor1, tensor2): + return (tensor1 * tensor2).sum(dim=-1) + +def dotproduct_bmm(tensor1, tensor2): + return torch.bmm(tensor1.unsqueeze(1), tensor2.unsqueeze(2)).view(-1) + +dot_product = dotproduct_naive + +######################## +######### MLP ########## +######################## + +class CustomMLP(nn.Module): + def __init__( + self, + num_layers, + input_dim, + hidden_dim, + output_dim, + dropout_ratio, + norm_type="none", + tailnormactdrop=False, + affine=True, + ): + super(CustomMLP, self).__init__() + self.num_layers = num_layers + self.norm_type = norm_type + self.tailnormactdrop = tailnormactdrop + self.affine = affine # the affine in batchnorm + self.layers = [] + if num_layers == 1: + self.layers.append(nn.Linear(input_dim, output_dim)) + if tailnormactdrop: + self.__build_normactdrop(self.layers, output_dim, dropout_ratio) + else: + self.layers.append(nn.Linear(input_dim, hidden_dim)) + self.__build_normactdrop(self.layers, hidden_dim, dropout_ratio) + for i in range(num_layers - 2): + self.layers.append(nn.Linear(hidden_dim, hidden_dim)) + self.__build_normactdrop(self.layers, hidden_dim, dropout_ratio) + self.layers.append(nn.Linear(hidden_dim, output_dim)) + if tailnormactdrop: + self.__build_normactdrop(self.layers, hidden_dim, dropout_ratio) + self.layers = nn.Sequential(*self.layers) + + def __build_normactdrop(self, layers, dim, dropout): + if self.norm_type == "batch": + layers.append(nn.BatchNorm1d(dim, affine=self.affine)) + elif self.norm_type == "layer": + layers.append(nn.LayerNorm(dim)) + layers.append(nn.Dropout(dropout, inplace=True)) + layers.append(nn.ReLU(inplace=True)) + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm1d): + m.reset_parameters() + + def forward(self, feats, adj_t=None): + return self.layers(feats) + +######################## +######### GNN ########## +######################## + +# Addpted from NCNC +class PureConv(nn.Module): + aggr: Final[str] + def __init__(self, indim, outdim, aggr="gcn") -> None: + super().__init__() + self.aggr = aggr + if indim == outdim: + self.lin = nn.Identity() + else: + raise NotImplementedError + + def forward(self, x, adj_t): + x = self.lin(x) + if self.aggr == "mean": + return spmm_mean(adj_t, x) + elif self.aggr == "max": + return spmm_max(adj_t, x)[0] + elif self.aggr == "sum": + return spmm_add(adj_t, x) + elif self.aggr == "gcn": + norm = torch.rsqrt_((1+adj_t.sum(dim=-1))).reshape(-1, 1) + x = norm * x + x = spmm_add(adj_t, x) + x + x = norm * x + return x + +class MPLP_GCN(torch.nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers, + dropout, xdropout, use_feature=True, jk=False, gcn_name='gcn', embedding=None): + super(MPLP_GCN, self).__init__() + + self.use_feature = use_feature + self.embedding = embedding + self.dropout = dropout + self.xdropout = xdropout + self.input_size = 0 + self.jk = jk + if jk: + self.register_parameter("jkparams", nn.Parameter(torch.randn((num_layers,)))) + if self.use_feature: + self.input_size += in_channels + if self.embedding is not None: + self.input_size += embedding.embedding_dim + self.convs = torch.nn.ModuleList() + + if self.input_size > 0: + if gcn_name == 'gcn': + conv_func = partial(GCNConv, cached=False) + elif 'pure' in gcn_name: + conv_func = partial(SGConv, apply_linearity=False) + self.xemb = nn.Sequential(nn.Dropout(xdropout)) + if ("pure" in gcn_name or num_layers==0): + self.xemb.append(nn.Linear(self.input_size, hidden_channels)) + self.xemb.append(nn.Dropout(dropout, inplace=True) if dropout > 1e-6 else nn.Identity()) + self.input_size = hidden_channels + self.convs.append(conv_func(self.input_size, hidden_channels)) + for _ in range(num_layers - 2): + self.convs.append( + conv_func(hidden_channels, hidden_channels)) + self.convs.append(conv_func(hidden_channels, out_channels)) + + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm1d): + m.reset_parameters() + + def forward(self, x, adj_t): + if self.input_size > 0: + xs = [] + if self.use_feature: + xs.append(x) + if self.embedding is not None: + xs.append(self.embedding.weight) + x = torch.cat(xs, dim=1) + x = self.xemb(x) + jkx = [] + for conv in self.convs: + x = conv(x, adj_t) + if self.jk: + jkx.append(x) + if self.jk: # JumpingKnowledge Connection + jkx = torch.stack(jkx, dim=0) + sftmax = self.jkparams.reshape(-1, 1, 1) + x = torch.sum(jkx*sftmax, dim=0) + return x + +######################## +######### MPLP ######### +######################## + +class MPLP(torch.nn.Module): + def __init__(self, in_channels, hidden_channels, num_layers, + feat_dropout, label_dropout, num_hops=2, signature_sampling='torchhd', use_degree='none', + signature_dim=1024, minimum_degree_onehot=-1, batchnorm_affine=True, + feature_combine="hadamard"): + super(MPLP, self).__init__() + + self.in_channels = in_channels + self.feat_dropout = feat_dropout + self.label_dropout = label_dropout + self.num_hops = num_hops + self.signature_sampling=signature_sampling + self.use_degree = use_degree + self.feature_combine = feature_combine + if self.use_degree == 'mlp': + if USE_CUSTOM_MLP: + self.node_weight_encode = CustomMLP(2, in_channels + 1, 32, 1, feat_dropout, norm_type="batch", affine=batchnorm_affine) + else: + self.node_weight_encode = MLP(num_layers=2, in_channels=in_channels + 1, hidden_channels=32, out_channels=1, + dropout=self.label_dropout, act='relu', + norm="BatchNorm", norm_kwargs={"affine": batchnorm_affine}) + struct_dim = 8 + self.nodelabel = NodeLabel(signature_dim, signature_sampling=self.signature_sampling, + minimum_degree_onehot= minimum_degree_onehot) + if USE_CUSTOM_MLP: + self.struct_encode = CustomMLP(1, struct_dim, struct_dim, struct_dim, self.label_dropout, "batch", tailnormactdrop=True, affine=batchnorm_affine) + else: + self.struct_encode = MLP(num_layers=1, in_channels=struct_dim, hidden_channels=struct_dim, out_channels=struct_dim, + dropout=self.label_dropout, act='relu', plain_last=False, + norm="BatchNorm", norm_kwargs={"affine": batchnorm_affine}) + + dense_dim = struct_dim + in_channels + if in_channels > 0: + if feature_combine == "hadamard": + feat_encode_input_dim = in_channels + elif feature_combine == "plus_minus": + feat_encode_input_dim = in_channels * 2 + if USE_CUSTOM_MLP: + self.feat_encode = CustomMLP(2, feat_encode_input_dim, in_channels, in_channels, self.feat_dropout, "batch", tailnormactdrop=True, affine=batchnorm_affine) + else: + self.feat_encode = MLP(num_layers=1, in_channels=feat_encode_input_dim, hidden_channels=in_channels, out_channels=in_channels, + dropout=self.label_dropout, act='relu', plain_last=False, + norm="BatchNorm", norm_kwargs={"affine": batchnorm_affine}) + self.classifier = nn.Linear(dense_dim, 1) + + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm1d): + m.reset_parameters() + + def forward(self, x, adj, edges, cache_mode=None): + """ + Args: + x: [N, in_channels] node embedding after GNN + adj: [N, N] adjacency matrix + edges: [2, E] target edges + fast_inference: bool. If True, only caching the message-passing without calculating the structural features + """ + if cache_mode in ["use","delete"]: + # no need to compute node_weight + node_weight = None + elif self.use_degree == 'none': + node_weight = None + elif self.use_degree == 'mlp': # 'mlp' for now + xs = [] + if self.in_channels > 0: + xs.append(x) + degree = adj.sum(dim=1).view(-1,1).to(adj.device()) + xs.append(degree) + node_weight_feat = torch.cat(xs, dim=1) + node_weight = self.node_weight_encode(node_weight_feat).squeeze(-1) + 1 # like residual, can be learned as 0 if needed + else: + # AA or RA + degree = adj.sum(dim=1).view(-1,1).to(adj.device()).squeeze(-1) + 1 # degree at least 1. then log(degree) > 0. + if self.use_degree == 'AA': + node_weight = torch.sqrt(torch.reciprocal(torch.log(degree))) + elif self.use_degree == 'RA': + node_weight = torch.sqrt(torch.reciprocal(degree)) + node_weight = torch.nan_to_num(node_weight, nan=0.0, posinf=0.0, neginf=0.0) + + if cache_mode in ["build","delete"]: + propped = self.nodelabel(edges, adj, node_weight=node_weight, cache_mode=cache_mode) + return + else: + propped = self.nodelabel(edges, adj, node_weight=node_weight, cache_mode=cache_mode) + propped_stack = torch.stack([*propped], dim=1) + out = self.struct_encode(propped_stack) + + if self.in_channels > 0: + x_i = x[edges[0]] + x_j = x[edges[1]] + if self.feature_combine == "hadamard": + x_ij = x_i * x_j + elif self.feature_combine == "plus_minus": + x_ij = torch.cat([x_i+x_j, torch.abs(x_i-x_j)], dim=1) + x_ij = self.feat_encode(x_ij) + x = torch.cat([x_ij, out], dim=1) + else: + x = out + logit = self.classifier(x) + return logit + + def precompute(self, adj): + self(None, adj, None, cache_mode="build") + return self + + From 9b2ea20e8bfd9ce8e9875eb4d554d9b12874e0b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 16:37:29 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/mplp.py | 312 ++++++++++++++++++------------ torch_geometric/nn/models/mplp.py | 278 +++++++++++++++----------- 2 files changed, 353 insertions(+), 237 deletions(-) diff --git a/examples/mplp.py b/examples/mplp.py index 77ee9039ab07..2ff03af4b669 100644 --- a/examples/mplp.py +++ b/examples/mplp.py @@ -3,34 +3,27 @@ Based on the code https://github.com/Barcavin/efficient-node-labelling """ import argparse -import random -import numpy as np import os +import random import sys import time from pathlib import Path from typing import List, Tuple +import numpy as np import torch -from torch_sparse import SparseTensor +from ogb.linkproppred import Evaluator, PygLinkPropPredDataset +from sklearn.metrics import roc_auc_score from torch.nn import BCEWithLogitsLoss from torch.utils.data import DataLoader +from torch_sparse import SparseTensor +from tqdm import tqdm +import torch_geometric.transforms as T from torch_geometric.data import Data +from torch_geometric.nn.models import MLP, MPLP, MPLP_GCN from torch_geometric.transforms import ToSparseTensor, ToUndirected -import torch_geometric.transforms as T - - -from sklearn.metrics import roc_auc_score - -from ogb.linkproppred import PygLinkPropPredDataset -from ogb.linkproppred import Evaluator - -from torch_geometric.nn.models import MLP, MPLP_GCN, MPLP - from torch_geometric.utils import degree -from tqdm import tqdm - ######################## ######## Utils ######### @@ -53,7 +46,7 @@ def get_dataset(root, name: str): key = 'edge' else: key = 'source_node' - print("-"*20) + print("-" * 20) print(f"train: {split_edge['train'][key].shape[0]}") print(f"{split_edge['train'][key]}") print(f"valid: {split_edge['valid'][key].shape[0]}") @@ -69,6 +62,7 @@ def get_dataset(root, name: str): del data.edge_index return data, split_edge + def set_random_seeds(random_seed=0): r"""Sets the seed for generating random numbers.""" torch.manual_seed(random_seed) @@ -76,6 +70,7 @@ def set_random_seeds(random_seed=0): np.random.seed(random_seed) random.seed(random_seed) + def str2bool(v): if isinstance(v, bool): return v @@ -86,6 +81,7 @@ def str2bool(v): else: raise argparse.ArgumentTypeError('Boolean value expected.') + def data_summary(name: str, data: Data): num_nodes = data.num_nodes num_edges = data.num_edges @@ -97,21 +93,23 @@ def data_summary(name: str, data: Data): if data.x is not None: attr_dim = data.x.shape[1] else: - attr_dim = '-' # no attribute + attr_dim = '-' # no attribute - print("-"*30+'Dataset and Features'+"-"*60) + print("-" * 30 + 'Dataset and Features' + "-" * 60) print("{:<10}|{:<10}|{:<10}|{:<15}|{:<15}|{:<15}|{:<10}|{:<15}"\ .format('Dataset','#Nodes','#Edges','Avg. node deg.','Std. node deg.','Max. node deg.', 'Density','Attr. Dimension')) - print("-"*110) + print("-" * 110) print("{:<10}|{:<10}|{:<10}|{:<15.2f}|{:<15.2f}|{:<15}|{:<9.4f}%|{:<15}"\ .format(name, num_nodes, num_edges, avg_degree, degree_std, max_degree, density*100, attr_dim)) - print("-"*110) + print("-" * 110) + def initial_embedding(data, hidden_channels, device): - embedding= torch.nn.Embedding(data.num_nodes, hidden_channels).to(device) + embedding = torch.nn.Embedding(data.num_nodes, hidden_channels).to(device) torch.nn.init.xavier_uniform_(embedding.weight) return embedding + def create_input(data): if hasattr(data, 'emb') and data.emb is not None: x = data.emb.weight @@ -125,16 +123,18 @@ def create_input(data): ######################## -def __elem2spm(element: torch.Tensor, sizes: List[int], val: torch.Tensor=None) -> SparseTensor: +def __elem2spm(element: torch.Tensor, sizes: List[int], + val: torch.Tensor = None) -> SparseTensor: # Convert adjacency matrix to a 1-d vector col = torch.bitwise_and(element, 0xffffffff) row = torch.bitwise_right_shift(element, 32) if val is None: - sp_tensor = SparseTensor(row=row, col=col, sparse_sizes=sizes).to_device( - element.device).fill_value_(1.0) + sp_tensor = SparseTensor(row=row, col=col, + sparse_sizes=sizes).to_device( + element.device).fill_value_(1.0) else: - sp_tensor = SparseTensor(row=row, col=col, value=val, sparse_sizes=sizes).to_device( - element.device) + sp_tensor = SparseTensor(row=row, col=col, value=val, + sparse_sizes=sizes).to_device(element.device) return sp_tensor @@ -146,11 +146,11 @@ def __spm2elem(spm: SparseTensor) -> torch.Tensor: val = spm.storage.value() return elem, val -def __spmdiff(adj1: SparseTensor, - adj2: SparseTensor, keep_val=False) -> Tuple[SparseTensor, SparseTensor]: - ''' - return elements in adj1 but not in adj2 and in adj2 but not adj1 - ''' + +def __spmdiff(adj1: SparseTensor, adj2: SparseTensor, + keep_val=False) -> Tuple[SparseTensor, SparseTensor]: + """Return elements in adj1 but not in adj2 and in adj2 but not adj1 + """ element1, val1 = __spm2elem(adj1) element2, val2 = __spm2elem(adj2) @@ -164,7 +164,7 @@ def __spmdiff(adj1: SparseTensor, maskelem1 = torch.ones_like(element1, dtype=torch.bool) maskelem1[idx[matchedmask]] = 0 retelem1 = element1[maskelem1] - + if keep_val and val1 is not None: retval1 = val1[maskelem1] return __elem2spm(retelem1, adj1.sizes(), retval1) @@ -180,36 +180,41 @@ def get_train_test(args): evaluator = Evaluator(name='ogbl-ddi') return train_hits, test_hits, evaluator -def train_hits(encoder, predictor, data, split_edge, optimizer, batch_size, - mask_target, num_neg): + +def train_hits(encoder, predictor, data, split_edge, optimizer, batch_size, + mask_target, num_neg): encoder.train() predictor.train() device = data.adj_t.device() criterion = BCEWithLogitsLoss(reduction='mean') pos_train_edge = split_edge['train']['edge'].to(device) - + optimizer.zero_grad() total_loss = total_examples = 0 - num_pos_max = max(data.adj_t.nnz()//2, pos_train_edge.size(0)) - neg_edge_epoch = torch.randint(0, data.adj_t.size(0), - size=(2, num_pos_max*num_neg), - dtype=torch.long, device=device) - for perm in tqdm(DataLoader(range(pos_train_edge.size(0)), batch_size, - shuffle=True),desc='Train'): + num_pos_max = max(data.adj_t.nnz() // 2, pos_train_edge.size(0)) + neg_edge_epoch = torch.randint(0, data.adj_t.size(0), + size=(2, num_pos_max * num_neg), + dtype=torch.long, device=device) + for perm in tqdm( + DataLoader(range(pos_train_edge.size(0)), batch_size, + shuffle=True), desc='Train'): edge = pos_train_edge[perm].t() if mask_target: adj_t = data.adj_t undirected_edges = torch.cat((edge, edge.flip(0)), dim=-1) - target_adj = SparseTensor.from_edge_index(undirected_edges, sparse_sizes=adj_t.sizes()) + target_adj = SparseTensor.from_edge_index( + undirected_edges, sparse_sizes=adj_t.sizes()) adj_t = __spmdiff(adj_t, target_adj, keep_val=True) else: adj_t = data.adj_t h = encoder(data.x, adj_t) - neg_edge = neg_edge_epoch[:,perm] + neg_edge = neg_edge_epoch[:, perm] train_edges = torch.cat((edge, neg_edge), dim=-1) - train_label = torch.cat((torch.ones(edge.size()[1]), torch.zeros(neg_edge.size()[1])), dim=0).to(device) + train_label = torch.cat( + (torch.ones(edge.size()[1]), torch.zeros(neg_edge.size()[1])), + dim=0).to(device) out = predictor(h, adj_t, train_edges).squeeze() loss = criterion(out, train_label) @@ -223,13 +228,13 @@ def train_hits(encoder, predictor, data, split_edge, optimizer, batch_size, optimizer.zero_grad() total_examples += train_label.size(0) total_loss += loss.item() * train_label.size(0) - + return total_loss / total_examples @torch.no_grad() -def test_hits(encoder, predictor, data, split_edge, evaluator, - batch_size, fast_inference): +def test_hits(encoder, predictor, data, split_edge, evaluator, batch_size, + fast_inference): encoder.eval() predictor.eval() device = data.adj_t.device() @@ -250,7 +255,10 @@ def test_split(split, cache_mode=None): neg_test_preds = [] for perm in DataLoader(range(neg_test_edge.size(0)), batch_size): edge = neg_test_edge[perm].t() - neg_test_preds += [predictor(h, adj_t, edge, cache_mode=cache_mode).squeeze().cpu()] + neg_test_preds += [ + predictor(h, adj_t, edge, + cache_mode=cache_mode).squeeze().cpu() + ] neg_test_pred = torch.cat(neg_test_preds, dim=0) return pos_test_pred, neg_test_pred @@ -260,10 +268,10 @@ def test_split(split, cache_mode=None): if fast_inference: # caching predictor(h, adj_t, None, cache_mode='build') - cache_mode='use' + cache_mode = 'use' else: - cache_mode=None - + cache_mode = None + pos_test_pred, neg_test_pred = test_split('test', cache_mode=cache_mode) end_time = time.perf_counter() total_time = end_time - start_time @@ -271,7 +279,7 @@ def test_split(split, cache_mode=None): if fast_inference: # delete cache predictor(h, adj_t, None, cache_mode='delete') - + results = {} K = 100 evaluator.K = K @@ -286,13 +294,19 @@ def test_split(split, cache_mode=None): results[f'Hits@{K}'] = (valid_hits, test_hits) - valid_result = torch.cat((torch.ones(pos_valid_pred.size()), torch.zeros(neg_valid_pred.size())), dim=0) + valid_result = torch.cat((torch.ones( + pos_valid_pred.size()), torch.zeros(neg_valid_pred.size())), dim=0) valid_pred = torch.cat((pos_valid_pred, neg_valid_pred), dim=0) - test_result = torch.cat((torch.ones(pos_test_pred.size()), torch.zeros(neg_test_pred.size())), dim=0) + test_result = torch.cat( + (torch.ones(pos_test_pred.size()), torch.zeros(neg_test_pred.size())), + dim=0) test_pred = torch.cat((pos_test_pred, neg_test_pred), dim=0) - results['AUC'] = (roc_auc_score(valid_result.cpu().numpy(),valid_pred.cpu().numpy()),roc_auc_score(test_result.cpu().numpy(),test_pred.cpu().numpy())) + results['AUC'] = (roc_auc_score(valid_result.cpu().numpy(), + valid_pred.cpu().numpy()), + roc_auc_score(test_result.cpu().numpy(), + test_pred.cpu().numpy())) return results @@ -313,17 +327,22 @@ def make_symmetric(sparse_tensor, reduce='sum'): new_value = all_value # Remove duplicates by summing the values for symmetric entries - unique_indices, inverse_indices = torch.unique(new_indices, dim=1, return_inverse=True) - unique_value = torch.zeros(unique_indices.size(1), device=value.device).scatter_reduce_(0, inverse_indices, new_value, reduce="amax") + unique_indices, inverse_indices = torch.unique(new_indices, dim=1, + return_inverse=True) + unique_value = torch.zeros(unique_indices.size(1), + device=value.device).scatter_reduce_( + 0, inverse_indices, new_value, + reduce="amax") # Create the symmetric sparse tensor - symmetric_tensor = torch.sparse_coo_tensor(unique_indices, unique_value, sparse_tensor.size()) + symmetric_tensor = torch.sparse_coo_tensor(unique_indices, unique_value, + sparse_tensor.size()) return symmetric_tensor -def train_mrr(encoder, predictor, data, split_edge, optimizer, batch_size, - mask_target, num_neg): +def train_mrr(encoder, predictor, data, split_edge, optimizer, batch_size, + mask_target, num_neg): encoder.train() predictor.train() device = data.adj_t.device() @@ -331,18 +350,21 @@ def train_mrr(encoder, predictor, data, split_edge, optimizer, batch_size, source_edge = split_edge['train']['source_node'].to(device) target_edge = split_edge['train']['target_node'].to(device) adjmask = torch.ones_like(source_edge, dtype=torch.bool) - + optimizer.zero_grad() total_loss = total_examples = 0 - for perm in tqdm(DataLoader(range(source_edge.size(0)), batch_size, - shuffle=True),desc='Train'): + for perm in tqdm( + DataLoader(range(source_edge.size(0)), batch_size, shuffle=True), + desc='Train'): if mask_target: adjmask[perm] = 0 - tei = torch.stack((source_edge[adjmask], target_edge[adjmask]), dim=0) # TODO: check if both direction is removed - - adj_t = SparseTensor.from_edge_index(tei, - sparse_sizes=(data.num_nodes, data.num_nodes)).to_device( - source_edge.device, non_blocking=True) + tei = torch.stack( + (source_edge[adjmask], target_edge[adjmask]), + dim=0) # TODO: check if both direction is removed + + adj_t = SparseTensor.from_edge_index( + tei, sparse_sizes=(data.num_nodes, data.num_nodes)).to_device( + source_edge.device, non_blocking=True) adjmask[perm] = 1 adj_t = adj_t.to_symmetric() @@ -353,15 +375,18 @@ def train_mrr(encoder, predictor, data, split_edge, optimizer, batch_size, else: adj_t = data.adj_t - h = encoder(data.x, adj_t) - dst_neg = torch.randint(0, data.num_nodes, perm.size()*num_neg, - dtype=torch.long, device=device) + dst_neg = torch.randint(0, data.num_nodes, + perm.size() * num_neg, dtype=torch.long, + device=device) edge = torch.stack((source_edge[perm], target_edge[perm]), dim=0) - neg_edge = torch.stack((source_edge[perm].repeat(num_neg), dst_neg), dim=0) + neg_edge = torch.stack((source_edge[perm].repeat(num_neg), dst_neg), + dim=0) train_edges = torch.cat((edge, neg_edge), dim=-1) - train_label = torch.cat((torch.ones(edge.size()[1]), torch.zeros(neg_edge.size()[1])), dim=0).to(device) + train_label = torch.cat( + (torch.ones(edge.size()[1]), torch.zeros(neg_edge.size()[1])), + dim=0).to(device) out = predictor(h, adj_t, train_edges).squeeze() loss = criterion(out, train_label) @@ -375,13 +400,13 @@ def train_mrr(encoder, predictor, data, split_edge, optimizer, batch_size, optimizer.zero_grad() total_examples += train_label.size(0) total_loss += loss.item() * train_label.size(0) - + return total_loss / total_examples @torch.no_grad() -def test_mrr(encoder, predictor, data, split_edge, evaluator, - batch_size, fast_inference): +def test_mrr(encoder, predictor, data, split_edge, evaluator, batch_size, + fast_inference): encoder.eval() predictor.eval() device = data.adj_t.device() @@ -396,7 +421,10 @@ def test_split(split, cache_mode=None): pos_preds = [] for perm in DataLoader(range(source.size(0)), batch_size): src, dst = source[perm], target[perm] - pos_preds += [predictor(h, adj_t, torch.stack((src, dst)), cache_mode=cache_mode).squeeze().cpu()] + pos_preds += [ + predictor(h, adj_t, torch.stack((src, dst)), + cache_mode=cache_mode).squeeze().cpu() + ] pos_pred = torch.cat(pos_preds, dim=0) neg_preds = [] @@ -404,7 +432,10 @@ def test_split(split, cache_mode=None): target_neg = target_neg.view(-1) for perm in DataLoader(range(source.size(0)), batch_size): src, dst_neg = source[perm], target_neg[perm] - neg_preds += [predictor(h, adj_t, torch.stack((src, dst_neg)), cache_mode=cache_mode).squeeze().cpu()] + neg_preds += [ + predictor(h, adj_t, torch.stack((src, dst_neg)), + cache_mode=cache_mode).squeeze().cpu() + ] neg_pred = torch.cat(neg_preds, dim=0).view(-1, 1000) return pos_pred, neg_pred @@ -415,10 +446,10 @@ def test_split(split, cache_mode=None): if fast_inference: # caching predictor(h, adj_t, None, cache_mode='build') - cache_mode='use' + cache_mode = 'use' else: - cache_mode=None - + cache_mode = None + pos_test_pred, neg_test_pred = test_split('test', cache_mode=cache_mode) end_time = time.perf_counter() total_time = end_time - start_time @@ -426,15 +457,15 @@ def test_split(split, cache_mode=None): if fast_inference: # delete cache predictor(h, adj_t, None, cache_mode='delete') - + valid_mrr = evaluator.eval({ - 'y_pred_pos': pos_valid_pred, - 'y_pred_neg': neg_valid_pred, - })['mrr_list'].mean().item() + 'y_pred_pos': pos_valid_pred, + 'y_pred_neg': neg_valid_pred, + })['mrr_list'].mean().item() test_mrr = evaluator.eval({ - 'y_pred_pos': pos_test_pred, - 'y_pred_neg': neg_test_pred, - })['mrr_list'].mean().item() + 'y_pred_pos': pos_test_pred, + 'y_pred_neg': neg_test_pred, + })['mrr_list'].mean().item() results = { "MRR": (valid_mrr, test_mrr), @@ -447,6 +478,7 @@ def test_split(split, cache_mode=None): ######## Main ######### ######################## + def main(): parser = argparse.ArgumentParser(description='MPLP') # dataset setting @@ -456,12 +488,25 @@ def main(): parser.add_argument('--dataset_dir', type=str, default='./data') # MPLP settings - parser.add_argument('--signature_dim', type=int, default=1024, help="the node signature dimension `F` in MPLP") - parser.add_argument('--minimum_degree_onehot', type=int, default=-1, help='the minimum degree of hubs with onehot encoding to reduce variance') - parser.add_argument('--use_degree', type=str, default='none', choices=["none","mlp","AA","RA"], help="rescale vector norm to facilitate weighted count") - parser.add_argument('--signature_sampling', type=str, default='torchhd', help='whether to use torchhd to randomize vectors', choices=["torchhd","gaussian","onehot"]) - parser.add_argument('--fast_inference', type=str2bool, default='False', help='whether to enable a faster inference by caching the node vectors') - parser.add_argument('--mask_target', type=str2bool, default='True', help='whether to mask the target edges to remove the shortcut') + parser.add_argument('--signature_dim', type=int, default=1024, + help="the node signature dimension `F` in MPLP") + parser.add_argument( + '--minimum_degree_onehot', type=int, default=-1, help= + 'the minimum degree of hubs with onehot encoding to reduce variance') + parser.add_argument( + '--use_degree', type=str, default='none', + choices=["none", "mlp", "AA", "RA"], + help="rescale vector norm to facilitate weighted count") + parser.add_argument('--signature_sampling', type=str, default='torchhd', + help='whether to use torchhd to randomize vectors', + choices=["torchhd", "gaussian", "onehot"]) + parser.add_argument( + '--fast_inference', type=str2bool, default='False', + help='whether to enable a faster inference by caching the node vectors' + ) + parser.add_argument( + '--mask_target', type=str2bool, default='True', + help='whether to mask the target edges to remove the shortcut') # model setting parser.add_argument('--encoder', type=str, default='gcn') @@ -471,11 +516,17 @@ def main(): parser.add_argument('--label_dropout', type=float, default=0.5) parser.add_argument('--num_layers', type=int, default=2) parser.add_argument('--device', type=int, default=0) - parser.add_argument('--use_feature', type=str2bool, default='True', help='whether to use node features as input') - parser.add_argument('--feature_combine', type=str, default='hadamard', choices=['hadamard','plus_minus'], help='how to represent a link with two nodes features') - parser.add_argument('--jk', type=str2bool, default='True', help='whether to use Jumping Knowledge') - parser.add_argument('--batchnorm_affine', type=str2bool, default='True', help='whether to use Affine in BatchNorm') - parser.add_argument('--use_embedding', type=str2bool, default='False', help='whether to train node embedding') + parser.add_argument('--use_feature', type=str2bool, default='True', + help='whether to use node features as input') + parser.add_argument('--feature_combine', type=str, default='hadamard', + choices=['hadamard', 'plus_minus'], + help='how to represent a link with two nodes features') + parser.add_argument('--jk', type=str2bool, default='True', + help='whether to use Jumping Knowledge') + parser.add_argument('--batchnorm_affine', type=str2bool, default='True', + help='whether to use Affine in BatchNorm') + parser.add_argument('--use_embedding', type=str2bool, default='False', + help='whether to train node embedding') # training setting parser.add_argument('--batch_size', type=int, default=64 * 1024) @@ -486,9 +537,11 @@ def main(): parser.add_argument('--lr', type=float, default=0.005) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--log_steps', type=int, default=20) - parser.add_argument('--patience', type=int, default=100, help='number of patience steps for early stopping') + parser.add_argument('--patience', type=int, default=100, + help='number of patience steps for early stopping') parser.add_argument('--runs', type=int, default=10) - parser.add_argument('--metric', type=str, default='Hits@100', help='main evaluation metric') + parser.add_argument('--metric', type=str, default='Hits@100', + help='main evaluation metric') # misc parser.add_argument('--data_split_only', type=str2bool, default='False') @@ -509,14 +562,15 @@ def main(): args.use_feature = False if args.print_summary: - data_summary(args.dataset, data);exit(0) + data_summary(args.dataset, data) + exit(0) else: print(args) - + # Save command line input. cmd_input = 'python ' + ' '.join(sys.argv) + '\n' print('Command line input: ' + cmd_input + ' is saved.') - + train, test, evaluator = get_train_test(args) val_max = 0.0 @@ -533,28 +587,36 @@ def main(): emb = None if 'gcn' in args.encoder: encoder = MPLP_GCN(data.num_features, args.hidden_channels, - args.hidden_channels, args.num_layers, - args.feat_dropout, args.xdp, args.use_feature, args.jk, args.encoder, emb).to(device) + args.hidden_channels, args.num_layers, + args.feat_dropout, args.xdp, args.use_feature, + args.jk, args.encoder, emb).to(device) elif args.encoder == 'mlp': - encoder = MLP(num_layers=args.num_layers, in_channels=data.num_features, - hidden_channels=args.hidden_channels, out_channels=args.hidden_channels, + encoder = MLP(num_layers=args.num_layers, + in_channels=data.num_features, + hidden_channels=args.hidden_channels, + out_channels=args.hidden_channels, dropout=args.feat_dropout, act=None).to(device) - predictor_in_dim = args.hidden_channels * int(args.use_feature or args.use_embedding) + predictor_in_dim = args.hidden_channels * int(args.use_feature + or args.use_embedding) predictor = MPLP(predictor_in_dim, args.hidden_channels, - args.num_layers, args.feat_dropout, args.label_dropout, args.num_hops, - signature_sampling=args.signature_sampling, - use_degree=args.use_degree, signature_dim=args.signature_dim, - minimum_degree_onehot=args.minimum_degree_onehot, batchnorm_affine=args.batchnorm_affine, - feature_combine=args.feature_combine) + args.num_layers, args.feat_dropout, + args.label_dropout, args.num_hops, + signature_sampling=args.signature_sampling, + use_degree=args.use_degree, + signature_dim=args.signature_dim, + minimum_degree_onehot=args.minimum_degree_onehot, + batchnorm_affine=args.batchnorm_affine, + feature_combine=args.feature_combine) predictor = predictor.to(device) encoder.reset_parameters() predictor.reset_parameters() parameters = list(encoder.parameters()) + list(predictor.parameters()) - optimizer = torch.optim.Adam(parameters, lr=args.lr, weight_decay=args.weight_decay) + optimizer = torch.optim.Adam(parameters, lr=args.lr, + weight_decay=args.weight_decay) total_params = sum(p.numel() for param in parameters for p in param) print(f'Total number of parameters is {total_params}') @@ -562,27 +624,27 @@ def main(): best_val = 0.0 for epoch in range(1, 1 + args.epochs): - loss = train(encoder, predictor, data, split_edge, - optimizer, args.batch_size, args.mask_target, + loss = train(encoder, predictor, data, split_edge, optimizer, + args.batch_size, args.mask_target, num_neg=args.num_neg) - results = test(encoder, predictor, data, split_edge, - evaluator, args.test_batch_size, args.fast_inference) + results = test(encoder, predictor, data, split_edge, evaluator, + args.test_batch_size, args.fast_inference) if results[args.metric][0] >= best_val: best_val = results[args.metric][0] cnt_wait = 0 else: - cnt_wait +=1 + cnt_wait += 1 if epoch % args.log_steps == 0: for key, result in results.items(): valid_hits, test_hits = result to_print = (f'Run: {run + 1:02d}, ' + - f'Epoch: {epoch:02d}, '+ - f'Loss: {loss:.4f}, '+ - f'Valid: {100 * valid_hits:.2f}%, '+ - f'Test: {100 * test_hits:.2f}%') + f'Epoch: {epoch:02d}, ' + + f'Loss: {loss:.4f}, ' + + f'Valid: {100 * valid_hits:.2f}%, ' + + f'Test: {100 * test_hits:.2f}%') print(key) print(to_print) print('---') @@ -594,6 +656,6 @@ def main(): end_time = time.time() print(f"Total time: {end_time - start_time:.4f}s") + if __name__ == "__main__": main() - diff --git a/torch_geometric/nn/models/mplp.py b/torch_geometric/nn/models/mplp.py index afb657a4326d..837535738d54 100644 --- a/torch_geometric/nn/models/mplp.py +++ b/torch_geometric/nn/models/mplp.py @@ -1,50 +1,48 @@ - import math - -from torch import Tensor +from functools import partial +from typing import Final import torch import torch.nn as nn import torch.nn.functional as F +import torchhd +from torch import Tensor +from torch_sparse import SparseTensor, matmul +from torch_sparse.matmul import spmm_add, spmm_max, spmm_mean from torch_geometric.nn import GCNConv, SGConv from torch_geometric.nn.models import MLP -from torch_geometric.utils import k_hop_subgraph as pyg_k_hop_subgraph, to_edge_index - -from torch_sparse import SparseTensor, matmul -from torch_sparse.matmul import spmm_max, spmm_mean, spmm_add +from torch_geometric.utils import k_hop_subgraph as pyg_k_hop_subgraph +from torch_geometric.utils import to_edge_index - -from functools import partial - -from typing import Final - -import torchhd - -USE_CUSTOM_MLP=True +USE_CUSTOM_MLP = True ######################## ###### NodeLabel ####### ######################## -MINIMUM_SIGNATURE_DIM=64 +MINIMUM_SIGNATURE_DIM = 64 + class NodeLabel(torch.nn.Module): - def __init__(self, dim: int=1024, signature_sampling="torchhd", - minimum_degree_onehot: int=-1): + def __init__(self, dim: int = 1024, signature_sampling="torchhd", + minimum_degree_onehot: int = -1): super().__init__() self.dim = dim self.signature_sampling = signature_sampling self.cached_two_hop_adj = None self.minimum_degree_onehot = minimum_degree_onehot - def forward(self, edges: Tensor, adj_t: SparseTensor, node_weight: Tensor=None, cache_mode=None): + def forward(self, edges: Tensor, adj_t: SparseTensor, + node_weight: Tensor = None, cache_mode=None): if cache_mode is not None: - return self.propagation_only_cache(edges, adj_t, node_weight, cache_mode) + return self.propagation_only_cache(edges, adj_t, node_weight, + cache_mode) else: return self.propagation_only(edges, adj_t, node_weight) - def get_random_node_vectors(self, adj_t: SparseTensor, node_weight) -> Tensor: + def get_random_node_vectors(self, adj_t: SparseTensor, + node_weight) -> Tensor: num_nodes = adj_t.size(0) device = adj_t.device() if self.minimum_degree_onehot > 0: @@ -52,63 +50,82 @@ def get_random_node_vectors(self, adj_t: SparseTensor, node_weight) -> Tensor: nodes_to_one_hot = degree >= self.minimum_degree_onehot one_hot_dim = nodes_to_one_hot.sum() if one_hot_dim + MINIMUM_SIGNATURE_DIM > self.dim: - raise ValueError(f"There are {int(one_hot_dim)} nodes with degree higher than {self.minimum_degree_onehot}, select a higher threshold to choose fewer nodes as hub") + raise ValueError( + f"There are {int(one_hot_dim)} nodes with degree higher than {self.minimum_degree_onehot}, select a higher threshold to choose fewer nodes as hub" + ) embedding = torch.zeros(num_nodes, self.dim, device=device) - if one_hot_dim>0: - one_hot_embedding = F.one_hot(torch.arange(0, one_hot_dim)).float().to(device) - embedding[nodes_to_one_hot,:one_hot_dim] = one_hot_embedding + if one_hot_dim > 0: + one_hot_embedding = F.one_hot(torch.arange( + 0, one_hot_dim)).float().to(device) + embedding[nodes_to_one_hot, :one_hot_dim] = one_hot_embedding else: embedding = torch.zeros(num_nodes, self.dim, device=device) - nodes_to_one_hot = torch.zeros(num_nodes, dtype=torch.bool, device=device) + nodes_to_one_hot = torch.zeros(num_nodes, dtype=torch.bool, + device=device) one_hot_dim = 0 rand_dim = self.dim - one_hot_dim if self.signature_sampling == "torchhd": scale = math.sqrt(1 / rand_dim) - node_vectors = torchhd.random(num_nodes - one_hot_dim, rand_dim, device=device) + node_vectors = torchhd.random(num_nodes - one_hot_dim, rand_dim, + device=device) node_vectors.mul_(scale) # make them unit vectors elif self.signature_sampling == "gaussian": - node_vectors = F.normalize(torch.nn.init.normal_(torch.empty((num_nodes - one_hot_dim, rand_dim), dtype=torch.float32, device=device))) + node_vectors = F.normalize( + torch.nn.init.normal_( + torch.empty((num_nodes - one_hot_dim, rand_dim), + dtype=torch.float32, device=device))) elif self.signature_sampling == "onehot": embedding = torch.zeros(num_nodes, num_nodes, device=device) - node_vectors = F.one_hot(torch.arange(0, num_nodes)).float().to(device) + node_vectors = F.one_hot(torch.arange( + 0, num_nodes)).float().to(device) embedding[~nodes_to_one_hot, one_hot_dim:] = node_vectors if node_weight is not None: - node_weight = node_weight.unsqueeze(1) # Note: not sqrt here because it can cause problem for MLP when output is negative - # thus, it requires the MLP to approximate one more sqrt? + node_weight = node_weight.unsqueeze( + 1 + ) # Note: not sqrt here because it can cause problem for MLP when output is negative + # thus, it requires the MLP to approximate one more sqrt? embedding.mul_(node_weight) return embedding - def propagation_only(self, edges: Tensor, adj_t: SparseTensor, node_weight=None): + def propagation_only(self, edges: Tensor, adj_t: SparseTensor, + node_weight=None): adj_t, new_edges, subset_nodes = subgraph(edges, adj_t, 2) - node_weight = node_weight[subset_nodes] if node_weight is not None else None + node_weight = node_weight[ + subset_nodes] if node_weight is not None else None x = self.get_random_node_vectors(adj_t, node_weight=node_weight) - subset = new_edges.view(-1) # flatten the target nodes [row, col] + subset = new_edges.view(-1) # flatten the target nodes [row, col] - subset_unique, inverse_indices = torch.unique(subset, return_inverse=True) + subset_unique, inverse_indices = torch.unique(subset, + return_inverse=True) one_hop_x_subgraph_nodes = matmul(adj_t, x) one_hop_x = one_hop_x_subgraph_nodes[subset] - two_hop_x = matmul(adj_t[subset_unique], one_hop_x_subgraph_nodes)[inverse_indices] + two_hop_x = matmul(adj_t[subset_unique], + one_hop_x_subgraph_nodes)[inverse_indices] degree_one_hop = adj_t.sum(dim=1) one_hop_x = one_hop_x.view(2, new_edges.size(1), -1) two_hop_x = two_hop_x.view(2, new_edges.size(1), -1) - count_1_1 = dot_product(one_hop_x[0,:,:], one_hop_x[1,:,:]) - count_1_2 = dot_product(one_hop_x[0,:,:], two_hop_x[1,:,:]) - count_2_1 = dot_product(two_hop_x[0,:,:] , one_hop_x[1,:,:]) - count_2_2 = dot_product((two_hop_x[0,:,:]-degree_one_hop[new_edges[0]].view(-1,1)*x[new_edges[0]]), - (two_hop_x[1,:,:]-degree_one_hop[new_edges[1]].view(-1,1)*x[new_edges[1]])) - - count_self_1_2 = dot_product(one_hop_x[0,:,:] , two_hop_x[0,:,:]) - count_self_2_1 = dot_product(one_hop_x[1,:,:] , two_hop_x[1,:,:]) + count_1_1 = dot_product(one_hop_x[0, :, :], one_hop_x[1, :, :]) + count_1_2 = dot_product(one_hop_x[0, :, :], two_hop_x[1, :, :]) + count_2_1 = dot_product(two_hop_x[0, :, :], one_hop_x[1, :, :]) + count_2_2 = dot_product( + (two_hop_x[0, :, :] - + degree_one_hop[new_edges[0]].view(-1, 1) * x[new_edges[0]]), + (two_hop_x[1, :, :] - + degree_one_hop[new_edges[1]].view(-1, 1) * x[new_edges[1]])) + + count_self_1_2 = dot_product(one_hop_x[0, :, :], two_hop_x[0, :, :]) + count_self_2_1 = dot_product(one_hop_x[1, :, :], two_hop_x[1, :, :]) degree_u = degree_one_hop[new_edges[0]] degree_v = degree_one_hop[new_edges[1]] return count_1_1, count_1_2, count_2_1, count_2_2, count_self_1_2, count_self_2_1, degree_u, degree_v - def propagation_only_cache(self, edges: Tensor, adj_t: SparseTensor, node_weight=None, cache_mode=None): + def propagation_only_cache(self, edges: Tensor, adj_t: SparseTensor, + node_weight=None, cache_mode=None): if cache_mode == 'build': # get the 2-hop subgraph of the target edges x = self.get_random_node_vectors(adj_t, node_weight=node_weight) @@ -138,44 +155,50 @@ def propagation_only_cache(self, edges: Tensor, adj_t: SparseTensor, node_weight one_hop_x = self.cached_one_hop_x two_iter_x = self.cached_two_iter_x - count_1_1 = dot_product(one_hop_x[edges[0]] , one_hop_x[edges[1]]) - count_1_2 = dot_product(one_hop_x[edges[0]] , two_iter_x[edges[1]]) - count_2_1 = dot_product(two_iter_x[edges[0]] , one_hop_x[edges[1]]) + count_1_1 = dot_product(one_hop_x[edges[0]], one_hop_x[edges[1]]) + count_1_2 = dot_product(one_hop_x[edges[0]], two_iter_x[edges[1]]) + count_2_1 = dot_product(two_iter_x[edges[0]], one_hop_x[edges[1]]) count_2_2 = dot_product(two_iter_x[edges[0]]-degree_one_hop[edges[0]].view(-1,1)*x[edges[0]],\ two_iter_x[edges[1]]-degree_one_hop[edges[1]].view(-1,1)*x[edges[1]]) - count_self_1_2 = dot_product(one_hop_x[edges[0]] , two_iter_x[edges[0]]) - count_self_2_1 = dot_product(one_hop_x[edges[1]] , two_iter_x[edges[1]]) + count_self_1_2 = dot_product(one_hop_x[edges[0]], two_iter_x[edges[0]]) + count_self_2_1 = dot_product(one_hop_x[edges[1]], two_iter_x[edges[1]]) degree_u = degree_one_hop[edges[0]] degree_v = degree_one_hop[edges[1]] return count_1_1, count_1_2, count_2_1, count_2_2, count_self_1_2, count_self_2_1, degree_u, degree_v -def subgraph(edges: Tensor, adj_t: SparseTensor, k: int=2): - row,col = edges - nodes = torch.cat((row,col),dim=-1) - edge_index,_ = to_edge_index(adj_t) - subset, new_edge_index, inv, edge_mask = pyg_k_hop_subgraph(nodes, k, edge_index=edge_index, - num_nodes=adj_t.size(0), relabel_nodes=True) + +def subgraph(edges: Tensor, adj_t: SparseTensor, k: int = 2): + row, col = edges + nodes = torch.cat((row, col), dim=-1) + edge_index, _ = to_edge_index(adj_t) + subset, new_edge_index, inv, edge_mask = pyg_k_hop_subgraph( + nodes, k, edge_index=edge_index, num_nodes=adj_t.size(0), + relabel_nodes=True) # subset[inv] = nodes. The new node id is based on `subset`'s order. # inv means the new idx (in subset) of the old nodes in `nodes` - new_adj_t = SparseTensor(row=new_edge_index[0], col=new_edge_index[1], - sparse_sizes=(subset.size(0), subset.size(0))) - new_edges = inv.view(2,-1) + new_adj_t = SparseTensor(row=new_edge_index[0], col=new_edge_index[1], + sparse_sizes=(subset.size(0), subset.size(0))) + new_edges = inv.view(2, -1) return new_adj_t, new_edges, subset + def dotproduct_naive(tensor1, tensor2): return (tensor1 * tensor2).sum(dim=-1) + def dotproduct_bmm(tensor1, tensor2): return torch.bmm(tensor1.unsqueeze(1), tensor2.unsqueeze(2)).view(-1) + dot_product = dotproduct_naive ######################## ######### MLP ########## ######################## + class CustomMLP(nn.Module): def __init__( self, @@ -192,21 +215,24 @@ def __init__( self.num_layers = num_layers self.norm_type = norm_type self.tailnormactdrop = tailnormactdrop - self.affine = affine # the affine in batchnorm + self.affine = affine # the affine in batchnorm self.layers = [] if num_layers == 1: self.layers.append(nn.Linear(input_dim, output_dim)) if tailnormactdrop: - self.__build_normactdrop(self.layers, output_dim, dropout_ratio) + self.__build_normactdrop(self.layers, output_dim, + dropout_ratio) else: self.layers.append(nn.Linear(input_dim, hidden_dim)) self.__build_normactdrop(self.layers, hidden_dim, dropout_ratio) for i in range(num_layers - 2): self.layers.append(nn.Linear(hidden_dim, hidden_dim)) - self.__build_normactdrop(self.layers, hidden_dim, dropout_ratio) + self.__build_normactdrop(self.layers, hidden_dim, + dropout_ratio) self.layers.append(nn.Linear(hidden_dim, output_dim)) if tailnormactdrop: - self.__build_normactdrop(self.layers, hidden_dim, dropout_ratio) + self.__build_normactdrop(self.layers, hidden_dim, + dropout_ratio) self.layers = nn.Sequential(*self.layers) def __build_normactdrop(self, layers, dim, dropout): @@ -225,13 +251,16 @@ def reset_parameters(self): def forward(self, feats, adj_t=None): return self.layers(feats) + ######################## ######### GNN ########## ######################## + # Addpted from NCNC class PureConv(nn.Module): aggr: Final[str] + def __init__(self, indim, outdim, aggr="gcn") -> None: super().__init__() self.aggr = aggr @@ -249,15 +278,17 @@ def forward(self, x, adj_t): elif self.aggr == "sum": return spmm_add(adj_t, x) elif self.aggr == "gcn": - norm = torch.rsqrt_((1+adj_t.sum(dim=-1))).reshape(-1, 1) + norm = torch.rsqrt_((1 + adj_t.sum(dim=-1))).reshape(-1, 1) x = norm * x x = spmm_add(adj_t, x) + x x = norm * x return x + class MPLP_GCN(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers, - dropout, xdropout, use_feature=True, jk=False, gcn_name='gcn', embedding=None): + dropout, xdropout, use_feature=True, jk=False, gcn_name='gcn', + embedding=None): super(MPLP_GCN, self).__init__() self.use_feature = use_feature @@ -267,30 +298,31 @@ def __init__(self, in_channels, hidden_channels, out_channels, num_layers, self.input_size = 0 self.jk = jk if jk: - self.register_parameter("jkparams", nn.Parameter(torch.randn((num_layers,)))) + self.register_parameter("jkparams", + nn.Parameter(torch.randn((num_layers, )))) if self.use_feature: self.input_size += in_channels if self.embedding is not None: self.input_size += embedding.embedding_dim self.convs = torch.nn.ModuleList() - + if self.input_size > 0: if gcn_name == 'gcn': conv_func = partial(GCNConv, cached=False) elif 'pure' in gcn_name: conv_func = partial(SGConv, apply_linearity=False) self.xemb = nn.Sequential(nn.Dropout(xdropout)) - if ("pure" in gcn_name or num_layers==0): + if ("pure" in gcn_name or num_layers == 0): self.xemb.append(nn.Linear(self.input_size, hidden_channels)) - self.xemb.append(nn.Dropout(dropout, inplace=True) if dropout > 1e-6 else nn.Identity()) + self.xemb.append( + nn.Dropout(dropout, inplace=True) if dropout > + 1e-6 else nn.Identity()) self.input_size = hidden_channels self.convs.append(conv_func(self.input_size, hidden_channels)) for _ in range(num_layers - 2): - self.convs.append( - conv_func(hidden_channels, hidden_channels)) + self.convs.append(conv_func(hidden_channels, hidden_channels)) self.convs.append(conv_func(hidden_channels, out_channels)) - def reset_parameters(self): for m in self.modules(): if isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm1d): @@ -310,20 +342,23 @@ def forward(self, x, adj_t): x = conv(x, adj_t) if self.jk: jkx.append(x) - if self.jk: # JumpingKnowledge Connection + if self.jk: # JumpingKnowledge Connection jkx = torch.stack(jkx, dim=0) sftmax = self.jkparams.reshape(-1, 1, 1) - x = torch.sum(jkx*sftmax, dim=0) + x = torch.sum(jkx * sftmax, dim=0) return x + ######################## ######### MPLP ######### ######################## + class MPLP(torch.nn.Module): - def __init__(self, in_channels, hidden_channels, num_layers, - feat_dropout, label_dropout, num_hops=2, signature_sampling='torchhd', use_degree='none', - signature_dim=1024, minimum_degree_onehot=-1, batchnorm_affine=True, + def __init__(self, in_channels, hidden_channels, num_layers, feat_dropout, + label_dropout, num_hops=2, signature_sampling='torchhd', + use_degree='none', signature_dim=1024, + minimum_degree_onehot=-1, batchnorm_affine=True, feature_combine="hadamard"): super(MPLP, self).__init__() @@ -331,25 +366,37 @@ def __init__(self, in_channels, hidden_channels, num_layers, self.feat_dropout = feat_dropout self.label_dropout = label_dropout self.num_hops = num_hops - self.signature_sampling=signature_sampling + self.signature_sampling = signature_sampling self.use_degree = use_degree self.feature_combine = feature_combine if self.use_degree == 'mlp': if USE_CUSTOM_MLP: - self.node_weight_encode = CustomMLP(2, in_channels + 1, 32, 1, feat_dropout, norm_type="batch", affine=batchnorm_affine) + self.node_weight_encode = CustomMLP(2, in_channels + 1, 32, 1, + feat_dropout, + norm_type="batch", + affine=batchnorm_affine) else: - self.node_weight_encode = MLP(num_layers=2, in_channels=in_channels + 1, hidden_channels=32, out_channels=1, - dropout=self.label_dropout, act='relu', - norm="BatchNorm", norm_kwargs={"affine": batchnorm_affine}) + self.node_weight_encode = MLP( + num_layers=2, in_channels=in_channels + 1, + hidden_channels=32, out_channels=1, + dropout=self.label_dropout, act='relu', norm="BatchNorm", + norm_kwargs={"affine": batchnorm_affine}) struct_dim = 8 - self.nodelabel = NodeLabel(signature_dim, signature_sampling=self.signature_sampling, - minimum_degree_onehot= minimum_degree_onehot) + self.nodelabel = NodeLabel(signature_dim, + signature_sampling=self.signature_sampling, + minimum_degree_onehot=minimum_degree_onehot) if USE_CUSTOM_MLP: - self.struct_encode = CustomMLP(1, struct_dim, struct_dim, struct_dim, self.label_dropout, "batch", tailnormactdrop=True, affine=batchnorm_affine) + self.struct_encode = CustomMLP(1, struct_dim, struct_dim, + struct_dim, self.label_dropout, + "batch", tailnormactdrop=True, + affine=batchnorm_affine) else: - self.struct_encode = MLP(num_layers=1, in_channels=struct_dim, hidden_channels=struct_dim, out_channels=struct_dim, - dropout=self.label_dropout, act='relu', plain_last=False, - norm="BatchNorm", norm_kwargs={"affine": batchnorm_affine}) + self.struct_encode = MLP(num_layers=1, in_channels=struct_dim, + hidden_channels=struct_dim, + out_channels=struct_dim, + dropout=self.label_dropout, act='relu', + plain_last=False, norm="BatchNorm", + norm_kwargs={"affine": batchnorm_affine}) dense_dim = struct_dim + in_channels if in_channels > 0: @@ -358,54 +405,63 @@ def __init__(self, in_channels, hidden_channels, num_layers, elif feature_combine == "plus_minus": feat_encode_input_dim = in_channels * 2 if USE_CUSTOM_MLP: - self.feat_encode = CustomMLP(2, feat_encode_input_dim, in_channels, in_channels, self.feat_dropout, "batch", tailnormactdrop=True, affine=batchnorm_affine) + self.feat_encode = CustomMLP(2, feat_encode_input_dim, + in_channels, in_channels, + self.feat_dropout, "batch", + tailnormactdrop=True, + affine=batchnorm_affine) else: - self.feat_encode = MLP(num_layers=1, in_channels=feat_encode_input_dim, hidden_channels=in_channels, out_channels=in_channels, - dropout=self.label_dropout, act='relu', plain_last=False, - norm="BatchNorm", norm_kwargs={"affine": batchnorm_affine}) + self.feat_encode = MLP( + num_layers=1, in_channels=feat_encode_input_dim, + hidden_channels=in_channels, out_channels=in_channels, + dropout=self.label_dropout, act='relu', plain_last=False, + norm="BatchNorm", norm_kwargs={"affine": batchnorm_affine}) self.classifier = nn.Linear(dense_dim, 1) - def reset_parameters(self): for m in self.modules(): if isinstance(m, nn.Linear) or isinstance(m, nn.BatchNorm1d): m.reset_parameters() - + def forward(self, x, adj, edges, cache_mode=None): + """Args: + x: [N, in_channels] node embedding after GNN + adj: [N, N] adjacency matrix + edges: [2, E] target edges + fast_inference: bool. If True, only caching the message-passing without calculating the structural features """ - Args: - x: [N, in_channels] node embedding after GNN - adj: [N, N] adjacency matrix - edges: [2, E] target edges - fast_inference: bool. If True, only caching the message-passing without calculating the structural features - """ - if cache_mode in ["use","delete"]: + if cache_mode in ["use", "delete"]: # no need to compute node_weight node_weight = None elif self.use_degree == 'none': node_weight = None - elif self.use_degree == 'mlp': # 'mlp' for now + elif self.use_degree == 'mlp': # 'mlp' for now xs = [] if self.in_channels > 0: xs.append(x) - degree = adj.sum(dim=1).view(-1,1).to(adj.device()) + degree = adj.sum(dim=1).view(-1, 1).to(adj.device()) xs.append(degree) node_weight_feat = torch.cat(xs, dim=1) - node_weight = self.node_weight_encode(node_weight_feat).squeeze(-1) + 1 # like residual, can be learned as 0 if needed + node_weight = self.node_weight_encode(node_weight_feat).squeeze( + -1) + 1 # like residual, can be learned as 0 if needed else: # AA or RA - degree = adj.sum(dim=1).view(-1,1).to(adj.device()).squeeze(-1) + 1 # degree at least 1. then log(degree) > 0. + degree = adj.sum(dim=1).view(-1, 1).to(adj.device()).squeeze( + -1) + 1 # degree at least 1. then log(degree) > 0. if self.use_degree == 'AA': node_weight = torch.sqrt(torch.reciprocal(torch.log(degree))) elif self.use_degree == 'RA': node_weight = torch.sqrt(torch.reciprocal(degree)) - node_weight = torch.nan_to_num(node_weight, nan=0.0, posinf=0.0, neginf=0.0) + node_weight = torch.nan_to_num(node_weight, nan=0.0, posinf=0.0, + neginf=0.0) - if cache_mode in ["build","delete"]: - propped = self.nodelabel(edges, adj, node_weight=node_weight, cache_mode=cache_mode) + if cache_mode in ["build", "delete"]: + propped = self.nodelabel(edges, adj, node_weight=node_weight, + cache_mode=cache_mode) return else: - propped = self.nodelabel(edges, adj, node_weight=node_weight, cache_mode=cache_mode) + propped = self.nodelabel(edges, adj, node_weight=node_weight, + cache_mode=cache_mode) propped_stack = torch.stack([*propped], dim=1) out = self.struct_encode(propped_stack) @@ -415,7 +471,7 @@ def forward(self, x, adj, edges, cache_mode=None): if self.feature_combine == "hadamard": x_ij = x_i * x_j elif self.feature_combine == "plus_minus": - x_ij = torch.cat([x_i+x_j, torch.abs(x_i-x_j)], dim=1) + x_ij = torch.cat([x_i + x_j, torch.abs(x_i - x_j)], dim=1) x_ij = self.feat_encode(x_ij) x = torch.cat([x_ij, out], dim=1) else: @@ -426,5 +482,3 @@ def forward(self, x, adj, edges, cache_mode=None): def precompute(self, adj): self(None, adj, None, cache_mode="build") return self - - From 0096d0af4e014687bf8cd13afdd7a2aceb2799ce Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Thu, 13 Jun 2024 01:59:57 +0900 Subject: [PATCH 3/4] Updates --- examples/mplp.py | 22 +++++----------------- torch_geometric/nn/models/mplp.py | 4 ++-- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/examples/mplp.py b/examples/mplp.py index 2ff03af4b669..63c2c506e1ca 100644 --- a/examples/mplp.py +++ b/examples/mplp.py @@ -525,8 +525,6 @@ def main(): help='whether to use Jumping Knowledge') parser.add_argument('--batchnorm_affine', type=str2bool, default='True', help='whether to use Affine in BatchNorm') - parser.add_argument('--use_embedding', type=str2bool, default='False', - help='whether to train node embedding') # training setting parser.add_argument('--batch_size', type=int, default=64 * 1024) @@ -581,24 +579,14 @@ def main(): one_hot_dim = nodes_to_one_hot.sum() print(f"number of nodes to onehot: {int(one_hot_dim)}") data = data.to(device) - if args.use_embedding: - emb = initial_embedding(data, args.hidden_channels, device) - else: - emb = None - if 'gcn' in args.encoder: - encoder = MPLP_GCN(data.num_features, args.hidden_channels, + + emb = initial_embedding(data, args.hidden_channels, device) + encoder = MPLP_GCN(data.num_features, args.hidden_channels, args.hidden_channels, args.num_layers, args.feat_dropout, args.xdp, args.use_feature, args.jk, args.encoder, emb).to(device) - elif args.encoder == 'mlp': - encoder = MLP(num_layers=args.num_layers, - in_channels=data.num_features, - hidden_channels=args.hidden_channels, - out_channels=args.hidden_channels, - dropout=args.feat_dropout, act=None).to(device) - - predictor_in_dim = args.hidden_channels * int(args.use_feature - or args.use_embedding) + + predictor_in_dim = args.hidden_channels predictor = MPLP(predictor_in_dim, args.hidden_channels, args.num_layers, args.feat_dropout, diff --git a/torch_geometric/nn/models/mplp.py b/torch_geometric/nn/models/mplp.py index 837535738d54..b95b176e230e 100644 --- a/torch_geometric/nn/models/mplp.py +++ b/torch_geometric/nn/models/mplp.py @@ -10,7 +10,7 @@ from torch_sparse import SparseTensor, matmul from torch_sparse.matmul import spmm_add, spmm_max, spmm_mean -from torch_geometric.nn import GCNConv, SGConv +from torch_geometric.nn import GCNConv, LGConv from torch_geometric.nn.models import MLP from torch_geometric.utils import k_hop_subgraph as pyg_k_hop_subgraph from torch_geometric.utils import to_edge_index @@ -310,7 +310,7 @@ def __init__(self, in_channels, hidden_channels, out_channels, num_layers, if gcn_name == 'gcn': conv_func = partial(GCNConv, cached=False) elif 'pure' in gcn_name: - conv_func = partial(SGConv, apply_linearity=False) + conv_func = LGConv self.xemb = nn.Sequential(nn.Dropout(xdropout)) if ("pure" in gcn_name or num_layers == 0): self.xemb.append(nn.Linear(self.input_size, hidden_channels)) From 2f714450f4d3996973d9035eb49d3ca45a199375 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jun 2024 17:05:38 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/mplp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/mplp.py b/examples/mplp.py index 63c2c506e1ca..40df01248352 100644 --- a/examples/mplp.py +++ b/examples/mplp.py @@ -582,9 +582,9 @@ def main(): emb = initial_embedding(data, args.hidden_channels, device) encoder = MPLP_GCN(data.num_features, args.hidden_channels, - args.hidden_channels, args.num_layers, - args.feat_dropout, args.xdp, args.use_feature, - args.jk, args.encoder, emb).to(device) + args.hidden_channels, args.num_layers, + args.feat_dropout, args.xdp, args.use_feature, + args.jk, args.encoder, emb).to(device) predictor_in_dim = args.hidden_channels