diff --git a/examples/explain/attention_exaplainer.py b/examples/explain/attention_exaplainer.py index 65e7c0f3d15db..314c759ec94ac 100644 --- a/examples/explain/attention_exaplainer.py +++ b/examples/explain/attention_exaplainer.py @@ -5,7 +5,7 @@ import torch_geometric from torch_geometric.datasets import Planetoid -from torch_geometric.explain import Explainer, AttentionExplainer +from torch_geometric.explain import AttentionExplainer, Explainer from torch_geometric.nn import GATConv if torch.cuda.is_available(): @@ -19,14 +19,15 @@ dataset = Planetoid(path, name='Cora') data = dataset[0].to(device) - # GAT Node Classification ===================================================== + class GAT(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6) - self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False, dropout=0.6) + self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False, + dropout=0.6) def forward(self, x, edge_index): x = F.dropout(x, p=0.6, training=self.training) diff --git a/examples/explain/graphmask_explainer.py b/examples/explain/graphmask_explainer.py index c7de53cf75428..a7659a9afd604 100644 --- a/examples/explain/graphmask_explainer.py +++ b/examples/explain/graphmask_explainer.py @@ -19,7 +19,6 @@ dataset = Planetoid(path, name='Cora') data = dataset[0].to(device) - # GCN Node Classification =====================================================