Skip to content

Commit

Permalink
input check
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Aug 25, 2023
1 parent d7e6cd6 commit cda6693
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def warm_up(self):
@component.output_types(embeddings=List[np.ndarray])
def run(self, texts: List[str]):
"""Embed a list of strings."""
if not isinstance(texts, list) or not isinstance(texts[0], str):
raise ValueError(
"SentenceTransformersTextEmbedder expects a list of strings as input."
"In case you want to embed Documents, please use the SentenceTransformersDocumentEmbedder."
)
self.warm_up()
texts_to_embed = [self.prefix + text + self.suffix for text in texts]
embeddings = self.embedding_backend.embed(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,18 @@ def test_run(self):
assert len(embeddings) == len(texts)
for embedding in embeddings:
assert isinstance(embedding, np.ndarray)

@pytest.mark.unit
def test_run_wrong_input_format(self):
embedder = SentenceTransformersTextEmbedder(model_name_or_path="model")
embedder.embedding_backend = MagicMock()
# embedder.embedding_backend.embed = lambda x, **kwargs: list(np.random.rand(len(x), 16))

string_input = "text"
list_integers_input = [1, 2, 3]

with pytest.raises(ValueError, match="SentenceTransformersTextEmbedder expects a list of strings as input"):
embedder.run(texts=string_input)

with pytest.raises(ValueError, match="SentenceTransformersTextEmbedder expects a list of strings as input"):
embedder.run(texts=list_integers_input)

0 comments on commit cda6693

Please sign in to comment.