Skip to content

Commit

Permalink
Fix conv1d erro.
Browse files Browse the repository at this point in the history
See issues #26 and #51 .
  • Loading branch information
zezhishao authored Dec 10, 2023
1 parent 3d7e747 commit 566e273
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions step/step_arch/graphwavenet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ def __init__(self, num_nodes, support_len, dropout=0.3, gcn_bool=True, addaptadj
# dilated convolutions
self.filter_convs.append(nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1,kernel_size),dilation=new_dilation))

self.gate_convs.append(nn.Conv1d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation))
self.gate_convs.append(nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation))

# 1x1 convolution for residual connection
self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1)))
self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1)))

# 1x1 convolution for skip connection
self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1)))
self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1)))
self.bn.append(nn.BatchNorm2d(residual_channels))
new_dilation *= 2
receptive_field += additional_scope
Expand Down

0 comments on commit 566e273

Please sign in to comment.