From becf5e911d52fde7438bdc2855181f38143f7241 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 18 Sep 2023 17:24:35 +0100 Subject: [PATCH] [Wav2Vec2-Conf / LLaMA] Style fix (#26188) * torch.nn -> nn * fix llama * copies --- .../deprecated/open_llama/modeling_open_llama.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- .../modeling_wav2vec2_conformer.py | 14 +++++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 67142229aada3d..0d36d8c0e06306 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -99,7 +99,7 @@ def forward(self, hidden_states): # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama -class OpenLlamaRotaryEmbedding(torch.nn.Module): +class OpenLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 5e7a879c07e88f..f2fef00f7c17a5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -89,7 +89,7 @@ def forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) -class LlamaRotaryEmbedding(torch.nn.Module): +class LlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 5041039a8ef987..f162c514297067 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -584,7 +584,7 @@ def __init__(self, config): if (config.conv_depthwise_kernel_size - 1) % 2 == 1: raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") self.layer_norm = nn.LayerNorm(config.hidden_size) - self.pointwise_conv1 = torch.nn.Conv1d( + self.pointwise_conv1 = nn.Conv1d( config.hidden_size, 2 * config.hidden_size, kernel_size=1, @@ -592,8 +592,8 @@ def __init__(self, config): padding=0, bias=False, ) - self.glu = torch.nn.GLU(dim=1) - self.depthwise_conv = torch.nn.Conv1d( + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( config.hidden_size, config.hidden_size, config.conv_depthwise_kernel_size, @@ -602,9 +602,9 @@ def __init__(self, config): groups=config.hidden_size, bias=False, ) - self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size) + self.batch_norm = nn.BatchNorm1d(config.hidden_size) self.activation = ACT2FN[config.hidden_act] - self.pointwise_conv2 = torch.nn.Conv1d( + self.pointwise_conv2 = nn.Conv1d( config.hidden_size, config.hidden_size, kernel_size=1, @@ -612,7 +612,7 @@ def __init__(self, config): padding=0, bias=False, ) - self.dropout = torch.nn.Dropout(config.conformer_conv_dropout) + self.dropout = nn.Dropout(config.conformer_conv_dropout) def forward(self, hidden_states): hidden_states = self.layer_norm(hidden_states) @@ -798,7 +798,7 @@ def __init__(self, config): # Self-Attention self.self_attn_layer_norm = nn.LayerNorm(embed_dim) - self.self_attn_dropout = torch.nn.Dropout(dropout) + self.self_attn_dropout = nn.Dropout(dropout) self.self_attn = Wav2Vec2ConformerSelfAttention(config) # Conformer Convolution