Skip to content

Commit

Permalink
Merge branch 'main' into chatgpt-llm-generator
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed Sep 5, 2023
2 parents 9eb7900 + d540883 commit 13104de
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 34 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/examples_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ on:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }}
SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
HUGGINGFACE_API_KEY: ${{ secrets.HUGGINGFACE_API_KEY }}
PYTHON_VERSION: "3.8"

jobs:
Expand Down
12 changes: 6 additions & 6 deletions haystack/nodes/ranker/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@

class CohereRanker(BaseRanker):
"""
Re-Ranking can be used on top of a retriever to boost the performance for document search.
This is particularly useful if the retriever has a high recall but is bad in sorting the documents by relevance.
You can use re-ranking on top of a Retriever to boost the performance for document search.
This is particularly useful if the Retriever has a high recall but is bad in sorting the documents by relevance.
Cohere models are trained with a context length of 510 tokens - the model takes into account both the input
Cohere models are trained with a context length of 512 tokens - the model takes into account both the input
from the query and document. If your query is larger than 256 tokens, it will be truncated to the first 256 tokens.
Cohere breaks down a query-document pair into 510 token chunks. For example, if your query is 50 tokens and your
Cohere breaks down a query-document pair into 512 token chunks. For example, if your query is 50 tokens and your
document is 1024 tokens, your document will be broken into the following chunks:
```bash
relevance_score_1 = <query[0,50], document[0,460]>
Expand All @@ -55,7 +55,7 @@ def __init__(
:param api_key: Cohere API key.
:param model_name_or_path: Cohere model name. Check the list of supported models in the [Cohere documentation](https://docs.cohere.com/docs/models).
:param top_k: The maximum number of documents to return.
:param max_chunks_per_doc: If your document exceeds 512 tokens, this will determine the maximum number of
:param max_chunks_per_doc: If your document exceeds 512 tokens, this determines the maximum number of
chunks a document can be split into. If None, the default of 10 is used.
For example, if your document is 6000 tokens, with the default of 10, the document will be split into 10
chunks each of 512 tokens and the last 880 tokens will be disregarded.
Expand Down Expand Up @@ -190,7 +190,7 @@ def predict_batch(
"""
Use Cohere Reranking endpoint to re-rank the supplied lists of Documents.
Returns a lists of Documents sorted by (desc.) similarity with the corresponding queries.
Returns a lists of Documents sorted by (descending) similarity with the corresponding queries.
- If you provide a list containing a single query...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@component
class AnswersBuilder:
class AnswerBuilder:
"""
A component to parse the output of a Generator to `Answer` objects using regular expressions.
"""
Expand All @@ -32,7 +32,7 @@ def __init__(self, pattern: Optional[str] = None, reference_pattern: Optional[st
Default: `None`.
"""
if pattern:
AnswersBuilder._check_num_groups_in_regex(pattern)
AnswerBuilder._check_num_groups_in_regex(pattern)

self.pattern = pattern
self.reference_pattern = reference_pattern
Expand Down Expand Up @@ -80,7 +80,7 @@ def run(
)

if pattern:
AnswersBuilder._check_num_groups_in_regex(pattern)
AnswerBuilder._check_num_groups_in_regex(pattern)

documents = documents or []
pattern = pattern or self.pattern
Expand All @@ -90,10 +90,10 @@ def run(
for i, (query, reply_list, meta_list) in enumerate(zip(queries, replies, metadata)):
doc_list = documents[i] if i < len(documents) else []

extracted_answer_strings = AnswersBuilder._extract_answer_strings(reply_list, pattern)
extracted_answer_strings = AnswerBuilder._extract_answer_strings(reply_list, pattern)

if doc_list and reference_pattern:
reference_idxs = AnswersBuilder._extract_reference_idxs(reply_list, reference_pattern)
reference_idxs = AnswerBuilder._extract_reference_idxs(reply_list, reference_pattern)
else:
reference_idxs = [[doc_idx for doc_idx, _ in enumerate(doc_list)] for _ in reply_list]

Expand All @@ -120,7 +120,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self, pattern=self.pattern, reference_pattern=self.reference_pattern)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AnswersBuilder":
def from_dict(cls, data: Dict[str, Any]) -> "AnswerBuilder":
"""
Deserialize this component from a dictionary.
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Add the `AnswerBuilder` component for Haystack 2.0 that creates Answer objects from the string output of Generators.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,43 @@
import pytest

from haystack.preview import GeneratedAnswer, Document
from haystack.preview.components.builders.answers_builder import AnswersBuilder
from haystack.preview.components.builders.answer_builder import AnswerBuilder


class TestAnswersBuilder:
class TestAnswerBuilder:
@pytest.mark.unit
def test_to_dict(self):
component = AnswersBuilder()
component = AnswerBuilder()
data = component.to_dict()
assert data == {"type": "AnswersBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}}
assert data == {"type": "AnswerBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}}

@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = AnswersBuilder(pattern="pattern", reference_pattern="reference_pattern")
component = AnswerBuilder(pattern="pattern", reference_pattern="reference_pattern")
data = component.to_dict()
assert data == {
"type": "AnswersBuilder",
"type": "AnswerBuilder",
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
}

@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "AnswersBuilder",
"type": "AnswerBuilder",
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
}
component = AnswersBuilder.from_dict(data)
component = AnswerBuilder.from_dict(data)
assert component.pattern == "pattern"
assert component.reference_pattern == "reference_pattern"

@pytest.mark.unit
def test_run_unmatching_input_len(self):
component = AnswersBuilder()
component = AnswerBuilder()
with pytest.raises(ValueError):
component.run(queries=["query"], replies=[["reply1"], ["reply2"]], metadata=[[]])

def test_run_without_pattern(self):
component = AnswersBuilder()
component = AnswerBuilder()
answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
assert len(answers) == 1
assert len(answers[0]) == 1
Expand All @@ -50,7 +50,7 @@ def test_run_without_pattern(self):
assert isinstance(answers[0][0], GeneratedAnswer)

def test_run_with_pattern_with_capturing_group(self):
component = AnswersBuilder(pattern=r"Answer: (.*)")
component = AnswerBuilder(pattern=r"Answer: (.*)")
answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
assert len(answers) == 1
assert len(answers[0]) == 1
Expand All @@ -61,7 +61,7 @@ def test_run_with_pattern_with_capturing_group(self):
assert isinstance(answers[0][0], GeneratedAnswer)

def test_run_with_pattern_without_capturing_group(self):
component = AnswersBuilder(pattern=r"'.*'")
component = AnswerBuilder(pattern=r"'.*'")
answers = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]])
assert len(answers) == 1
assert len(answers[0]) == 1
Expand All @@ -73,10 +73,10 @@ def test_run_with_pattern_without_capturing_group(self):

