-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_lightning.py
78 lines (58 loc) · 3.17 KB
/
main_lightning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import torch
from model_lightning import GAT
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import numpy as np
import random
def fix_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
np.random.seed(seed)
random.seed(seed)
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--dropout", default=0.6, type=float)
args.add_argument("--hidden", default=258, type=int)
args.add_argument("--n_heads", default=1, type=int)
args.add_argument("--learning_rate", default=0.001, type=float)
args.add_argument("--batch_size", default=32, type=int)
args.add_argument("--split", default=0.8, type=float)
args.add_argument("--patience", default=10, type=int)
args.add_argument("--epoch", default=100, type=int)
args.add_argument("--gpus", default=1, type=int)
args.add_argument("--seed", default=42, type=int)
args.add_argument("--num_workers", default=1, type=int)
args = args.parse_args()
fix_seed(args.seed)
config = vars(args) # convert to dictionary
config["num_feature"] = 512
print("Building DataLoader & Model ...")
model = GAT(config)
print("Done")
print("Building Trainer ...")
logger = TensorBoardLogger(save_dir="./runs2", name="A{}_H{}_lr{}_batch{}_drop{}".format(config["n_heads"],
config["hidden"],
config["learning_rate"],
config["batch_size"],
config["dropout"]))
early_stopping = EarlyStopping("val_loss", patience=config["patience"])
checkpoint = ModelCheckpoint(dirpath="./output",
filename="{epoch}_{val_acc:.2f}_" +
"A{}_H{}_lr{}_batch{}_drop{}".format(config["n_heads"],
config["hidden"],
config["learning_rate"],
config["batch_size"],
config["dropout"]),
monitor="val_acc",
mode="max")
trainer = Trainer(accelerator="gpu", devices=config["gpus"], max_epochs=config["epoch"], logger=logger,
callbacks=[early_stopping, checkpoint], log_every_n_steps=1)
print("+--------------------------------------------+")
print("| Training Start |")
print("+--------------------------------------------+")
trainer.fit(model)