Skip to content

Commit

Permalink
Multi-dim prompt tuning mask padding
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jul 6, 2024
1 parent c8c5c12 commit 8de5c22
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
2 changes: 2 additions & 0 deletions src/adapters/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def forward(
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
attention_mask = prefix_attention_mask(attention_mask, [2, 3]) # type: ignore

if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
Expand Down
31 changes: 17 additions & 14 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from functools import partial
from os.path import basename, isdir, isfile, join
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile

Expand Down Expand Up @@ -865,7 +865,7 @@ def get_adapter_info(adapter_id: str, source: str = "ah") -> Optional[AdapterInf
raise ValueError("Please specify either 'ah' or 'hf' as source.")


def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
def prefix_attention_mask(attention_mask, dim: Union[int, List[int]] = 3, prefix_value: int = 0):
"""
Adds a prefix to an attention mask. The length of the prefix is determined by the `prefix_attention_mask_length`
attribute in the ForwardContext.
Expand All @@ -890,18 +890,21 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0):
and forward_context is not None
and getattr(forward_context, "prompt_tokens_length", None) is not None
):
# Create a tensor of ones with the desired shape
ones_shape = list(attention_mask.shape)
ones_shape[dim] = forward_context.prompt_tokens_length

prefix_attention_mask = torch.full(
ones_shape,
prefix_value,
dtype=attention_mask.dtype,
).to(attention_mask.device)

# Concatenate the prefix_attention_mask along the specified dimension
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=dim)
if isinstance(dim, int):
dim = [dim]
for d in dim:
# Create a tensor of ones with the desired shape
ones_shape = list(attention_mask.shape)
ones_shape[d] = forward_context.prompt_tokens_length

prefix_attention_mask = torch.full(
ones_shape,
prefix_value,
dtype=attention_mask.dtype,
).to(attention_mask.device)

# Concatenate the prefix_attention_mask along the specified dimension
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=d)

return attention_mask

Expand Down

0 comments on commit 8de5c22

Please sign in to comment.