-
Notifications
You must be signed in to change notification settings - Fork 6
/
datasets.py
147 lines (110 loc) · 6 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import numpy as np
import torch
from torch.nn import functional as F
import dgl
from dgl import ops
from sklearn.metrics import roc_auc_score
class Dataset:
def __init__(self, name, add_self_loops=False, device='cpu', use_sgc_features=False, use_identity_features=False,
use_adjacency_features=False, do_not_use_original_features=False):
if do_not_use_original_features and not any([use_sgc_features, use_identity_features, use_adjacency_features]):
raise ValueError('If original node features are not used, at least one of the arguments '
'use_sgc_features, use_identity_features, use_adjacency_features should be used.')
print('Preparing data...')
data = np.load(os.path.join('data', f'{name.replace("-", "_")}.npz'))
node_features = torch.tensor(data['node_features'])
labels = torch.tensor(data['node_labels'])
edges = torch.tensor(data['edges'])
graph = dgl.graph((edges[:, 0], edges[:, 1]), num_nodes=len(node_features), idtype=torch.int)
if 'directed' not in name:
graph = dgl.to_bidirected(graph)
if add_self_loops:
graph = dgl.add_self_loop(graph)
num_classes = len(labels.unique())
num_targets = 1 if num_classes == 2 else num_classes
if num_targets == 1:
labels = labels.float()
train_masks = torch.tensor(data['train_masks'])
val_masks = torch.tensor(data['val_masks'])
test_masks = torch.tensor(data['test_masks'])
train_idx_list = [torch.where(train_mask)[0] for train_mask in train_masks]
val_idx_list = [torch.where(val_mask)[0] for val_mask in val_masks]
test_idx_list = [torch.where(test_mask)[0] for test_mask in test_masks]
node_features = self.augment_node_features(graph=graph,
node_features=node_features,
use_sgc_features=use_sgc_features,
use_identity_features=use_identity_features,
use_adjacency_features=use_adjacency_features,
do_not_use_original_features=do_not_use_original_features)
self.name = name
self.device = device
self.graph = graph.to(device)
self.node_features = node_features.to(device)
self.labels = labels.to(device)
self.train_idx_list = [train_idx.to(device) for train_idx in train_idx_list]
self.val_idx_list = [val_idx.to(device) for val_idx in val_idx_list]
self.test_idx_list = [test_idx.to(device) for test_idx in test_idx_list]
self.num_data_splits = len(train_idx_list)
self.cur_data_split = 0
self.num_node_features = node_features.shape[1]
self.num_targets = num_targets
self.loss_fn = F.binary_cross_entropy_with_logits if num_targets == 1 else F.cross_entropy
self.metric = 'ROC AUC' if num_targets == 1 else 'accuracy'
@property
def train_idx(self):
return self.train_idx_list[self.cur_data_split]
@property
def val_idx(self):
return self.val_idx_list[self.cur_data_split]
@property
def test_idx(self):
return self.test_idx_list[self.cur_data_split]
def next_data_split(self):
self.cur_data_split = (self.cur_data_split + 1) % self.num_data_splits
def compute_metrics(self, logits):
if self.num_targets == 1:
train_metric = roc_auc_score(y_true=self.labels[self.train_idx].cpu().numpy(),
y_score=logits[self.train_idx].cpu().numpy()).item()
val_metric = roc_auc_score(y_true=self.labels[self.val_idx].cpu().numpy(),
y_score=logits[self.val_idx].cpu().numpy()).item()
test_metric = roc_auc_score(y_true=self.labels[self.test_idx].cpu().numpy(),
y_score=logits[self.test_idx].cpu().numpy()).item()
else:
preds = logits.argmax(axis=1)
train_metric = (preds[self.train_idx] == self.labels[self.train_idx]).float().mean().item()
val_metric = (preds[self.val_idx] == self.labels[self.val_idx]).float().mean().item()
test_metric = (preds[self.test_idx] == self.labels[self.test_idx]).float().mean().item()
metrics = {
f'train {self.metric}': train_metric,
f'val {self.metric}': val_metric,
f'test {self.metric}': test_metric
}
return metrics
@staticmethod
def augment_node_features(graph, node_features, use_sgc_features, use_identity_features, use_adjacency_features,
do_not_use_original_features):
n = graph.num_nodes()
original_node_features = node_features
if do_not_use_original_features:
node_features = torch.tensor([[] for _ in range(n)])
if use_sgc_features:
sgc_features = Dataset.compute_sgc_features(graph, original_node_features)
node_features = torch.cat([node_features, sgc_features], axis=1)
if use_identity_features:
node_features = torch.cat([node_features, torch.eye(n)], axis=1)
if use_adjacency_features:
graph_without_self_loops = dgl.remove_self_loop(graph)
adj_matrix = graph_without_self_loops.adjacency_matrix().to_dense()
node_features = torch.cat([node_features, adj_matrix], axis=1)
return node_features
@staticmethod
def compute_sgc_features(graph, node_features, num_props=5):
graph = dgl.remove_self_loop(graph)
graph = dgl.add_self_loop(graph)
degrees = graph.out_degrees().float()
degree_edge_products = ops.u_mul_v(graph, degrees, degrees)
norm_coefs = 1 / degree_edge_products ** 0.5
for _ in range(num_props):
node_features = ops.u_mul_e_sum(graph, node_features, norm_coefs)
return node_features