Skip to content

Commit

Permalink
extend args accpeted by Embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
donglihe-hub committed Aug 7, 2023
1 parent 2c52c3b commit 18944fb
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
2 changes: 1 addition & 1 deletion libmultilabel/nn/networks/kim_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
activation="relu",
):
super(KimCNN, self).__init__()
self.embedding = Embedding(embed_vecs, embed_dropout)
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
self.encoder = CNNEncoder(
embed_vecs.shape[1], filter_sizes, num_filter_per_size, activation, post_encoder_dropout, num_pool=1
)
Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/networks/labelwise_attention_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class LabelwiseAttentionNetwork(ABC, nn.Module):

def __init__(self, embed_vecs, num_classes, embed_dropout, encoder_dropout, post_encoder_dropout, hidden_dim):
super(LabelwiseAttentionNetwork, self).__init__()
self.embedding = Embedding(embed_vecs, embed_dropout)
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
self.encoder = self._get_encoder(embed_vecs.shape[1], encoder_dropout, post_encoder_dropout)
self.attention = self._get_attention()
self.output = LabelwiseLinearOutput(hidden_dim, num_classes)
Expand Down
29 changes: 17 additions & 12 deletions libmultilabel/nn/networks/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,29 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class Embedding(nn.Module):
"""Embedding layer with dropout
Args:
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
"""

def __init__(self, embed_vecs, dropout=0.2):
super(Embedding, self).__init__()
self.embedding = nn.Embedding.from_pretrained(embed_vecs, freeze=False, padding_idx=0)
"""Embedding layer with dropout."""

def __init__(self, embed_vecs: Tensor, freeze: bool = False, sparse: bool = False, dropout: float = 0.2):
"""Construct the embedding layer with dropout from pre-trained word vectors.
Args:
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
freeze (bool): If True, the tensor does not get updated in the learning process.
Equivalent to embedding.weight.requires_grad = False. Default: False.
sparse (bool): If True, gradient w.r.t. weight matrix will be a sparse tensor. Default: False.
dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
"""
super().__init__()
self.embedding = nn.Embedding.from_pretrained(embed_vecs, freeze=freeze, padding_idx=0, sparse=sparse)
self.dropout = nn.Dropout(dropout)

def forward(self, input):
return self.dropout(self.embedding(input))
def forward(self, inputs: Tensor) -> Tensor:
return self.dropout(self.embedding(inputs))


class RNNEncoder(ABC, nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion libmultilabel/nn/networks/xml_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
activation="relu",
):
super(XMLCNN, self).__init__()
self.embedding = Embedding(embed_vecs, embed_dropout)
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
self.encoder = CNNEncoder(embed_vecs.shape[1], filter_sizes, num_filter_per_size, activation, num_pool=num_pool)
total_output_size = len(filter_sizes) * num_filter_per_size * num_pool
self.linear1 = nn.Linear(total_output_size, hidden_dim)
Expand Down

0 comments on commit 18944fb

Please sign in to comment.