-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into add_exact_match_metric
- Loading branch information
Showing
31 changed files
with
866 additions
and
229 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import pytest | ||
|
||
from haystack import Document, Pipeline, ComponentError | ||
from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend | ||
|
||
|
||
@pytest.fixture | ||
def raw_texts(): | ||
return [ | ||
"My name is Clara and I live in Berkeley, California.", | ||
"I'm Merlin, the happy pig!", | ||
"New York State declared a state of emergency after the announcement of the end of the world.", | ||
"", # Intentionally empty. | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def hf_annotations(): | ||
return [ | ||
[ | ||
NamedEntityAnnotation(entity="PER", start=11, end=16), | ||
NamedEntityAnnotation(entity="LOC", start=31, end=39), | ||
NamedEntityAnnotation(entity="LOC", start=41, end=51), | ||
], | ||
[NamedEntityAnnotation(entity="PER", start=4, end=10)], | ||
[NamedEntityAnnotation(entity="LOC", start=0, end=14)], | ||
[], | ||
] | ||
|
||
|
||
@pytest.fixture | ||
def spacy_annotations(): | ||
return [ | ||
[ | ||
NamedEntityAnnotation(entity="PERSON", start=11, end=16), | ||
NamedEntityAnnotation(entity="GPE", start=31, end=39), | ||
NamedEntityAnnotation(entity="GPE", start=41, end=51), | ||
], | ||
[NamedEntityAnnotation(entity="PERSON", start=4, end=10)], | ||
[NamedEntityAnnotation(entity="GPE", start=0, end=14)], | ||
[], | ||
] | ||
|
||
|
||
def test_ner_extractor_init(): | ||
extractor = NamedEntityExtractor( | ||
backend=NamedEntityExtractorBackend.HUGGING_FACE, model_name_or_path="dslim/bert-base-NER", device_id=-1 | ||
) | ||
|
||
with pytest.raises(ComponentError, match=r"not initialized"): | ||
extractor.run(documents=[]) | ||
|
||
assert not extractor.initialized | ||
extractor.warm_up() | ||
assert extractor.initialized | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [1, 3]) | ||
def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size): | ||
extractor = NamedEntityExtractor( | ||
backend=NamedEntityExtractorBackend.HUGGING_FACE, model_name_or_path="dslim/bert-base-NER" | ||
) | ||
extractor.warm_up() | ||
|
||
_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size) | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [1, 3]) | ||
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size): | ||
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model_name_or_path="en_core_web_trf") | ||
extractor.warm_up() | ||
|
||
_extract_and_check_predictions(extractor, raw_texts, spacy_annotations, batch_size) | ||
|
||
|
||
@pytest.mark.parametrize("batch_size", [1, 3]) | ||
def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size): | ||
pipeline = Pipeline() | ||
pipeline.add_component( | ||
name="ner_extractor", | ||
instance=NamedEntityExtractor( | ||
backend=NamedEntityExtractorBackend.HUGGING_FACE, model_name_or_path="dslim/bert-base-NER" | ||
), | ||
) | ||
|
||
outputs = pipeline.run( | ||
{"ner_extractor": {"documents": [Document(content=text) for text in raw_texts], "batch_size": batch_size}} | ||
)["ner_extractor"]["documents"] | ||
predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs] | ||
_check_predictions(predicted, hf_annotations) | ||
|
||
|
||
def _extract_and_check_predictions(extractor, texts, expected, batch_size): | ||
docs = [Document(content=text) for text in texts] | ||
outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"] | ||
assert all(id(a) == id(b) for a, b in zip(docs, outputs)) | ||
predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs] | ||
|
||
_check_predictions(predicted, expected) | ||
|
||
|
||
def _check_predictions(predicted, expected): | ||
assert len(predicted) == len(expected) | ||
for pred, exp in zip(predicted, expected): | ||
assert len(pred) == len(exp) | ||
|
||
for a, b in zip(pred, exp): | ||
assert a.entity == b.entity | ||
assert a.start == b.start | ||
assert a.end == b.end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Dict, Any | ||
from pathlib import Path | ||
from datetime import datetime | ||
|
||
from haystack import Pipeline | ||
from haystack.components.others import Multiplexer | ||
from haystack.components.converters import PyPDFToDocument, TextFileToDocument | ||
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter | ||
from haystack.components.routers import FileTypeRouter, DocumentJoiner | ||
from haystack.components.writers import DocumentWriter | ||
from haystack.document_stores import InMemoryDocumentStore | ||
|
||
|
||
document_store = InMemoryDocumentStore() | ||
|
||
p = Pipeline() | ||
p.add_component(instance=FileTypeRouter(mime_types=["text/plain", "application/pdf"]), name="file_type_router") | ||
p.add_component(instance=Multiplexer(Dict[str, Any]), name="metadata_multiplexer") | ||
p.add_component(instance=TextFileToDocument(), name="text_file_converter") | ||
p.add_component(instance=PyPDFToDocument(), name="pdf_file_converter") | ||
p.add_component(instance=DocumentJoiner(), name="joiner") | ||
p.add_component(instance=DocumentCleaner(), name="cleaner") | ||
p.add_component(instance=DocumentSplitter(split_by="sentence", split_length=250, split_overlap=30), name="splitter") | ||
p.add_component(instance=DocumentWriter(document_store=document_store), name="writer") | ||
|
||
p.connect("file_type_router.text/plain", "text_file_converter.sources") | ||
p.connect("file_type_router.application/pdf", "pdf_file_converter.sources") | ||
p.connect("metadata_multiplexer", "text_file_converter.meta") | ||
p.connect("metadata_multiplexer", "pdf_file_converter.meta") | ||
p.connect("text_file_converter.documents", "joiner.documents") | ||
p.connect("pdf_file_converter.documents", "joiner.documents") | ||
p.connect("joiner.documents", "cleaner.documents") | ||
p.connect("cleaner.documents", "splitter.documents") | ||
p.connect("splitter.documents", "writer.documents") | ||
|
||
result = p.run( | ||
{ | ||
"file_type_router": {"sources": list(Path(".").iterdir())}, | ||
"metadata_multiplexer": {"value": {"date_added": datetime.now().isoformat()}}, | ||
} | ||
) | ||
|
||
assert all("date_added" in doc.meta for doc in document_store.filter_documents()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from haystack.components.extractors.named_entity_extractor import ( | ||
NamedEntityAnnotation, | ||
NamedEntityExtractor, | ||
NamedEntityExtractorBackend, | ||
) | ||
|
||
__all__ = ["NamedEntityExtractor", "NamedEntityExtractorBackend", "NamedEntityAnnotation"] |
Oops, something went wrong.