def test_run_with_pattern_with_more_than_one_capturing_group(self):
with pytest.raises(ValueError, match="contains multiple capture groups"):
component = AnswersBuilder(pattern=r"Answer: (.*), (.*)")
component = AnswerBuilder(pattern=r"Answer: (.*), (.*)")

def test_run_with_pattern_set_at_runtime(self):
component = AnswersBuilder(pattern="unused pattern")
component = AnswerBuilder(pattern="unused pattern")
answers = component.run(
queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)"
)
Expand All @@ -89,7 +89,7 @@ def test_run_with_pattern_set_at_runtime(self):
assert isinstance(answers[0][0], GeneratedAnswer)

def test_run_with_documents_without_reference_pattern(self):
component = AnswersBuilder()
component = AnswerBuilder()
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString"]],
Expand All @@ -106,7 +106,7 @@ def test_run_with_documents_without_reference_pattern(self):
assert answers[0][0].documents[1].content == "test doc 2"

def test_run_with_documents_with_reference_pattern(self):
component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]")
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[2]"]],
Expand All @@ -122,7 +122,7 @@ def test_run_with_documents_with_reference_pattern(self):
assert answers[0][0].documents[0].content == "test doc 2"

def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog):
component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]")
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
with caplog.at_level(logging.WARNING):
answers = component.run(
queries=["test query"],
Expand All @@ -139,7 +139,7 @@ def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog):
assert "Document index '3' referenced in Generator output is out of range." in caplog.text

def test_run_with_reference_pattern_set_at_runtime(self):
component = AnswersBuilder(reference_pattern="unused pattern")
component = AnswerBuilder(reference_pattern="unused pattern")
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[2][3]"]],
Expand Down

0 comments on commit 13104de

Please sign in to comment.