Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add LostInTheMiddleRanker #5457

Merged
merged 14 commits into from
Aug 2, 2023
2 changes: 1 addition & 1 deletion docs/pydoc/config/ranker.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
loaders:
- type: python
search_path: [../../../haystack/nodes/ranker]
modules: ["base", "sentence_transformers", "recentness_ranker", "diversity"]
modules: ["base", "sentence_transformers", "recentness_ranker", "diversity", "lost_in_the_middle"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
19 changes: 10 additions & 9 deletions examples/web_lfqa_improved.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os

from haystack import Pipeline
from haystack.nodes import PromptNode, PromptTemplate, TopPSampler, DocumentMerger
from haystack.nodes import PromptNode, PromptTemplate, TopPSampler
from haystack.nodes.ranker.diversity import DiversityRanker
from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker
from haystack.nodes.retriever.web import WebRetriever

search_key = os.environ.get("SERPERDEV_API_KEY")
Expand All @@ -22,21 +23,21 @@
"""

prompt_node = PromptNode(
"gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=256
"gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=768
)

web_retriever = WebRetriever(api_key=search_key, top_search_results=10, mode="preprocessed_documents", top_k=25)
web_retriever = WebRetriever(api_key=search_key, top_search_results=5, mode="preprocessed_documents", top_k=50)

sampler = TopPSampler(top_p=0.95)
ranker = DiversityRanker()
merger = DocumentMerger(separator="\n\n")
sampler = TopPSampler(top_p=0.97)
diversity_ranker = DiversityRanker()
litm_ranker = LostInTheMiddleRanker(word_count_threshold=1024)

pipeline = Pipeline()
pipeline.add_node(component=web_retriever, name="Retriever", inputs=["Query"])
pipeline.add_node(component=sampler, name="Sampler", inputs=["Retriever"])
pipeline.add_node(component=ranker, name="Ranker", inputs=["Sampler"])
pipeline.add_node(component=merger, name="Merger", inputs=["Ranker"])
pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["Merger"])
pipeline.add_node(component=diversity_ranker, name="DiversityRanker", inputs=["Sampler"])
pipeline.add_node(component=litm_ranker, name="LostInTheMiddleRanker", inputs=["DiversityRanker"])
pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["LostInTheMiddleRanker"])

logger = logging.getLogger("boilerpy3")
logger.setLevel(logging.CRITICAL)
Expand Down
133 changes: 133 additions & 0 deletions haystack/nodes/ranker/lost_in_the_middle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Optional, Union, List
import logging

from haystack.schema import Document
from haystack.nodes.ranker.base import BaseRanker

logger = logging.getLogger(__name__)


class LostInTheMiddleRanker(BaseRanker):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we can come up with a name that expresses that relevant documents will be ranked so that they end up at the top and bottom and irrelevant documents end up in the middle.
LITM is a very technical name that rarely anyone using Haystack will know. I don't have a better idea yet. Maybe you can come up with something together with @dfokina ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually like LITM name because it is unique and well descriptive and associative with that now "famous" paper. So I'd keep it if there are no objections...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the LITM name in a sense that it is a clear reference to the paper. I don't think that the users would be specifically searching for this type of ranking without understanding the background for it?

Anyway, I tried to think about possible alternatives, but no good ideas so far. They're coming up pretty bad (IrrelevantMiddleRanker, TopBottomRanker, IrrelevantDocsInTheMiddleRanker...)

"""
The LostInTheMiddleRanker implements a ranker that reorders documents based on the "lost in the middle" order.
"Lost in the Middle: How Language Models Use Long Contexts" paper by Liu et al. aims to lay out paragraphs into LLM
context so that the relevant paragraphs are at the beginning or end of the input context, while the least relevant
information is in the middle of the context.

See https://arxiv.org/abs/2307.03172 for more details.
"""

def __init__(self, word_count_threshold: Optional[int] = None, top_k: Optional[int] = None):
"""
Creates an instance of LostInTheMiddleRanker.

If 'word_count_threshold' is specified, this ranker includes all documents up until the point where adding
another document would exceed the 'word_count_threshold'. The last document that causes the threshold to
be breached will be included in the resulting list of documents, but all subsequent documents will be
discarded.

:param word_count_threshold: The maximum total number of words across all documents selected by the ranker.
:param top_k: The maximum number of documents to return.
"""
super().__init__()
if isinstance(word_count_threshold, int) and word_count_threshold <= 0:
raise ValueError(
f"Invalid value for word_count_threshold: {word_count_threshold}. "
f"word_count_threshold must be a positive integer."
)
self.word_count_threshold = word_count_threshold
self.top_k = top_k

def reorder_documents(self, documents: List[Document]) -> List[Document]:
"""
Ranks documents based on the "lost in the middle" order. Assumes that all documents are ordered by relevance.

:param documents: List of Documents to merge.
:return: Documents in the "lost in the middle" order.
"""

# Return empty list if no documents are provided
if not documents:
return []

# If there's only one document, return it as is
if len(documents) == 1:
return documents

# Raise an error if any document is not textual
if any(not doc.content_type == "text" for doc in documents):
raise ValueError("Some provided documents are not textual; LostInTheMiddleRanker can process only text.")

# Initialize word count and indices for the "lost in the middle" order
word_count = 0
document_index = list(range(len(documents)))
lost_in_the_middle_indices = [0]

# If word count threshold is set, calculate word count for the first document
if self.word_count_threshold:
word_count = len(documents[0].content.split())

# If the first document already meets the word count threshold, return it
if word_count >= self.word_count_threshold:
return [documents[0]]

# Start from the second document and create "lost in the middle" order
for doc_idx in document_index[1:]:
vblagoje marked this conversation as resolved.
Show resolved Hide resolved
julian-risch marked this conversation as resolved.
Show resolved Hide resolved
# Calculate the index at which the current document should be inserted
insertion_index = len(lost_in_the_middle_indices) // 2 + len(lost_in_the_middle_indices) % 2

# Insert the document index at the calculated position
lost_in_the_middle_indices.insert(insertion_index, doc_idx)

# If word count threshold is set, calculate the total word count
if self.word_count_threshold:
word_count += len(documents[doc_idx].content.split())

# If the total word count meets the threshold, stop processing further documents
if word_count >= self.word_count_threshold:
break

# Return the documents in the "lost in the middle" order
return [documents[idx] for idx in lost_in_the_middle_indices]

def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]:
"""
Reranks documents based on the "lost in the middle" order.

:param query: The query to reorder documents for (ignored).
:param documents: List of Documents to reorder.
:param top_k: The number of documents to return.

:return: The reordered documents.
"""
top_k = top_k or self.top_k
documents_to_reorder = documents[:top_k] if top_k else documents
ranked_docs = self.reorder_documents(documents=documents_to_reorder)
return ranked_docs

def predict_batch(
self,
queries: List[str],
documents: Union[List[Document], List[List[Document]]],
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Union[List[Document], List[List[Document]]]:
"""
Reranks batch of documents based on the "lost in the middle" order.

:param queries: The queries to reorder documents for (ignored).
:param documents: List of Documents to reorder.
:param top_k: The number of documents to return.
:param batch_size: The number of queries to process in one batch (ignored).

:return: The reordered documents.
"""
if len(documents) > 0 and isinstance(documents[0], Document):
return self.predict(query="", documents=documents, top_k=top_k) # type: ignore
else:
# Docs case 2: list of lists of Documents -> rerank each list of Documents
results = []
for cur_docs in documents:
assert isinstance(cur_docs, list)
results.append(self.predict(query="", documents=cur_docs, top_k=top_k))
return results
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
---
prelude: >
We're excited to introduce a new ranker to Haystack - LostInTheMiddleRanker.
It reorders documents based on the "Lost in the Middle" order, a strategy that
places the most relevant paragraphs at the beginning or end of the context,
while less relevant paragraphs are positioned in the middle. This ranker,
based on the research paper "Lost in the Middle: How Language Models Use Long
Contexts" by Liu et al., can be leveraged in Retrieval-Augmented Generation
(RAG) pipelines.
features:
- |
The LostInTheMiddleRanker can be used like other rankers in Haystack. After
vblagoje marked this conversation as resolved.
Show resolved Hide resolved
initializing LostInTheMiddleRanker with the desired parameters, it can be
used to rank/reorder a list of documents based on the "Lost in the Middle"
order - the most relevant documents are located at the top and bottom of
the returned list, while the least relevant documents are found in the
middle. We advise that you use this ranker in combination with other rankers,
and to place it towards the end of the pipeline.
154 changes: 154 additions & 0 deletions test/nodes/test_lost_in_the_middle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import pytest

from haystack import Document
from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker


@pytest.mark.unit
def test_lost_in_the_middle_order_odd():
# tests that lost_in_the_middle order works with an odd number of documents
docs = [Document(str(i)) for i in range(1, 10)]
ranker = LostInTheMiddleRanker()
result, _ = ranker.run(query="", documents=docs)
assert result["documents"]
expected_order = "1 3 5 7 9 8 6 4 2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))


@pytest.mark.unit
def test_batch_lost_in_the_middle_order():
# tests that lost_in_the_middle order works with a batch of documents
docs = [
[Document("1"), Document("2"), Document("3"), Document("4")],
[Document("5"), Document("6")],
[Document("7"), Document("8"), Document("9")],
]
ranker = LostInTheMiddleRanker()
result, _ = ranker.run_batch(queries=[""], documents=docs)

assert " ".join(doc.content for doc in result["documents"][0]) == "1 3 4 2"
assert " ".join(doc.content for doc in result["documents"][1]) == "5 6"
assert " ".join(doc.content for doc in result["documents"][2]) == "7 9 8"


@pytest.mark.unit
def test_lost_in_the_middle_order_even():
# tests that lost_in_the_middle order works with an even number of documents
docs = [Document(str(i)) for i in range(1, 11)]
ranker = LostInTheMiddleRanker()
result, _ = ranker.run(query="", documents=docs)
expected_order = "1 3 5 7 9 10 8 6 4 2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))


@pytest.mark.unit
def test_lost_in_the_middle_order_two_docs():
# tests that lost_in_the_middle order works with two documents
ranker = LostInTheMiddleRanker()

# two docs
docs = [Document("1"), Document("2")]
result, _ = ranker.run(query="", documents=docs)
assert result["documents"][0].content == "1"
assert result["documents"][1].content == "2"


@pytest.mark.unit
def test_lost_in_the_middle_init():
# tests that LostInTheMiddleRanker initializes with default values
ranker = LostInTheMiddleRanker()
assert ranker.word_count_threshold is None

ranker = LostInTheMiddleRanker(word_count_threshold=10)
assert ranker.word_count_threshold == 10


@pytest.mark.unit
def test_lost_in_the_middle_init_invalid_word_count_threshold():
# tests that LostInTheMiddleRanker raises an error when word_count_threshold is <= 0
with pytest.raises(ValueError, match="Invalid value for word_count_threshold"):
LostInTheMiddleRanker(word_count_threshold=0)

with pytest.raises(ValueError, match="Invalid value for word_count_threshold"):
LostInTheMiddleRanker(word_count_threshold=-5)


@pytest.mark.unit
def test_lost_in_the_middle_with_word_count_threshold():
# tests that lost_in_the_middle with word_count_threshold works as expected
ranker = LostInTheMiddleRanker(word_count_threshold=6)
docs = [Document("word" + str(i)) for i in range(1, 10)]
result, _ = ranker.run(query="", documents=docs)
expected_order = "word1 word3 word5 word6 word4 word2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))

ranker = LostInTheMiddleRanker(word_count_threshold=9)
result, _ = ranker.run(query="", documents=docs)
expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"]))


@pytest.mark.unit
def test_word_count_threshold_greater_than_total_number_of_words_returns_all_documents():
ranker = LostInTheMiddleRanker(word_count_threshold=100)
docs = [Document("word" + str(i)) for i in range(1, 10)]
ordered_docs = ranker.predict(query="test", documents=docs)
assert len(ordered_docs) == len(docs)
expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split()
assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs))


@pytest.mark.unit
def test_empty_documents_returns_empty_list():
ranker = LostInTheMiddleRanker()
assert ranker.predict(query="test", documents=[]) == []


@pytest.mark.unit
def test_list_of_one_document_returns_same_document():
ranker = LostInTheMiddleRanker()
doc = Document(content="test", content_type="text")
assert ranker.predict(query="test", documents=[doc]) == [doc]


@pytest.mark.unit
def test_non_textual_documents():
# tests that merging a list of non-textual documents raises a ValueError
ranker = LostInTheMiddleRanker()
doc1 = Document(content="This is a textual document.")
doc2 = Document(content_type="image", content="This is a non-textual document.")
with pytest.raises(ValueError, match="Some provided documents are not textual"):
ranker.reorder_documents([doc1, doc2])


@pytest.mark.unit
@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8, 12, 20])
def test_lost_in_the_middle_order_with_postive_top_k(top_k: int):
# tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter
docs = [Document(str(i)) for i in range(1, 10)]
ranker = LostInTheMiddleRanker()
result = ranker.predict(query="irrelevant", documents=docs, top_k=top_k)
if top_k < len(docs):
# top_k is less than the number of documents, so only the top_k documents should be returned in LITM order
assert len(result) == top_k
expected_order = ranker.predict(query="irrelevant", documents=[Document(str(i)) for i in range(1, top_k + 1)])
assert result == expected_order
else:
# top_k is greater than the number of documents, so all documents should be returned in LITM order
assert len(result) == len(docs)
assert result == ranker.predict(query="irrelevant", documents=docs)


@pytest.mark.unit
@pytest.mark.parametrize("top_k", [-20, -10, -5, -1])
def test_lost_in_the_middle_order_with_negative_top_k(top_k: int):
# tests that lost_in_the_middle order works with an odd number of documents and an invalid top_k parameter
docs = [Document(str(i)) for i in range(1, 10)]
ranker = LostInTheMiddleRanker()
result = ranker.predict(query="irrelevant", documents=docs, top_k=top_k)
if top_k < len(docs) * -1:
assert len(result) == 0 # top_k is too negative, so no documents should be returned
else:
# top_k is negative, subtract it from the total number of documents to get the expected number of documents
expected_docs = ranker.predict(query="irrelevant", documents=docs, top_k=len(docs) + top_k)
assert result == expected_docs