3284 bug support transformers 4310 #1910
Annotations
3 errors
test:
flair/embeddings/transformer.py#L1
Black format check
--- /home/runner/work/flair/flair/flair/embeddings/transformer.py 2023-07-31 13:47:49.696067 +0000
+++ /home/runner/work/flair/flair/flair/embeddings/transformer.py 2023-07-31 13:51:30.386564 +0000
@@ -68,15 +68,15 @@
return hidden_states[:, :, : input_ids.size()[1]]
@torch.jit.script_if_tracing
def combine_strided_tensors(
- hidden_states: torch.Tensor,
- overflow_to_sample_mapping: torch.Tensor,
- half_stride: int,
- max_length: int,
- default_value: int,
+ hidden_states: torch.Tensor,
+ overflow_to_sample_mapping: torch.Tensor,
+ half_stride: int,
+ max_length: int,
+ default_value: int,
) -> torch.Tensor:
_, counts = torch.unique(overflow_to_sample_mapping, sorted=True, return_counts=True)
sentence_count = int(overflow_to_sample_mapping.max().item() + 1)
token_count = max_length + (max_length - 2) * int(counts.max().item() - 1)
if hidden_states.dim() == 2:
@@ -92,13 +92,13 @@
for sentence_id in torch.arange(0, sentence_hidden_states.shape[0]):
selected_sentences = hidden_states[overflow_to_sample_mapping == sentence_id]
if selected_sentences.size(0) > 1:
start_part = selected_sentences[0, : half_stride + 1]
- mid_part = selected_sentences[:, half_stride + 1: max_length - 1 - half_stride]
+ mid_part = selected_sentences[:, half_stride + 1 : max_length - 1 - half_stride]
mid_part = torch.reshape(mid_part, (mid_part.shape[0] * mid_part.shape[1],) + mid_part.shape[2:])
- end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1:]
+ end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1 :]
sentence_hidden_state = torch.cat((start_part, mid_part, end_part), dim=0)
sentence_hidden_states[sentence_id, : sentence_hidden_state.shape[0]] = torch.cat(
(start_part, mid_part, end_part), dim=0
)
else:
@@ -107,25 +107,25 @@
return sentence_hidden_states
@torch.jit.script_if_tracing
def fill_masked_elements(
- all_token_embeddings: torch.Tensor,
- sentence_hidden_states: torch.Tensor,
- mask: torch.Tensor,
- word_ids: torch.Tensor,
- lengths: torch.LongTensor,
+ all_token_embeddings: torch.Tensor,
+ sentence_hidden_states: torch.Tensor,
+ mask: torch.Tensor,
+ word_ids: torch.Tensor,
+ lengths: torch.LongTensor,
):
for i in torch.arange(int(all_token_embeddings.shape[0])):
r = insert_missing_embeddings(sentence_hidden_states[i][mask[i] & (word_ids[i] >= 0)], word_ids[i], lengths[i])
all_token_embeddings[i, : lengths[i], :] = r
return all_token_embeddings
@torch.jit.script_if_tracing
def insert_missing_embeddings(
- token_embeddings: torch.Tensor, word_id: torch.Tensor, length: torch.LongTensor
+ token_embeddings: torch.Tensor, word_id: torch.Tensor, length: torch.LongTensor
) -> torch.Tensor:
# in some cases we need to insert zero vectors for tokens without embedding.
if token_embeddings.shape[0] == 0:
if token_embeddings.dim() == 2:
token_embeddings = torch.zeros(
@@ -164,14 +164,14 @@
return token_embeddings
@torch.jit.script_if_tracing
def fill_mean_token_embeddings(
- all_token_embeddings: torch.Tensor,
- sentence_hidden_states: torch.Tensor,
- word_ids: torch.Tensor,
- token_lengths: torch.Tensor,
+ all_token_embeddings: torch.Tensor,
+ sentence_hidden_states: torch.Tensor,
+ word_ids: torch.Tensor,
+ token_lengths: torch.Tensor,
):
for i in torch.arange(all_token_embeddings.shape[0]):
for _id in torch.arange(token_lengths[i]): # type: ignore[call-overload]
all_token_embeddings[i, _id, :] = torch.nan_to_num(
sentence_hidden_states[i][word_ids[i] == _id].mean(dim=0)
@@ -194,11 +194,11 @@
for i in torch.arange(sentence_hidden_states.shape[0]):
result[i], _ =
|
test:
tests/test_datasets.py#L780
test_masakhane_corpus[False]
AssertionError: Mismatch in number of sentences for fon@v2/dev
assert 621 == 623
|
test
Process completed with exit code 1.
|