-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add LostInTheMiddleRanker #5457
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
fa92b88
Add lost in the middle ranker
vblagoje c5660bf
Add release note
vblagoje 1a853fd
PR review: implement top_k correctly, improve and add new unit tests
vblagoje b72cd34
Use ranker variable name in tests
vblagoje a16219d
Add invalid top_k handling, unit tests
vblagoje 5bc429f
Minor fixes
vblagoje 10a92c2
Julian's feedback: more precise version of truncate
vblagoje c5fcc30
Better comments for the litm algorithm
vblagoje 32d06e6
Update examples/web_lfqa_improved.py
vblagoje 2c46028
Update docs
dfokina 918971e
Update lg in docstrings
dfokina e135303
Sebastian PR feedback
vblagoje d0d1484
Add check for invalid values of word_count_threshold
vblagoje 1d8ca68
Remove _truncate as it is not needed any more
vblagoje File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from typing import Optional, Union, List | ||
import logging | ||
|
||
from haystack.schema import Document | ||
from haystack.nodes.ranker.base import BaseRanker | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LostInTheMiddleRanker(BaseRanker): | ||
""" | ||
The LostInTheMiddleRanker implements a ranker that reorders documents based on the "lost in the middle" order. | ||
"Lost in the Middle: How Language Models Use Long Contexts" paper by Liu et al. aims to lay out paragraphs into LLM | ||
context so that the relevant paragraphs are at the beginning or end of the input context, while the least relevant | ||
information is in the middle of the context. | ||
|
||
See https://arxiv.org/abs/2307.03172 for more details. | ||
""" | ||
|
||
def __init__(self, word_count_threshold: Optional[int] = None, top_k: Optional[int] = None): | ||
""" | ||
Creates an instance of LostInTheMiddleRanker. | ||
|
||
If 'word_count_threshold' is specified, this ranker includes all documents up until the point where adding | ||
another document would exceed the 'word_count_threshold'. The last document that causes the threshold to | ||
be breached will be included in the resulting list of documents, but all subsequent documents will be | ||
discarded. | ||
|
||
:param word_count_threshold: The maximum total number of words across all documents selected by the ranker. | ||
:param top_k: The maximum number of documents to return. | ||
""" | ||
super().__init__() | ||
if isinstance(word_count_threshold, int) and word_count_threshold <= 0: | ||
raise ValueError( | ||
f"Invalid value for word_count_threshold: {word_count_threshold}. " | ||
f"word_count_threshold must be a positive integer." | ||
) | ||
self.word_count_threshold = word_count_threshold | ||
self.top_k = top_k | ||
|
||
def reorder_documents(self, documents: List[Document]) -> List[Document]: | ||
""" | ||
Ranks documents based on the "lost in the middle" order. Assumes that all documents are ordered by relevance. | ||
|
||
:param documents: List of Documents to merge. | ||
:return: Documents in the "lost in the middle" order. | ||
""" | ||
|
||
# Return empty list if no documents are provided | ||
if not documents: | ||
return [] | ||
|
||
# If there's only one document, return it as is | ||
if len(documents) == 1: | ||
return documents | ||
|
||
# Raise an error if any document is not textual | ||
if any(not doc.content_type == "text" for doc in documents): | ||
raise ValueError("Some provided documents are not textual; LostInTheMiddleRanker can process only text.") | ||
|
||
# Initialize word count and indices for the "lost in the middle" order | ||
word_count = 0 | ||
document_index = list(range(len(documents))) | ||
lost_in_the_middle_indices = [0] | ||
|
||
# If word count threshold is set, calculate word count for the first document | ||
if self.word_count_threshold: | ||
word_count = len(documents[0].content.split()) | ||
|
||
# If the first document already meets the word count threshold, return it | ||
if word_count >= self.word_count_threshold: | ||
return [documents[0]] | ||
|
||
# Start from the second document and create "lost in the middle" order | ||
for doc_idx in document_index[1:]: | ||
vblagoje marked this conversation as resolved.
Show resolved
Hide resolved
julian-risch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Calculate the index at which the current document should be inserted | ||
insertion_index = len(lost_in_the_middle_indices) // 2 + len(lost_in_the_middle_indices) % 2 | ||
|
||
# Insert the document index at the calculated position | ||
lost_in_the_middle_indices.insert(insertion_index, doc_idx) | ||
|
||
# If word count threshold is set, calculate the total word count | ||
if self.word_count_threshold: | ||
word_count += len(documents[doc_idx].content.split()) | ||
|
||
# If the total word count meets the threshold, stop processing further documents | ||
if word_count >= self.word_count_threshold: | ||
break | ||
|
||
# Return the documents in the "lost in the middle" order | ||
return [documents[idx] for idx in lost_in_the_middle_indices] | ||
|
||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]: | ||
""" | ||
Reranks documents based on the "lost in the middle" order. | ||
|
||
:param query: The query to reorder documents for (ignored). | ||
:param documents: List of Documents to reorder. | ||
:param top_k: The number of documents to return. | ||
|
||
:return: The reordered documents. | ||
""" | ||
top_k = top_k or self.top_k | ||
documents_to_reorder = documents[:top_k] if top_k else documents | ||
ranked_docs = self.reorder_documents(documents=documents_to_reorder) | ||
return ranked_docs | ||
|
||
def predict_batch( | ||
self, | ||
queries: List[str], | ||
documents: Union[List[Document], List[List[Document]]], | ||
top_k: Optional[int] = None, | ||
batch_size: Optional[int] = None, | ||
) -> Union[List[Document], List[List[Document]]]: | ||
""" | ||
Reranks batch of documents based on the "lost in the middle" order. | ||
|
||
:param queries: The queries to reorder documents for (ignored). | ||
:param documents: List of Documents to reorder. | ||
:param top_k: The number of documents to return. | ||
:param batch_size: The number of queries to process in one batch (ignored). | ||
|
||
:return: The reordered documents. | ||
""" | ||
if len(documents) > 0 and isinstance(documents[0], Document): | ||
return self.predict(query="", documents=documents, top_k=top_k) # type: ignore | ||
else: | ||
# Docs case 2: list of lists of Documents -> rerank each list of Documents | ||
results = [] | ||
for cur_docs in documents: | ||
assert isinstance(cur_docs, list) | ||
results.append(self.predict(query="", documents=cur_docs, top_k=top_k)) | ||
return results |
18 changes: 18 additions & 0 deletions
18
releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
--- | ||
prelude: > | ||
We're excited to introduce a new ranker to Haystack - LostInTheMiddleRanker. | ||
It reorders documents based on the "Lost in the Middle" order, a strategy that | ||
places the most relevant paragraphs at the beginning or end of the context, | ||
while less relevant paragraphs are positioned in the middle. This ranker, | ||
based on the research paper "Lost in the Middle: How Language Models Use Long | ||
Contexts" by Liu et al., can be leveraged in Retrieval-Augmented Generation | ||
(RAG) pipelines. | ||
features: | ||
- | | ||
The LostInTheMiddleRanker can be used like other rankers in Haystack. After | ||
vblagoje marked this conversation as resolved.
Show resolved
Hide resolved
|
||
initializing LostInTheMiddleRanker with the desired parameters, it can be | ||
used to rank/reorder a list of documents based on the "Lost in the Middle" | ||
order - the most relevant documents are located at the top and bottom of | ||
the returned list, while the least relevant documents are found in the | ||
middle. We advise that you use this ranker in combination with other rankers, | ||
and to place it towards the end of the pipeline. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import pytest | ||
|
||
from haystack import Document | ||
from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker | ||
|
||
|
||
@pytest.mark.unit | ||
def test_lost_in_the_middle_order_odd(): | ||
# tests that lost_in_the_middle order works with an odd number of documents | ||
docs = [Document(str(i)) for i in range(1, 10)] | ||
ranker = LostInTheMiddleRanker() | ||
result, _ = ranker.run(query="", documents=docs) | ||
assert result["documents"] | ||
expected_order = "1 3 5 7 9 8 6 4 2".split() | ||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_batch_lost_in_the_middle_order(): | ||
# tests that lost_in_the_middle order works with a batch of documents | ||
docs = [ | ||
[Document("1"), Document("2"), Document("3"), Document("4")], | ||
[Document("5"), Document("6")], | ||
[Document("7"), Document("8"), Document("9")], | ||
] | ||
ranker = LostInTheMiddleRanker() | ||
result, _ = ranker.run_batch(queries=[""], documents=docs) | ||
|
||
assert " ".join(doc.content for doc in result["documents"][0]) == "1 3 4 2" | ||
assert " ".join(doc.content for doc in result["documents"][1]) == "5 6" | ||
assert " ".join(doc.content for doc in result["documents"][2]) == "7 9 8" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_lost_in_the_middle_order_even(): | ||
# tests that lost_in_the_middle order works with an even number of documents | ||
docs = [Document(str(i)) for i in range(1, 11)] | ||
ranker = LostInTheMiddleRanker() | ||
result, _ = ranker.run(query="", documents=docs) | ||
expected_order = "1 3 5 7 9 10 8 6 4 2".split() | ||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_lost_in_the_middle_order_two_docs(): | ||
# tests that lost_in_the_middle order works with two documents | ||
ranker = LostInTheMiddleRanker() | ||
|
||
# two docs | ||
docs = [Document("1"), Document("2")] | ||
result, _ = ranker.run(query="", documents=docs) | ||
assert result["documents"][0].content == "1" | ||
assert result["documents"][1].content == "2" | ||
|
||
|
||
@pytest.mark.unit | ||
def test_lost_in_the_middle_init(): | ||
# tests that LostInTheMiddleRanker initializes with default values | ||
ranker = LostInTheMiddleRanker() | ||
assert ranker.word_count_threshold is None | ||
|
||
ranker = LostInTheMiddleRanker(word_count_threshold=10) | ||
assert ranker.word_count_threshold == 10 | ||
|
||
|
||
@pytest.mark.unit | ||
def test_lost_in_the_middle_init_invalid_word_count_threshold(): | ||
# tests that LostInTheMiddleRanker raises an error when word_count_threshold is <= 0 | ||
with pytest.raises(ValueError, match="Invalid value for word_count_threshold"): | ||
LostInTheMiddleRanker(word_count_threshold=0) | ||
|
||
with pytest.raises(ValueError, match="Invalid value for word_count_threshold"): | ||
LostInTheMiddleRanker(word_count_threshold=-5) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_lost_in_the_middle_with_word_count_threshold(): | ||
# tests that lost_in_the_middle with word_count_threshold works as expected | ||
ranker = LostInTheMiddleRanker(word_count_threshold=6) | ||
docs = [Document("word" + str(i)) for i in range(1, 10)] | ||
result, _ = ranker.run(query="", documents=docs) | ||
expected_order = "word1 word3 word5 word6 word4 word2".split() | ||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) | ||
|
||
ranker = LostInTheMiddleRanker(word_count_threshold=9) | ||
result, _ = ranker.run(query="", documents=docs) | ||
expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split() | ||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_word_count_threshold_greater_than_total_number_of_words_returns_all_documents(): | ||
ranker = LostInTheMiddleRanker(word_count_threshold=100) | ||
docs = [Document("word" + str(i)) for i in range(1, 10)] | ||
ordered_docs = ranker.predict(query="test", documents=docs) | ||
assert len(ordered_docs) == len(docs) | ||
expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split() | ||
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs)) | ||
|
||
|
||
@pytest.mark.unit | ||
def test_empty_documents_returns_empty_list(): | ||
ranker = LostInTheMiddleRanker() | ||
assert ranker.predict(query="test", documents=[]) == [] | ||
|
||
|
||
@pytest.mark.unit | ||
def test_list_of_one_document_returns_same_document(): | ||
ranker = LostInTheMiddleRanker() | ||
doc = Document(content="test", content_type="text") | ||
assert ranker.predict(query="test", documents=[doc]) == [doc] | ||
|
||
|
||
@pytest.mark.unit | ||
def test_non_textual_documents(): | ||
# tests that merging a list of non-textual documents raises a ValueError | ||
ranker = LostInTheMiddleRanker() | ||
doc1 = Document(content="This is a textual document.") | ||
doc2 = Document(content_type="image", content="This is a non-textual document.") | ||
with pytest.raises(ValueError, match="Some provided documents are not textual"): | ||
ranker.reorder_documents([doc1, doc2]) | ||
|
||
|
||
@pytest.mark.unit | ||
@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8, 12, 20]) | ||
def test_lost_in_the_middle_order_with_postive_top_k(top_k: int): | ||
# tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter | ||
docs = [Document(str(i)) for i in range(1, 10)] | ||
ranker = LostInTheMiddleRanker() | ||
result = ranker.predict(query="irrelevant", documents=docs, top_k=top_k) | ||
if top_k < len(docs): | ||
# top_k is less than the number of documents, so only the top_k documents should be returned in LITM order | ||
assert len(result) == top_k | ||
expected_order = ranker.predict(query="irrelevant", documents=[Document(str(i)) for i in range(1, top_k + 1)]) | ||
assert result == expected_order | ||
else: | ||
# top_k is greater than the number of documents, so all documents should be returned in LITM order | ||
assert len(result) == len(docs) | ||
assert result == ranker.predict(query="irrelevant", documents=docs) | ||
|
||
|
||
@pytest.mark.unit | ||
@pytest.mark.parametrize("top_k", [-20, -10, -5, -1]) | ||
def test_lost_in_the_middle_order_with_negative_top_k(top_k: int): | ||
# tests that lost_in_the_middle order works with an odd number of documents and an invalid top_k parameter | ||
docs = [Document(str(i)) for i in range(1, 10)] | ||
ranker = LostInTheMiddleRanker() | ||
result = ranker.predict(query="irrelevant", documents=docs, top_k=top_k) | ||
if top_k < len(docs) * -1: | ||
assert len(result) == 0 # top_k is too negative, so no documents should be returned | ||
else: | ||
# top_k is negative, subtract it from the total number of documents to get the expected number of documents | ||
expected_docs = ranker.predict(query="irrelevant", documents=docs, top_k=len(docs) + top_k) | ||
assert result == expected_docs |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether we can come up with a name that expresses that relevant documents will be ranked so that they end up at the top and bottom and irrelevant documents end up in the middle.
LITM is a very technical name that rarely anyone using Haystack will know. I don't have a better idea yet. Maybe you can come up with something together with @dfokina ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually like LITM name because it is unique and well descriptive and associative with that now "famous" paper. So I'd keep it if there are no objections...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the LITM name in a sense that it is a clear reference to the paper. I don't think that the users would be specifically searching for this type of ranking without understanding the background for it?
Anyway, I tried to think about possible alternatives, but no good ideas so far. They're coming up pretty bad (IrrelevantMiddleRanker, TopBottomRanker, IrrelevantDocsInTheMiddleRanker...)