-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
108 lines (90 loc) · 5.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import torch
import os
import random
import argparse
import sys
def get_args():
parser = argparse.ArgumentParser('Interface for Inductive Dynamic Representation Learning for Link Prediction on Temporal Graphs')
# select dataset and training mode
parser.add_argument('-d', '--data', type=str, help='data sources to use, try wikipedia or reddit', default='wikipedia')
parser.add_argument('--data_usage', default=1.0, type=float, help='fraction of data to use (0-1)')
parser.add_argument('-m', '--mode', type=str, default='i', choices=['t', 'i'], help='transductive (t) or inductive (i)')
# method-related hyper-parameters
parser.add_argument('--n_walks', type=int, default=64, help='number of sampled walks to form a subgraph')
parser.add_argument('--n_steps', type=int, default=2, help='walk steps as well as number of layers')
parser.add_argument('--bias', default=0.0, type=float, help='the hyperparameter alpha controlling sampling preference in recent time, default to 0 which is uniform sampling')
parser.add_argument('--agg', type=str, default='walk', choices=['tree', 'walk'],
help='tree based hierarchical aggregation or walk-based flat lstm aggregation, we only use the default here')
parser.add_argument('--pos_enc', type=str, default='lp', choices=['spd', 'lp', 'saw'], help='way to encode distances, shortest-path distance or landing probabilities, or self-based anonymous walk (baseline)')
parser.add_argument('--pos_dim', type=int, default=172, help='dimension of the positional embedding')
parser.add_argument('--pos_sample', type=str, default='binary', choices=['multinomial', 'binary'], help='two equivalent sampling method with empirically different running time')
parser.add_argument('--walk_pool', type=str, default='attn', choices=['attn', 'sum'], help='how to pool the encoded walks, using attention or simple sum, if sum will overwrite all the other walk_ arguments')
parser.add_argument('--walk_n_head', type=int, default=8, help="number of heads to use for walk attention")
parser.add_argument('--walk_mutual', action='store_true', help="whether to do mutual query for source and target node random walks")
parser.add_argument('--walk_linear_out', action='store_true', default=False, help="whether to linearly project each node's embedding")
parser.add_argument('--attn_agg_method', type=str, default='attn', choices=['attn', 'lstm', 'mean'], help='local aggregation method, we only use the default here')
parser.add_argument('--attn_mode', type=str, default='prod', choices=['prod', 'map'],
help='use dot product attention or mapping based, we only use the default here')
parser.add_argument('--attn_n_head', type=int, default=2, help='number of heads used in tree-shaped attention layer, we only use the default here')
parser.add_argument('--time', type=str, default='time', choices=['time', 'pos', 'empty'], help='how to use time information, we only use the default here')
parser.add_argument('--w', type=int, default=16, help='size of sampled neighbors pool')
# general training hyper-parameters
parser.add_argument('--n_epoch', type=int, default=10, help='number of epochs')
parser.add_argument('--bs', type=int, default=64, help='batch_size')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--drop_out', type=float, default=0.1, help='dropout probability for all dropout layers')
parser.add_argument('--tolerance', type=float, default=0,
help='tolerated marginal improvement for early stopper')
# parameters controlling computation settings but not affecting results in general
parser.add_argument('--seed', type=int, default=0, help='random seed for all randomized algorithms')
parser.add_argument('--ngh_cache', action='store_true',
help='(currently not suggested due to overwhelming memory consumption) cache temporal neighbors previously calculated to speed up repeated lookup')
parser.add_argument('--gpu', type=int, default=0, help='which gpu to use')
parser.add_argument('--verbosity', type=int, default=1, help='verbosity of the program output')
try:
args = parser.parse_args()
except:
parser.print_help()
sys.exit(0)
return args, sys.argv
class EarlyStopMonitor(object):
def __init__(self, max_round=3, higher_better=True, tolerance=1e-3):
self.max_round = max_round
self.num_round = 0
self.epoch_count = 0
self.best_epoch = 0
self.last_best = None
self.higher_better = higher_better
self.tolerance = tolerance
def early_stop_check(self, curr_val):
if not self.higher_better:
curr_val *= -1
if self.last_best is None:
self.last_best = curr_val
elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
self.last_best = curr_val
self.num_round = 0
self.best_epoch = self.epoch_count
else:
self.num_round += 1
self.epoch_count += 1
return self.num_round >= self.max_round
class RandEdgeSampler(object):
def __init__(self, src_list, dst_list):
src_list = np.concatenate(src_list)
dst_list = np.concatenate(dst_list)
self.src_list = np.unique(src_list)
self.dst_list = np.unique(dst_list)
def sample(self, size):
src_index = np.random.randint(0, len(self.src_list), size)
dst_index = np.random.randint(0, len(self.dst_list), size)
return self.src_list[src_index], self.dst_list[dst_index]
def set_random_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)