diff --git a/docs/pydoc/config/ranker.yml b/docs/pydoc/config/ranker.yml index e0776e7f04..052756615f 100644 --- a/docs/pydoc/config/ranker.yml +++ b/docs/pydoc/config/ranker.yml @@ -1,7 +1,7 @@ loaders: - type: python search_path: [../../../haystack/nodes/ranker] - modules: ["base", "sentence_transformers", "recentness_ranker", "diversity"] + modules: ["base", "sentence_transformers", "recentness_ranker", "diversity", "lost_in_the_middle"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/examples/web_lfqa_improved.py b/examples/web_lfqa_improved.py index c46852617e..cfef824f7c 100644 --- a/examples/web_lfqa_improved.py +++ b/examples/web_lfqa_improved.py @@ -2,8 +2,9 @@ import os from haystack import Pipeline -from haystack.nodes import PromptNode, PromptTemplate, TopPSampler, DocumentMerger +from haystack.nodes import PromptNode, PromptTemplate, TopPSampler from haystack.nodes.ranker.diversity import DiversityRanker +from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker from haystack.nodes.retriever.web import WebRetriever search_key = os.environ.get("SERPERDEV_API_KEY") @@ -22,21 +23,21 @@ """ prompt_node = PromptNode( - "gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=256 + "gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=768 ) -web_retriever = WebRetriever(api_key=search_key, top_search_results=10, mode="preprocessed_documents", top_k=25) +web_retriever = WebRetriever(api_key=search_key, top_search_results=5, mode="preprocessed_documents", top_k=50) -sampler = TopPSampler(top_p=0.95) -ranker = DiversityRanker() -merger = DocumentMerger(separator="\n\n") +sampler = TopPSampler(top_p=0.97) +diversity_ranker = DiversityRanker() +litm_ranker = LostInTheMiddleRanker(word_count_threshold=1024) pipeline = Pipeline() pipeline.add_node(component=web_retriever, name="Retriever", inputs=["Query"]) pipeline.add_node(component=sampler, name="Sampler", inputs=["Retriever"]) -pipeline.add_node(component=ranker, name="Ranker", inputs=["Sampler"]) -pipeline.add_node(component=merger, name="Merger", inputs=["Ranker"]) -pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["Merger"]) +pipeline.add_node(component=diversity_ranker, name="DiversityRanker", inputs=["Sampler"]) +pipeline.add_node(component=litm_ranker, name="LostInTheMiddleRanker", inputs=["DiversityRanker"]) +pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["LostInTheMiddleRanker"]) logger = logging.getLogger("boilerpy3") logger.setLevel(logging.CRITICAL) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py new file mode 100644 index 0000000000..1ffbe5bcd3 --- /dev/null +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -0,0 +1,133 @@ +from typing import Optional, Union, List +import logging + +from haystack.schema import Document +from haystack.nodes.ranker.base import BaseRanker + +logger = logging.getLogger(__name__) + + +class LostInTheMiddleRanker(BaseRanker): + """ + The LostInTheMiddleRanker implements a ranker that reorders documents based on the "lost in the middle" order. + "Lost in the Middle: How Language Models Use Long Contexts" paper by Liu et al. aims to lay out paragraphs into LLM + context so that the relevant paragraphs are at the beginning or end of the input context, while the least relevant + information is in the middle of the context. + + See https://arxiv.org/abs/2307.03172 for more details. + """ + + def __init__(self, word_count_threshold: Optional[int] = None, top_k: Optional[int] = None): + """ + Creates an instance of LostInTheMiddleRanker. + + If 'word_count_threshold' is specified, this ranker includes all documents up until the point where adding + another document would exceed the 'word_count_threshold'. The last document that causes the threshold to + be breached will be included in the resulting list of documents, but all subsequent documents will be + discarded. + + :param word_count_threshold: The maximum total number of words across all documents selected by the ranker. + :param top_k: The maximum number of documents to return. + """ + super().__init__() + if isinstance(word_count_threshold, int) and word_count_threshold <= 0: + raise ValueError( + f"Invalid value for word_count_threshold: {word_count_threshold}. " + f"word_count_threshold must be a positive integer." + ) + self.word_count_threshold = word_count_threshold + self.top_k = top_k + + def reorder_documents(self, documents: List[Document]) -> List[Document]: + """ + Ranks documents based on the "lost in the middle" order. Assumes that all documents are ordered by relevance. + + :param documents: List of Documents to merge. + :return: Documents in the "lost in the middle" order. + """ + + # Return empty list if no documents are provided + if not documents: + return [] + + # If there's only one document, return it as is + if len(documents) == 1: + return documents + + # Raise an error if any document is not textual + if any(not doc.content_type == "text" for doc in documents): + raise ValueError("Some provided documents are not textual; LostInTheMiddleRanker can process only text.") + + # Initialize word count and indices for the "lost in the middle" order + word_count = 0 + document_index = list(range(len(documents))) + lost_in_the_middle_indices = [0] + + # If word count threshold is set, calculate word count for the first document + if self.word_count_threshold: + word_count = len(documents[0].content.split()) + + # If the first document already meets the word count threshold, return it + if word_count >= self.word_count_threshold: + return [documents[0]] + + # Start from the second document and create "lost in the middle" order + for doc_idx in document_index[1:]: + # Calculate the index at which the current document should be inserted + insertion_index = len(lost_in_the_middle_indices) // 2 + len(lost_in_the_middle_indices) % 2 + + # Insert the document index at the calculated position + lost_in_the_middle_indices.insert(insertion_index, doc_idx) + + # If word count threshold is set, calculate the total word count + if self.word_count_threshold: + word_count += len(documents[doc_idx].content.split()) + + # If the total word count meets the threshold, stop processing further documents + if word_count >= self.word_count_threshold: + break + + # Return the documents in the "lost in the middle" order + return [documents[idx] for idx in lost_in_the_middle_indices] + + def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]: + """ + Reranks documents based on the "lost in the middle" order. + + :param query: The query to reorder documents for (ignored). + :param documents: List of Documents to reorder. + :param top_k: The number of documents to return. + + :return: The reordered documents. + """ + top_k = top_k or self.top_k + documents_to_reorder = documents[:top_k] if top_k else documents + ranked_docs = self.reorder_documents(documents=documents_to_reorder) + return ranked_docs + + def predict_batch( + self, + queries: List[str], + documents: Union[List[Document], List[List[Document]]], + top_k: Optional[int] = None, + batch_size: Optional[int] = None, + ) -> Union[List[Document], List[List[Document]]]: + """ + Reranks batch of documents based on the "lost in the middle" order. + + :param queries: The queries to reorder documents for (ignored). + :param documents: List of Documents to reorder. + :param top_k: The number of documents to return. + :param batch_size: The number of queries to process in one batch (ignored). + + :return: The reordered documents. + """ + if len(documents) > 0 and isinstance(documents[0], Document): + return self.predict(query="", documents=documents, top_k=top_k) # type: ignore + else: + # Docs case 2: list of lists of Documents -> rerank each list of Documents + results = [] + for cur_docs in documents: + assert isinstance(cur_docs, list) + results.append(self.predict(query="", documents=cur_docs, top_k=top_k)) + return results diff --git a/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml b/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml new file mode 100644 index 0000000000..2588daf593 --- /dev/null +++ b/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml @@ -0,0 +1,18 @@ +--- +prelude: > + We're excited to introduce a new ranker to Haystack - LostInTheMiddleRanker. + It reorders documents based on the "Lost in the Middle" order, a strategy that + places the most relevant paragraphs at the beginning or end of the context, + while less relevant paragraphs are positioned in the middle. This ranker, + based on the research paper "Lost in the Middle: How Language Models Use Long + Contexts" by Liu et al., can be leveraged in Retrieval-Augmented Generation + (RAG) pipelines. +features: + - | + The LostInTheMiddleRanker can be used like other rankers in Haystack. After + initializing LostInTheMiddleRanker with the desired parameters, it can be + used to rank/reorder a list of documents based on the "Lost in the Middle" + order - the most relevant documents are located at the top and bottom of + the returned list, while the least relevant documents are found in the + middle. We advise that you use this ranker in combination with other rankers, + and to place it towards the end of the pipeline. diff --git a/test/nodes/test_lost_in_the_middle.py b/test/nodes/test_lost_in_the_middle.py new file mode 100644 index 0000000000..db89c3926e --- /dev/null +++ b/test/nodes/test_lost_in_the_middle.py @@ -0,0 +1,154 @@ +import pytest + +from haystack import Document +from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker + + +@pytest.mark.unit +def test_lost_in_the_middle_order_odd(): + # tests that lost_in_the_middle order works with an odd number of documents + docs = [Document(str(i)) for i in range(1, 10)] + ranker = LostInTheMiddleRanker() + result, _ = ranker.run(query="", documents=docs) + assert result["documents"] + expected_order = "1 3 5 7 9 8 6 4 2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) + + +@pytest.mark.unit +def test_batch_lost_in_the_middle_order(): + # tests that lost_in_the_middle order works with a batch of documents + docs = [ + [Document("1"), Document("2"), Document("3"), Document("4")], + [Document("5"), Document("6")], + [Document("7"), Document("8"), Document("9")], + ] + ranker = LostInTheMiddleRanker() + result, _ = ranker.run_batch(queries=[""], documents=docs) + + assert " ".join(doc.content for doc in result["documents"][0]) == "1 3 4 2" + assert " ".join(doc.content for doc in result["documents"][1]) == "5 6" + assert " ".join(doc.content for doc in result["documents"][2]) == "7 9 8" + + +@pytest.mark.unit +def test_lost_in_the_middle_order_even(): + # tests that lost_in_the_middle order works with an even number of documents + docs = [Document(str(i)) for i in range(1, 11)] + ranker = LostInTheMiddleRanker() + result, _ = ranker.run(query="", documents=docs) + expected_order = "1 3 5 7 9 10 8 6 4 2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) + + +@pytest.mark.unit +def test_lost_in_the_middle_order_two_docs(): + # tests that lost_in_the_middle order works with two documents + ranker = LostInTheMiddleRanker() + + # two docs + docs = [Document("1"), Document("2")] + result, _ = ranker.run(query="", documents=docs) + assert result["documents"][0].content == "1" + assert result["documents"][1].content == "2" + + +@pytest.mark.unit +def test_lost_in_the_middle_init(): + # tests that LostInTheMiddleRanker initializes with default values + ranker = LostInTheMiddleRanker() + assert ranker.word_count_threshold is None + + ranker = LostInTheMiddleRanker(word_count_threshold=10) + assert ranker.word_count_threshold == 10 + + +@pytest.mark.unit +def test_lost_in_the_middle_init_invalid_word_count_threshold(): + # tests that LostInTheMiddleRanker raises an error when word_count_threshold is <= 0 + with pytest.raises(ValueError, match="Invalid value for word_count_threshold"): + LostInTheMiddleRanker(word_count_threshold=0) + + with pytest.raises(ValueError, match="Invalid value for word_count_threshold"): + LostInTheMiddleRanker(word_count_threshold=-5) + + +@pytest.mark.unit +def test_lost_in_the_middle_with_word_count_threshold(): + # tests that lost_in_the_middle with word_count_threshold works as expected + ranker = LostInTheMiddleRanker(word_count_threshold=6) + docs = [Document("word" + str(i)) for i in range(1, 10)] + result, _ = ranker.run(query="", documents=docs) + expected_order = "word1 word3 word5 word6 word4 word2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) + + ranker = LostInTheMiddleRanker(word_count_threshold=9) + result, _ = ranker.run(query="", documents=docs) + expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) + + +@pytest.mark.unit +def test_word_count_threshold_greater_than_total_number_of_words_returns_all_documents(): + ranker = LostInTheMiddleRanker(word_count_threshold=100) + docs = [Document("word" + str(i)) for i in range(1, 10)] + ordered_docs = ranker.predict(query="test", documents=docs) + assert len(ordered_docs) == len(docs) + expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs)) + + +@pytest.mark.unit +def test_empty_documents_returns_empty_list(): + ranker = LostInTheMiddleRanker() + assert ranker.predict(query="test", documents=[]) == [] + + +@pytest.mark.unit +def test_list_of_one_document_returns_same_document(): + ranker = LostInTheMiddleRanker() + doc = Document(content="test", content_type="text") + assert ranker.predict(query="test", documents=[doc]) == [doc] + + +@pytest.mark.unit +def test_non_textual_documents(): + # tests that merging a list of non-textual documents raises a ValueError + ranker = LostInTheMiddleRanker() + doc1 = Document(content="This is a textual document.") + doc2 = Document(content_type="image", content="This is a non-textual document.") + with pytest.raises(ValueError, match="Some provided documents are not textual"): + ranker.reorder_documents([doc1, doc2]) + + +@pytest.mark.unit +@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8, 12, 20]) +def test_lost_in_the_middle_order_with_postive_top_k(top_k: int): + # tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter + docs = [Document(str(i)) for i in range(1, 10)] + ranker = LostInTheMiddleRanker() + result = ranker.predict(query="irrelevant", documents=docs, top_k=top_k) + if top_k < len(docs): + # top_k is less than the number of documents, so only the top_k documents should be returned in LITM order + assert len(result) == top_k + expected_order = ranker.predict(query="irrelevant", documents=[Document(str(i)) for i in range(1, top_k + 1)]) + assert result == expected_order + else: + # top_k is greater than the number of documents, so all documents should be returned in LITM order + assert len(result) == len(docs) + assert result == ranker.predict(query="irrelevant", documents=docs) + + +@pytest.mark.unit +@pytest.mark.parametrize("top_k", [-20, -10, -5, -1]) +def test_lost_in_the_middle_order_with_negative_top_k(top_k: int): + # tests that lost_in_the_middle order works with an odd number of documents and an invalid top_k parameter + docs = [Document(str(i)) for i in range(1, 10)] + ranker = LostInTheMiddleRanker() + result = ranker.predict(query="irrelevant", documents=docs, top_k=top_k) + if top_k < len(docs) * -1: + assert len(result) == 0 # top_k is too negative, so no documents should be returned + else: + # top_k is negative, subtract it from the total number of documents to get the expected number of documents + expected_docs = ranker.predict(query="irrelevant", documents=docs, top_k=len(docs) + top_k) + assert result == expected_docs