Skip to content

Commit

Permalink
Merge pull request #3537 from MattGPT-ai/fix.text-pair-regressor-stat…
Browse files Browse the repository at this point in the history
…e-dict-key-bug

GH-3536: fix state dict key mismatch for embeddings in TextPairRegres…
  • Loading branch information
helpmefindaname authored Aug 23, 2024
2 parents 7a11174 + d6874a6 commit 3685529
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions flair/models/pairwise_regression_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -91,7 +90,7 @@ def label_type(self):

def get_used_tokens(
self, corpus: Corpus, context_length: int = 0, respect_document_boundaries: bool = True
) -> typing.Iterable[List[str]]:
) -> Iterable[List[str]]:
for sentence_pair in _iter_dataset(corpus.get_all_sentences()):
yield [t.text for t in sentence_pair.first]
yield [t.text for t in sentence_pair.first.left_context(context_length, respect_document_boundaries)]
Expand Down Expand Up @@ -204,10 +203,16 @@ def _get_state_dict(self):
return model_state

@classmethod
def _init_model_with_state_dict(cls, state, **kwargs):
# add DefaultClassifier arguments
def _init_model_with_state_dict(cls, state: Dict[str, Any], **kwargs):
"""Initializes a TextPairRegressor model from a state dictionary (exported by _get_state_dict).
Requires keys 'state_dict', 'document_embeddings', and 'label_type' in the state dictionary.
"""
if "document_embeddings" in state:
state["embeddings"] = state.pop("document_embeddings") # need to rename this parameter
# add Model arguments
for arg in [
"document_embeddings",
"embeddings",
"label_type",
"embed_separately",
"dropout",
Expand Down

0 comments on commit 3685529

Please sign in to comment.