From cc3670956cc6f2b34ce0d6411000a2810d4b931f Mon Sep 17 00:00:00 2001 From: Niklas Muennighoff Date: Thu, 22 Feb 2024 16:58:09 +0100 Subject: [PATCH] Add attn bias arg to HF wrapper (#458) --- hf_olmo/modeling_olmo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hf_olmo/modeling_olmo.py b/hf_olmo/modeling_olmo.py index 83814c5cf..6a279cb10 100644 --- a/hf_olmo/modeling_olmo.py +++ b/hf_olmo/modeling_olmo.py @@ -50,6 +50,7 @@ def forward( input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -70,6 +71,7 @@ def forward( input_ids=input_ids, input_embeddings=inputs_embeds, attention_mask=attention_mask, + attention_bias=attention_bias, past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states,