Skip to content

Commit

Permalink
Quality
Browse files Browse the repository at this point in the history
  • Loading branch information
hSterz committed Nov 16, 2020
1 parent d029259 commit 5a58ca4
Showing 1 changed file with 11 additions and 44 deletions.
55 changes: 11 additions & 44 deletions src/transformers/adapter_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze
param.requires_grad = True

def get_adapter_preparams(
self,
adapter_config,
hidden_states,
input_tensor,
self, adapter_config, hidden_states, input_tensor,
):
"""
Retrieves the hidden_states, query (for Fusion), and residual connection according to the set configuration
Expand Down Expand Up @@ -216,12 +213,7 @@ def adapter_fusion(self, hidden_states, adapter_stack, residual, query):

fusion_name = ",".join(adapter_stack)

hidden_states = self.adapter_fusion_layer[fusion_name](
query,
up_list,
up_list,
residual,
)
hidden_states = self.adapter_fusion_layer[fusion_name](query, up_list, up_list, residual,)
return hidden_states

def adapters_forward(self, hidden_states, input_tensor, adapter_names=None):
Expand All @@ -240,9 +232,7 @@ def adapters_forward(self, hidden_states, input_tensor, adapter_names=None):

for adapter_stack in adapter_names:
hidden_states = self.adapter_stack_layer(
hidden_states=hidden_states,
input_tensor=input_tensor,
adapter_stack=adapter_stack,
hidden_states=hidden_states, input_tensor=input_tensor, adapter_stack=adapter_stack,
)

last_config = self.config.adapters.get(adapter_names[-1][-1])
Expand Down Expand Up @@ -312,10 +302,7 @@ def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze
param.requires_grad = True

def get_adapter_preparams(
self,
adapter_config,
hidden_states,
input_tensor,
self, adapter_config, hidden_states, input_tensor,
):
"""
Retrieves the hidden_states, query (for Fusion), and residual connection according to the set configuration
Expand Down Expand Up @@ -444,9 +431,7 @@ def adapters_forward(self, hidden_states, input_tensor, adapter_names=None):

for adapter_stack in adapter_names:
hidden_states = self.adapter_stack_layer(
hidden_states=hidden_states,
input_tensor=input_tensor,
adapter_stack=adapter_stack,
hidden_states=hidden_states, input_tensor=input_tensor, adapter_stack=adapter_stack,
)

last_config = self.config.adapters.get(adapter_names[-1][-1])
Expand Down Expand Up @@ -693,10 +678,7 @@ def add_qa_head(
self.add_prediction_head(head_name, config, overwrite_ok)

def add_prediction_head(
self,
head_name,
config,
overwrite_ok=False,
self, head_name, config, overwrite_ok=False,
):
if head_name not in self.config.prediction_heads or overwrite_ok:
self.config.prediction_heads[head_name] = config
Expand Down Expand Up @@ -766,10 +748,7 @@ def forward_head(self, outputs, head_name=None, attention_mask=None, labels=None

if return_dict:
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
else:
return outputs
Expand All @@ -787,10 +766,7 @@ def forward_head(self, outputs, head_name=None, attention_mask=None, labels=None

if return_dict:
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
else:
return outputs
Expand All @@ -807,10 +783,7 @@ def forward_head(self, outputs, head_name=None, attention_mask=None, labels=None

if return_dict:
return MultipleChoiceModelOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
else:
return outputs
Expand All @@ -835,10 +808,7 @@ def forward_head(self, outputs, head_name=None, attention_mask=None, labels=None

if return_dict:
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
)
else:
return outputs
Expand All @@ -850,10 +820,7 @@ def forward_head(self, outputs, head_name=None, attention_mask=None, labels=None
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)

outputs = (
start_logits,
end_logits,
) + outputs[2:]
outputs = (start_logits, end_logits,) + outputs[2:]
if labels is not None:
start_positions, end_positions = labels
if len(start_positions.size()) > 1:
Expand Down

0 comments on commit 5a58ca4

Please sign in to comment.