Skip to content

3284 bug support transformers 4310 #1910

3284 bug support transformers 4310

3284 bug support transformers 4310 #1910

Triggered via pull request July 31, 2023 13:47
Status Failure
Total duration 27m 53s
Artifacts

ci.yml

on: pull_request
Fit to window
Zoom out
Zoom in

Annotations

3 errors
test: flair/embeddings/transformer.py#L1
Black format check --- /home/runner/work/flair/flair/flair/embeddings/transformer.py 2023-07-31 13:47:49.696067 +0000 +++ /home/runner/work/flair/flair/flair/embeddings/transformer.py 2023-07-31 13:51:30.386564 +0000 @@ -68,15 +68,15 @@ return hidden_states[:, :, : input_ids.size()[1]] @torch.jit.script_if_tracing def combine_strided_tensors( - hidden_states: torch.Tensor, - overflow_to_sample_mapping: torch.Tensor, - half_stride: int, - max_length: int, - default_value: int, + hidden_states: torch.Tensor, + overflow_to_sample_mapping: torch.Tensor, + half_stride: int, + max_length: int, + default_value: int, ) -> torch.Tensor: _, counts = torch.unique(overflow_to_sample_mapping, sorted=True, return_counts=True) sentence_count = int(overflow_to_sample_mapping.max().item() + 1) token_count = max_length + (max_length - 2) * int(counts.max().item() - 1) if hidden_states.dim() == 2: @@ -92,13 +92,13 @@ for sentence_id in torch.arange(0, sentence_hidden_states.shape[0]): selected_sentences = hidden_states[overflow_to_sample_mapping == sentence_id] if selected_sentences.size(0) > 1: start_part = selected_sentences[0, : half_stride + 1] - mid_part = selected_sentences[:, half_stride + 1: max_length - 1 - half_stride] + mid_part = selected_sentences[:, half_stride + 1 : max_length - 1 - half_stride] mid_part = torch.reshape(mid_part, (mid_part.shape[0] * mid_part.shape[1],) + mid_part.shape[2:]) - end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1:] + end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1 :] sentence_hidden_state = torch.cat((start_part, mid_part, end_part), dim=0) sentence_hidden_states[sentence_id, : sentence_hidden_state.shape[0]] = torch.cat( (start_part, mid_part, end_part), dim=0 ) else: @@ -107,25 +107,25 @@ return sentence_hidden_states @torch.jit.script_if_tracing def fill_masked_elements( - all_token_embeddings: torch.Tensor, - sentence_hidden_states: torch.Tensor, - mask: torch.Tensor, - word_ids: torch.Tensor, - lengths: torch.LongTensor, + all_token_embeddings: torch.Tensor, + sentence_hidden_states: torch.Tensor, + mask: torch.Tensor, + word_ids: torch.Tensor, + lengths: torch.LongTensor, ): for i in torch.arange(int(all_token_embeddings.shape[0])): r = insert_missing_embeddings(sentence_hidden_states[i][mask[i] & (word_ids[i] >= 0)], word_ids[i], lengths[i]) all_token_embeddings[i, : lengths[i], :] = r return all_token_embeddings @torch.jit.script_if_tracing def insert_missing_embeddings( - token_embeddings: torch.Tensor, word_id: torch.Tensor, length: torch.LongTensor + token_embeddings: torch.Tensor, word_id: torch.Tensor, length: torch.LongTensor ) -> torch.Tensor: # in some cases we need to insert zero vectors for tokens without embedding. if token_embeddings.shape[0] == 0: if token_embeddings.dim() == 2: token_embeddings = torch.zeros( @@ -164,14 +164,14 @@ return token_embeddings @torch.jit.script_if_tracing def fill_mean_token_embeddings( - all_token_embeddings: torch.Tensor, - sentence_hidden_states: torch.Tensor, - word_ids: torch.Tensor, - token_lengths: torch.Tensor, + all_token_embeddings: torch.Tensor, + sentence_hidden_states: torch.Tensor, + word_ids: torch.Tensor, + token_lengths: torch.Tensor, ): for i in torch.arange(all_token_embeddings.shape[0]): for _id in torch.arange(token_lengths[i]): # type: ignore[call-overload] all_token_embeddings[i, _id, :] = torch.nan_to_num( sentence_hidden_states[i][word_ids[i] == _id].mean(dim=0) @@ -194,11 +194,11 @@ for i in torch.arange(sentence_hidden_states.shape[0]): result[i], _ =
test: tests/test_datasets.py#L780
test_masakhane_corpus[False] AssertionError: Mismatch in number of sentences for fon@v2/dev assert 621 == 623
test
Process completed with exit code 1.