From fa92b88902e34d1cf9fe53a36e583f4c95c3ed66 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 27 Jul 2023 11:51:57 +0200 Subject: [PATCH 01/14] Add lost in the middle ranker --- haystack/nodes/ranker/lost_in_the_middle.py | 126 +++++++++++++ test/nodes/test_lost_in_the_middle.py | 189 ++++++++++++++++++++ 2 files changed, 315 insertions(+) create mode 100644 haystack/nodes/ranker/lost_in_the_middle.py create mode 100644 test/nodes/test_lost_in_the_middle.py diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py new file mode 100644 index 0000000000..db27ab7362 --- /dev/null +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -0,0 +1,126 @@ +from typing import Optional, Union, List +import logging + +from haystack.schema import Document +from haystack.nodes.ranker.base import BaseRanker + +logger = logging.getLogger(__name__) + + +class LostInTheMiddleRanker(BaseRanker): + """ + The LostInTheMiddleRanker implements a ranker that reorders documents based on the lost in the middle order. + "Lost in the Middle: How Language Models Use Long Contexts" by Liu et al. aims to layout paragraphs into LLM + context so that relevant paragraphs are at the beginning or end of the input context, while the least relevant + information should be in the middle of a context. + + See https://arxiv.org/abs/2307.03172 for more details. + """ + + def __init__(self, word_count_threshold: Optional[int] = None, truncate_document: Optional[bool] = False): + """ + Creates an instance of LostInTheMiddleRanker. + + If truncate_document is set to True, you must specify a word_count_threshold as well. + + :param word_count_threshold: The maximum number of words in all ordered documents. + :param truncate_document: Whether to truncate the last document that overflows the word count threshold. + """ + super().__init__() + if truncate_document and not word_count_threshold: + raise ValueError("If truncate_document is set to True, you must specify a word_count_threshold as well.") + self.word_count_threshold = word_count_threshold + self.truncate_document = truncate_document + + def reorder_documents(self, documents: List[Document]) -> List[Document]: + """ + Orders documents based on the lost in the middle order. + + :param documents: List of Documents to merge. + :return: Documents in the lost in the middle order. + """ + if not documents: + return [] + if len(documents) == 1: + return documents + + if any(not doc.content_type == "text" for doc in documents): + raise ValueError("Some provided documents are not textual; LostInTheMiddleRanker can process only text.") + + word_count = 0 + document_index = list(range(len(documents))) + lost_in_the_middle_indices = [0] + if self.word_count_threshold: + word_count = len(documents[0].content.split()) + if word_count >= self.word_count_threshold: + return [documents[0]] + + for doc_idx in document_index[1:]: + insertion_index = len(lost_in_the_middle_indices) // 2 + len(lost_in_the_middle_indices) % 2 + lost_in_the_middle_indices.insert(insertion_index, doc_idx) + if self.word_count_threshold: + word_count += len(documents[doc_idx].content.split()) + # if threshold is specified, check if we have enough words in all selected documents + # if yes, we can stop adding documents + if word_count >= self.word_count_threshold: + if self.truncate_document: + # truncate the last document that overflows the word count threshold + last_docs_length = len(documents[doc_idx].content.split()) + truncate_last_doc_length = last_docs_length - (word_count - self.word_count_threshold) + documents[doc_idx] = self._truncate(documents[doc_idx], truncate_last_doc_length) + break + + return [documents[idx] for idx in lost_in_the_middle_indices] + + def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]: + """ + Reorders documents based on the lost in the middle order. + + :param query: The query to rerank documents for. + :param documents: List of Documents to reorder. + :param top_k: The number of documents to return. + + :return: The reranked documents. + """ + ordered_docs = self.reorder_documents(documents=documents) + return ordered_docs[:top_k] if top_k else ordered_docs + + def predict_batch( + self, + queries: List[str], + documents: Union[List[Document], List[List[Document]]], + top_k: Optional[int] = None, + batch_size: Optional[int] = None, + ) -> Union[List[Document], List[List[Document]]]: + """ + Reorders batch of documents based on the lost in the middle order. + + :param queries: The queries to rerank documents for (ignored). + :param documents: List of Documents to reorder. + :param top_k: The number of documents to return. + :param batch_size: The number of queries to process in one batch. + + :return: The reordered documents. + """ + if len(documents) > 0 and isinstance(documents[0], Document): + return self.predict(query="", documents=documents, top_k=top_k) # type: ignore + else: + # Docs case 2: list of lists of Documents -> rerank each list of Documents + results = [] + for cur_docs in documents: + assert isinstance(cur_docs, list) + results.append(self.predict(query="", documents=cur_docs, top_k=top_k)) + return results + + def _truncate(self, document: Document, word_count_threshold: int) -> Document: + """ + Shortens a document by cutting off the content after a specified number of words. + + :param document: Document to truncate. + :param word_count_threshold: integer representing the maximum number of words + allowed in the truncated document. + + :return: Document with truncated content. + """ + document.content = " ".join(document.content.split()[:word_count_threshold]) + return document diff --git a/test/nodes/test_lost_in_the_middle.py b/test/nodes/test_lost_in_the_middle.py new file mode 100644 index 0000000000..20b36096ec --- /dev/null +++ b/test/nodes/test_lost_in_the_middle.py @@ -0,0 +1,189 @@ +import pytest + +from haystack import Document +from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker + + +@pytest.mark.unit +def test_lost_in_the_middle_order_odd(): + # tests that lost_in_the_middle order works with an odd number of documents + docs = [ + Document("1"), + Document("2"), + Document("3"), + Document("4"), + Document("5"), + Document("6"), + Document("7"), + Document("8"), + Document("9"), + ] + dm = LostInTheMiddleRanker() + result, _ = dm.run(query="", documents=docs) + assert result["documents"] + expected_order = "1 3 5 7 9 8 6 4 2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) + + +@pytest.mark.unit +def test_batch_lost_in_the_middle_order_(): + # tests that lost_in_the_middle order works with a batch of documents + docs = [ + [Document("1"), Document("2"), Document("3"), Document("4")], + [Document("5"), Document("6")], + [Document("7"), Document("8"), Document("9")], + ] + dm = LostInTheMiddleRanker() + result, _ = dm.run_batch(queries=[""], documents=docs) + + assert " ".join(doc.content for doc in result["documents"][0]) == "1 3 4 2" + assert " ".join(doc.content for doc in result["documents"][1]) == "5 6" + assert " ".join(doc.content for doc in result["documents"][2]) == "7 9 8" + + +@pytest.mark.unit +def test_lost_in_the_middle_order_even(): + # tests that lost_in_the_middle order works with an even number of documents + docs = [ + Document("1"), + Document("2"), + Document("3"), + Document("4"), + Document("5"), + Document("6"), + Document("7"), + Document("8"), + Document("9"), + Document("10"), + ] + dm = LostInTheMiddleRanker() + result, _ = dm.run(query="", documents=docs) + expected_order = "1 3 5 7 9 10 8 6 4 2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) + + +@pytest.mark.unit +def test_lost_in_the_middle_order_corner(): + # tests that lost_in_the_middle order works with some basic corner cases + dm = LostInTheMiddleRanker() + + # empty doc list + docs = [] + result, _ = dm.run(query="", documents=docs) + assert len(result["documents"]) == 0 + + # single doc + docs = [Document("1")] + result, _ = dm.run(query="", documents=docs) + assert result["documents"][0].content == "1" + + # two docs + docs = [Document("1"), Document("2")] + result, _ = dm.run(query="", documents=docs) + assert result["documents"][0].content == "1" + assert result["documents"][1].content == "2" + + +@pytest.mark.unit +def test_lost_in_the_middle_init(): + # tests that LostInTheMiddleRanker initializes with default values + litm = LostInTheMiddleRanker() + assert litm.word_count_threshold is None + assert litm.truncate_document is False + + litm = LostInTheMiddleRanker(word_count_threshold=10, truncate_document=True) + assert litm.word_count_threshold == 10 + assert litm.truncate_document is True + + with pytest.raises( + ValueError, match="If truncate_document is set to True, you must specify a word_count_threshold" + ): + LostInTheMiddleRanker(truncate_document=True) + + +@pytest.mark.unit +def test_lost_in_the_middle_with_word_count_threshold(): + # tests that lost_in_the_middle with word_count_threshold works as expected + litm = LostInTheMiddleRanker(word_count_threshold=6) + docs = [ + Document("word1"), + Document("word2"), + Document("word3"), + Document("word4"), + Document("word5"), + Document("word6"), + Document("word7"), + Document("word8"), + Document("word9"), + ] + result, _ = litm.run(query="", documents=docs) + expected_order = "word1 word3 word5 word6 word4 word2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) + + litm = LostInTheMiddleRanker(word_count_threshold=9) + result, _ = litm.run(query="", documents=docs) + expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) + + +@pytest.mark.unit +def test_word_count_threshold_greater_than_total_number_of_words_returns_all_documents(): + ranker = LostInTheMiddleRanker(word_count_threshold=100) + docs = [ + Document("word1"), + Document("word2"), + Document("word3"), + Document("word4"), + Document("word5"), + Document("word6"), + Document("word7"), + Document("word8"), + Document("word9"), + ] + ordered_docs = ranker.predict(query="test", documents=docs) + assert len(ordered_docs) == len(docs) + expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split() + assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs)) + + +@pytest.mark.unit +def test_truncation_with_threshold(): + # tests that truncation works as expected + litm = LostInTheMiddleRanker(word_count_threshold=9, truncate_document=True) + docs = [ + Document("word1 word1"), + Document("word2 word2"), + Document("word3 word3"), + Document("word4 word4"), + Document("word5 word5"), + Document("word6 word6"), + Document("word7 word7"), + Document("word8 word8"), + Document("word9 word9"), + ] + result, _ = litm.run(query="", documents=docs) + expected_order = "word1 word1 word3 word3 word5 word4 word4 word2 word2" + assert expected_order == " ".join(doc.content for doc in result["documents"]) + + +@pytest.mark.unit +def test_empty_documents_returns_empty_list(): + ranker = LostInTheMiddleRanker() + assert ranker.predict(query="test", documents=[]) == [] + + +@pytest.mark.unit +def test_list_of_one_document_returns_same_document(): + ranker = LostInTheMiddleRanker() + doc = Document(content="test", content_type="text") + assert ranker.predict(query="test", documents=[doc]) == [doc] + + +@pytest.mark.unit +def test_non_textual_documents(): + # tests that merging a list of non-textual documents raises a ValueError + litm = LostInTheMiddleRanker() + doc1 = Document(content="This is a textual document.") + doc2 = Document(content_type="image", content="This is a non-textual document.") + with pytest.raises(ValueError): + litm.reorder_documents([doc1, doc2]) From c5660bf7e15af77551c0ddf3500b6b90b7d93880 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 28 Jul 2023 10:12:11 +0200 Subject: [PATCH 02/14] Add release note --- ...st-in-the-middle-ranker-6ad7dda754fad5a9.yaml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml diff --git a/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml b/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml new file mode 100644 index 0000000000..08e20d6f31 --- /dev/null +++ b/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml @@ -0,0 +1,16 @@ +--- +prelude: > + We're excited to introduce a new ranker to Haystack - LostInTheMiddleRanker. + It reorders documents based on the "Lost in the Middle" order, a strategy that + places the most relevant paragraphs at the beginning or end of the context, + while less relevant paragraphs are positioned in the middle. This ranker, + based on the research paper "Lost in the Middle: How Language Models Use Long + Contexts" by Liu et al., can be leveraged in Retrieval-Augmented Generation + (RAG) pipelines. +features: + - | + The LostInTheMiddleRanker can be used like other rankers in Haystack. After + initializing LostInTheMiddleRanker with the desired parameters, it can be + used to rank/reorder a list of documents based on the "Lost in the Middle" + order. We advice that you use this ranker in combination with other + rankers, and to place it towards the end of the pipeline. From 1a853fd0dd292d8c5446b8926ee3d5980836ae83 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 31 Jul 2023 21:33:57 +0200 Subject: [PATCH 03/14] PR review: implement top_k correctly, improve and add new unit tests --- haystack/nodes/ranker/lost_in_the_middle.py | 15 ++- test/nodes/test_lost_in_the_middle.py | 101 +++++++++----------- 2 files changed, 58 insertions(+), 58 deletions(-) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index db27ab7362..8200033bd5 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -80,10 +80,10 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = :param documents: List of Documents to reorder. :param top_k: The number of documents to return. - :return: The reranked documents. + :return: The re-ranked documents. """ ordered_docs = self.reorder_documents(documents=documents) - return ordered_docs[:top_k] if top_k else ordered_docs + return self._exclude_middle_elements(ordered_docs, top_k) if top_k else ordered_docs def predict_batch( self, @@ -112,6 +112,17 @@ def predict_batch( results.append(self.predict(query="", documents=cur_docs, top_k=top_k)) return results + def _exclude_middle_elements(self, ordered_docs: List[Document], top_k: int): + exclude_count = len(ordered_docs) - top_k + middle_index = len(ordered_docs) // 2 + half_top_k = exclude_count // 2 + + start_index = middle_index - half_top_k + len(ordered_docs) % 2 + end_index = start_index + exclude_count + remaining_elements = ordered_docs[:start_index] + ordered_docs[end_index:] + + return remaining_elements + def _truncate(self, document: Document, word_count_threshold: int) -> Document: """ Shortens a document by cutting off the content after a specified number of words. diff --git a/test/nodes/test_lost_in_the_middle.py b/test/nodes/test_lost_in_the_middle.py index 20b36096ec..fb64c2d008 100644 --- a/test/nodes/test_lost_in_the_middle.py +++ b/test/nodes/test_lost_in_the_middle.py @@ -7,34 +7,24 @@ @pytest.mark.unit def test_lost_in_the_middle_order_odd(): # tests that lost_in_the_middle order works with an odd number of documents - docs = [ - Document("1"), - Document("2"), - Document("3"), - Document("4"), - Document("5"), - Document("6"), - Document("7"), - Document("8"), - Document("9"), - ] - dm = LostInTheMiddleRanker() - result, _ = dm.run(query="", documents=docs) + docs = [Document(str(i)) for i in range(1, 10)] + litm = LostInTheMiddleRanker() + result, _ = litm.run(query="", documents=docs) assert result["documents"] expected_order = "1 3 5 7 9 8 6 4 2".split() assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) @pytest.mark.unit -def test_batch_lost_in_the_middle_order_(): +def test_batch_lost_in_the_middle_order(): # tests that lost_in_the_middle order works with a batch of documents docs = [ [Document("1"), Document("2"), Document("3"), Document("4")], [Document("5"), Document("6")], [Document("7"), Document("8"), Document("9")], ] - dm = LostInTheMiddleRanker() - result, _ = dm.run_batch(queries=[""], documents=docs) + litm = LostInTheMiddleRanker() + result, _ = litm.run_batch(queries=[""], documents=docs) assert " ".join(doc.content for doc in result["documents"][0]) == "1 3 4 2" assert " ".join(doc.content for doc in result["documents"][1]) == "5 6" @@ -44,20 +34,9 @@ def test_batch_lost_in_the_middle_order_(): @pytest.mark.unit def test_lost_in_the_middle_order_even(): # tests that lost_in_the_middle order works with an even number of documents - docs = [ - Document("1"), - Document("2"), - Document("3"), - Document("4"), - Document("5"), - Document("6"), - Document("7"), - Document("8"), - Document("9"), - Document("10"), - ] - dm = LostInTheMiddleRanker() - result, _ = dm.run(query="", documents=docs) + docs = [Document(str(i)) for i in range(1, 11)] + litm = LostInTheMiddleRanker() + result, _ = litm.run(query="", documents=docs) expected_order = "1 3 5 7 9 10 8 6 4 2".split() assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) @@ -65,21 +44,21 @@ def test_lost_in_the_middle_order_even(): @pytest.mark.unit def test_lost_in_the_middle_order_corner(): # tests that lost_in_the_middle order works with some basic corner cases - dm = LostInTheMiddleRanker() + litm = LostInTheMiddleRanker() # empty doc list docs = [] - result, _ = dm.run(query="", documents=docs) + result, _ = litm.run(query="", documents=docs) assert len(result["documents"]) == 0 # single doc docs = [Document("1")] - result, _ = dm.run(query="", documents=docs) + result, _ = litm.run(query="", documents=docs) assert result["documents"][0].content == "1" # two docs docs = [Document("1"), Document("2")] - result, _ = dm.run(query="", documents=docs) + result, _ = litm.run(query="", documents=docs) assert result["documents"][0].content == "1" assert result["documents"][1].content == "2" @@ -105,17 +84,7 @@ def test_lost_in_the_middle_init(): def test_lost_in_the_middle_with_word_count_threshold(): # tests that lost_in_the_middle with word_count_threshold works as expected litm = LostInTheMiddleRanker(word_count_threshold=6) - docs = [ - Document("word1"), - Document("word2"), - Document("word3"), - Document("word4"), - Document("word5"), - Document("word6"), - Document("word7"), - Document("word8"), - Document("word9"), - ] + docs = [Document("word" + str(i)) for i in range(1, 10)] result, _ = litm.run(query="", documents=docs) expected_order = "word1 word3 word5 word6 word4 word2".split() assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) @@ -129,17 +98,7 @@ def test_lost_in_the_middle_with_word_count_threshold(): @pytest.mark.unit def test_word_count_threshold_greater_than_total_number_of_words_returns_all_documents(): ranker = LostInTheMiddleRanker(word_count_threshold=100) - docs = [ - Document("word1"), - Document("word2"), - Document("word3"), - Document("word4"), - Document("word5"), - Document("word6"), - Document("word7"), - Document("word8"), - Document("word9"), - ] + docs = [Document("word" + str(i)) for i in range(1, 10)] ordered_docs = ranker.predict(query="test", documents=docs) assert len(ordered_docs) == len(docs) expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split() @@ -187,3 +146,33 @@ def test_non_textual_documents(): doc2 = Document(content_type="image", content="This is a non-textual document.") with pytest.raises(ValueError): litm.reorder_documents([doc1, doc2]) + + +@pytest.mark.unit +@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8]) +def test_lost_in_the_middle_order_odd_with_top_k(top_k: int): + # tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter + docs = [Document(str(i)) for i in range(1, 10)] + litm = LostInTheMiddleRanker() + result = litm._exclude_middle_elements(docs, top_k=top_k) + assert len(result) == top_k + + reverse_top_k = len(docs) - top_k + middle = len(docs) // 2 + 1 + expected_docs = docs[: middle - reverse_top_k // 2] + docs[middle + reverse_top_k // 2 + reverse_top_k % 2 :] + assert result == expected_docs + + +@pytest.mark.unit +@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8]) +def test_lost_in_the_middle_order_even_with_top_k(top_k: int): + # tests that lost_in_the_middle order works with an even number of documents and a top_k parameter + docs = [Document(str(i)) for i in range(1, 9)] + litm = LostInTheMiddleRanker() + result = litm._exclude_middle_elements(docs, top_k=top_k) + assert len(result) == top_k + + reverse_top_k = len(docs) - top_k + middle = len(docs) // 2 + expected_docs = docs[: middle - reverse_top_k // 2] + docs[middle + reverse_top_k // 2 + reverse_top_k % 2 :] + assert result == expected_docs From b72cd349f9447ced99362f196690bf0836b0a517 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 31 Jul 2023 21:35:42 +0200 Subject: [PATCH 04/14] Use ranker variable name in tests --- test/nodes/test_lost_in_the_middle.py | 56 +++++++++++++-------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/test/nodes/test_lost_in_the_middle.py b/test/nodes/test_lost_in_the_middle.py index fb64c2d008..c741ce756f 100644 --- a/test/nodes/test_lost_in_the_middle.py +++ b/test/nodes/test_lost_in_the_middle.py @@ -8,8 +8,8 @@ def test_lost_in_the_middle_order_odd(): # tests that lost_in_the_middle order works with an odd number of documents docs = [Document(str(i)) for i in range(1, 10)] - litm = LostInTheMiddleRanker() - result, _ = litm.run(query="", documents=docs) + ranker = LostInTheMiddleRanker() + result, _ = ranker.run(query="", documents=docs) assert result["documents"] expected_order = "1 3 5 7 9 8 6 4 2".split() assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) @@ -23,8 +23,8 @@ def test_batch_lost_in_the_middle_order(): [Document("5"), Document("6")], [Document("7"), Document("8"), Document("9")], ] - litm = LostInTheMiddleRanker() - result, _ = litm.run_batch(queries=[""], documents=docs) + ranker = LostInTheMiddleRanker() + result, _ = ranker.run_batch(queries=[""], documents=docs) assert " ".join(doc.content for doc in result["documents"][0]) == "1 3 4 2" assert " ".join(doc.content for doc in result["documents"][1]) == "5 6" @@ -35,8 +35,8 @@ def test_batch_lost_in_the_middle_order(): def test_lost_in_the_middle_order_even(): # tests that lost_in_the_middle order works with an even number of documents docs = [Document(str(i)) for i in range(1, 11)] - litm = LostInTheMiddleRanker() - result, _ = litm.run(query="", documents=docs) + ranker = LostInTheMiddleRanker() + result, _ = ranker.run(query="", documents=docs) expected_order = "1 3 5 7 9 10 8 6 4 2".split() assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) @@ -44,21 +44,21 @@ def test_lost_in_the_middle_order_even(): @pytest.mark.unit def test_lost_in_the_middle_order_corner(): # tests that lost_in_the_middle order works with some basic corner cases - litm = LostInTheMiddleRanker() + ranker = LostInTheMiddleRanker() # empty doc list docs = [] - result, _ = litm.run(query="", documents=docs) + result, _ = ranker.run(query="", documents=docs) assert len(result["documents"]) == 0 # single doc docs = [Document("1")] - result, _ = litm.run(query="", documents=docs) + result, _ = ranker.run(query="", documents=docs) assert result["documents"][0].content == "1" # two docs docs = [Document("1"), Document("2")] - result, _ = litm.run(query="", documents=docs) + result, _ = ranker.run(query="", documents=docs) assert result["documents"][0].content == "1" assert result["documents"][1].content == "2" @@ -66,13 +66,13 @@ def test_lost_in_the_middle_order_corner(): @pytest.mark.unit def test_lost_in_the_middle_init(): # tests that LostInTheMiddleRanker initializes with default values - litm = LostInTheMiddleRanker() - assert litm.word_count_threshold is None - assert litm.truncate_document is False + ranker = LostInTheMiddleRanker() + assert ranker.word_count_threshold is None + assert ranker.truncate_document is False - litm = LostInTheMiddleRanker(word_count_threshold=10, truncate_document=True) - assert litm.word_count_threshold == 10 - assert litm.truncate_document is True + ranker = LostInTheMiddleRanker(word_count_threshold=10, truncate_document=True) + assert ranker.word_count_threshold == 10 + assert ranker.truncate_document is True with pytest.raises( ValueError, match="If truncate_document is set to True, you must specify a word_count_threshold" @@ -83,14 +83,14 @@ def test_lost_in_the_middle_init(): @pytest.mark.unit def test_lost_in_the_middle_with_word_count_threshold(): # tests that lost_in_the_middle with word_count_threshold works as expected - litm = LostInTheMiddleRanker(word_count_threshold=6) + ranker = LostInTheMiddleRanker(word_count_threshold=6) docs = [Document("word" + str(i)) for i in range(1, 10)] - result, _ = litm.run(query="", documents=docs) + result, _ = ranker.run(query="", documents=docs) expected_order = "word1 word3 word5 word6 word4 word2".split() assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) - litm = LostInTheMiddleRanker(word_count_threshold=9) - result, _ = litm.run(query="", documents=docs) + ranker = LostInTheMiddleRanker(word_count_threshold=9) + result, _ = ranker.run(query="", documents=docs) expected_order = "word1 word3 word5 word7 word9 word8 word6 word4 word2".split() assert all(doc.content == expected_order[idx] for idx, doc in enumerate(result["documents"])) @@ -108,7 +108,7 @@ def test_word_count_threshold_greater_than_total_number_of_words_returns_all_doc @pytest.mark.unit def test_truncation_with_threshold(): # tests that truncation works as expected - litm = LostInTheMiddleRanker(word_count_threshold=9, truncate_document=True) + ranker = LostInTheMiddleRanker(word_count_threshold=9, truncate_document=True) docs = [ Document("word1 word1"), Document("word2 word2"), @@ -120,7 +120,7 @@ def test_truncation_with_threshold(): Document("word8 word8"), Document("word9 word9"), ] - result, _ = litm.run(query="", documents=docs) + result, _ = ranker.run(query="", documents=docs) expected_order = "word1 word1 word3 word3 word5 word4 word4 word2 word2" assert expected_order == " ".join(doc.content for doc in result["documents"]) @@ -141,11 +141,11 @@ def test_list_of_one_document_returns_same_document(): @pytest.mark.unit def test_non_textual_documents(): # tests that merging a list of non-textual documents raises a ValueError - litm = LostInTheMiddleRanker() + ranker = LostInTheMiddleRanker() doc1 = Document(content="This is a textual document.") doc2 = Document(content_type="image", content="This is a non-textual document.") with pytest.raises(ValueError): - litm.reorder_documents([doc1, doc2]) + ranker.reorder_documents([doc1, doc2]) @pytest.mark.unit @@ -153,8 +153,8 @@ def test_non_textual_documents(): def test_lost_in_the_middle_order_odd_with_top_k(top_k: int): # tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter docs = [Document(str(i)) for i in range(1, 10)] - litm = LostInTheMiddleRanker() - result = litm._exclude_middle_elements(docs, top_k=top_k) + ranker = LostInTheMiddleRanker() + result = ranker._exclude_middle_elements(docs, top_k=top_k) assert len(result) == top_k reverse_top_k = len(docs) - top_k @@ -168,8 +168,8 @@ def test_lost_in_the_middle_order_odd_with_top_k(top_k: int): def test_lost_in_the_middle_order_even_with_top_k(top_k: int): # tests that lost_in_the_middle order works with an even number of documents and a top_k parameter docs = [Document(str(i)) for i in range(1, 9)] - litm = LostInTheMiddleRanker() - result = litm._exclude_middle_elements(docs, top_k=top_k) + ranker = LostInTheMiddleRanker() + result = ranker._exclude_middle_elements(docs, top_k=top_k) assert len(result) == top_k reverse_top_k = len(docs) - top_k From a16219debdd98c07bdf1d1a9ec7beb946899c3af Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 1 Aug 2023 07:57:23 +0200 Subject: [PATCH 05/14] Add invalid top_k handling, unit tests --- haystack/nodes/ranker/lost_in_the_middle.py | 5 ++++- test/nodes/test_lost_in_the_middle.py | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index 8200033bd5..98bee225d6 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -83,7 +83,8 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = :return: The re-ranked documents. """ ordered_docs = self.reorder_documents(documents=documents) - return self._exclude_middle_elements(ordered_docs, top_k) if top_k else ordered_docs + valid_top_k = isinstance(top_k, int) and 0 < top_k < len(ordered_docs) + return self._exclude_middle_elements(ordered_docs, top_k) if valid_top_k else ordered_docs # type: ignore def predict_batch( self, @@ -113,6 +114,8 @@ def predict_batch( return results def _exclude_middle_elements(self, ordered_docs: List[Document], top_k: int): + if top_k < 1 or top_k > len(ordered_docs): + raise ValueError(f"top_k must be between 1 and {len(ordered_docs)}") exclude_count = len(ordered_docs) - top_k middle_index = len(ordered_docs) // 2 half_top_k = exclude_count // 2 diff --git a/test/nodes/test_lost_in_the_middle.py b/test/nodes/test_lost_in_the_middle.py index c741ce756f..398e8dccf8 100644 --- a/test/nodes/test_lost_in_the_middle.py +++ b/test/nodes/test_lost_in_the_middle.py @@ -176,3 +176,23 @@ def test_lost_in_the_middle_order_even_with_top_k(top_k: int): middle = len(docs) // 2 expected_docs = docs[: middle - reverse_top_k // 2] + docs[middle + reverse_top_k // 2 + reverse_top_k % 2 :] assert result == expected_docs + + +@pytest.mark.unit +@pytest.mark.parametrize("top_k", [-5, 15]) +def test_lost_in_the_middle_order_odd_with_invalid_top_k(top_k: int): + # tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter + docs = [Document(str(i)) for i in range(1, 10)] + ranker = LostInTheMiddleRanker() + with pytest.raises(ValueError): + ranker._exclude_middle_elements(docs, top_k=top_k) + + +@pytest.mark.unit +@pytest.mark.parametrize("top_k", [-5, 15]) +def test_lost_in_the_middle_order_even_with_invalid_top_k(top_k: int): + # tests that lost_in_the_middle order works with an even number of documents and a top_k parameter + docs = [Document(str(i)) for i in range(1, 9)] + ranker = LostInTheMiddleRanker() + with pytest.raises(ValueError): + ranker._exclude_middle_elements(docs, top_k=top_k) From 5bc429f3a8c092cedf52ea6646b06bc136a40309 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 1 Aug 2023 09:28:26 +0200 Subject: [PATCH 06/14] Minor fixes --- haystack/nodes/ranker/lost_in_the_middle.py | 6 ++++-- ...in-the-middle-ranker-6ad7dda754fad5a9.yaml | 6 ++++-- test/nodes/test_lost_in_the_middle.py | 20 +++++-------------- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index 98bee225d6..8d3615948a 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -21,7 +21,9 @@ def __init__(self, word_count_threshold: Optional[int] = None, truncate_document """ Creates an instance of LostInTheMiddleRanker. - If truncate_document is set to True, you must specify a word_count_threshold as well. + If truncate_document is True, you must specify a word_count_threshold as well. If truncate_document is False + and word_count_threshold is specified, the word_count_threshold will be used as a soft limit. The last document + breaching the word_count_threshold will be included in the resulting list of Documents but won't be truncated. :param word_count_threshold: The maximum number of words in all ordered documents. :param truncate_document: Whether to truncate the last document that overflows the word count threshold. @@ -34,7 +36,7 @@ def __init__(self, word_count_threshold: Optional[int] = None, truncate_document def reorder_documents(self, documents: List[Document]) -> List[Document]: """ - Orders documents based on the lost in the middle order. + Orders documents based on the lost in the middle order. Assumes that all documents are ordered by relevance. :param documents: List of Documents to merge. :return: Documents in the lost in the middle order. diff --git a/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml b/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml index 08e20d6f31..2588daf593 100644 --- a/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml +++ b/releasenotes/notes/add-lost-in-the-middle-ranker-6ad7dda754fad5a9.yaml @@ -12,5 +12,7 @@ features: The LostInTheMiddleRanker can be used like other rankers in Haystack. After initializing LostInTheMiddleRanker with the desired parameters, it can be used to rank/reorder a list of documents based on the "Lost in the Middle" - order. We advice that you use this ranker in combination with other - rankers, and to place it towards the end of the pipeline. + order - the most relevant documents are located at the top and bottom of + the returned list, while the least relevant documents are found in the + middle. We advise that you use this ranker in combination with other rankers, + and to place it towards the end of the pipeline. diff --git a/test/nodes/test_lost_in_the_middle.py b/test/nodes/test_lost_in_the_middle.py index 398e8dccf8..1a16e0b5e1 100644 --- a/test/nodes/test_lost_in_the_middle.py +++ b/test/nodes/test_lost_in_the_middle.py @@ -42,20 +42,10 @@ def test_lost_in_the_middle_order_even(): @pytest.mark.unit -def test_lost_in_the_middle_order_corner(): - # tests that lost_in_the_middle order works with some basic corner cases +def test_lost_in_the_middle_order_two_docs(): + # tests that lost_in_the_middle order works with two documents ranker = LostInTheMiddleRanker() - # empty doc list - docs = [] - result, _ = ranker.run(query="", documents=docs) - assert len(result["documents"]) == 0 - - # single doc - docs = [Document("1")] - result, _ = ranker.run(query="", documents=docs) - assert result["documents"][0].content == "1" - # two docs docs = [Document("1"), Document("2")] result, _ = ranker.run(query="", documents=docs) @@ -144,7 +134,7 @@ def test_non_textual_documents(): ranker = LostInTheMiddleRanker() doc1 = Document(content="This is a textual document.") doc2 = Document(content_type="image", content="This is a non-textual document.") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Some provided documents are not textual"): ranker.reorder_documents([doc1, doc2]) @@ -181,7 +171,7 @@ def test_lost_in_the_middle_order_even_with_top_k(top_k: int): @pytest.mark.unit @pytest.mark.parametrize("top_k", [-5, 15]) def test_lost_in_the_middle_order_odd_with_invalid_top_k(top_k: int): - # tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter + # tests that lost_in_the_middle order works with an odd number of documents and an invalid top_k parameter docs = [Document(str(i)) for i in range(1, 10)] ranker = LostInTheMiddleRanker() with pytest.raises(ValueError): @@ -191,7 +181,7 @@ def test_lost_in_the_middle_order_odd_with_invalid_top_k(top_k: int): @pytest.mark.unit @pytest.mark.parametrize("top_k", [-5, 15]) def test_lost_in_the_middle_order_even_with_invalid_top_k(top_k: int): - # tests that lost_in_the_middle order works with an even number of documents and a top_k parameter + # tests that lost_in_the_middle order works with an even number of documents and an invalid top_k parameter docs = [Document(str(i)) for i in range(1, 9)] ranker = LostInTheMiddleRanker() with pytest.raises(ValueError): From 10a92c25a9f9699bf90257c33a78288352ccffe3 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 1 Aug 2023 10:22:00 +0200 Subject: [PATCH 07/14] Julian's feedback: more precise version of truncate --- haystack/nodes/ranker/lost_in_the_middle.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index 8d3615948a..a3f43de224 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -131,12 +131,14 @@ def _exclude_middle_elements(self, ordered_docs: List[Document], top_k: int): def _truncate(self, document: Document, word_count_threshold: int) -> Document: """ Shortens a document by cutting off the content after a specified number of words. - :param document: Document to truncate. :param word_count_threshold: integer representing the maximum number of words allowed in the truncated document. - :return: Document with truncated content. """ - document.content = " ".join(document.content.split()[:word_count_threshold]) + words = document.content.split() + if len(words) > word_count_threshold: + # -1 to remove trailing whitespace + cut_off = sum(len(word) + 1 for word in words[:word_count_threshold]) - 1 + document.content = document.content[:cut_off] return document From c5fcc30726a6a4f363c2f1443ac8913a82db7e96 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 1 Aug 2023 11:29:20 +0200 Subject: [PATCH 08/14] Better comments for the litm algorithm --- haystack/nodes/ranker/lost_in_the_middle.py | 23 ++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index a3f43de224..78e1bd9079 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -41,37 +41,54 @@ def reorder_documents(self, documents: List[Document]) -> List[Document]: :param documents: List of Documents to merge. :return: Documents in the lost in the middle order. """ + + # Return empty list if no documents are provided if not documents: return [] + + # If there's only one document, return it as is if len(documents) == 1: return documents + # Raise an error if any document is not textual if any(not doc.content_type == "text" for doc in documents): raise ValueError("Some provided documents are not textual; LostInTheMiddleRanker can process only text.") + # Initialize word count and indices for the "lost in the middle" order word_count = 0 document_index = list(range(len(documents))) lost_in_the_middle_indices = [0] + + # If word count threshold is set, calculate word count for the first document if self.word_count_threshold: word_count = len(documents[0].content.split()) + + # If the first document already meets the word count threshold, return it if word_count >= self.word_count_threshold: return [documents[0]] + # Start from the second document and create "lost in the middle" order for doc_idx in document_index[1:]: + # Calculate the index at which the current document should be inserted insertion_index = len(lost_in_the_middle_indices) // 2 + len(lost_in_the_middle_indices) % 2 + + # Insert the document index at the calculated position lost_in_the_middle_indices.insert(insertion_index, doc_idx) + + # If word count threshold is set, calculate the total word count if self.word_count_threshold: word_count += len(documents[doc_idx].content.split()) - # if threshold is specified, check if we have enough words in all selected documents - # if yes, we can stop adding documents + + # If the total word count meets the threshold, stop processing further documents if word_count >= self.word_count_threshold: + # If truncation is allowed, truncate the last document to meet the word count threshold if self.truncate_document: - # truncate the last document that overflows the word count threshold last_docs_length = len(documents[doc_idx].content.split()) truncate_last_doc_length = last_docs_length - (word_count - self.word_count_threshold) documents[doc_idx] = self._truncate(documents[doc_idx], truncate_last_doc_length) break + # Return the documents in the "lost in the middle" order return [documents[idx] for idx in lost_in_the_middle_indices] def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]: From 32d06e6208c1a2684a37fe0b01cb045193cf83aa Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 1 Aug 2023 13:49:07 +0200 Subject: [PATCH 09/14] Update examples/web_lfqa_improved.py --- examples/web_lfqa_improved.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/web_lfqa_improved.py b/examples/web_lfqa_improved.py index c46852617e..5c070dd093 100644 --- a/examples/web_lfqa_improved.py +++ b/examples/web_lfqa_improved.py @@ -2,8 +2,9 @@ import os from haystack import Pipeline -from haystack.nodes import PromptNode, PromptTemplate, TopPSampler, DocumentMerger +from haystack.nodes import PromptNode, PromptTemplate, TopPSampler from haystack.nodes.ranker.diversity import DiversityRanker +from haystack.nodes.ranker.lost_in_the_middle import LostInTheMiddleRanker from haystack.nodes.retriever.web import WebRetriever search_key = os.environ.get("SERPERDEV_API_KEY") @@ -22,21 +23,21 @@ """ prompt_node = PromptNode( - "gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=256 + "gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=640 ) -web_retriever = WebRetriever(api_key=search_key, top_search_results=10, mode="preprocessed_documents", top_k=25) +web_retriever = WebRetriever(api_key=search_key, top_search_results=5, mode="preprocessed_documents", top_k=50) -sampler = TopPSampler(top_p=0.95) -ranker = DiversityRanker() -merger = DocumentMerger(separator="\n\n") +sampler = TopPSampler(top_p=0.97) +diversity_ranker = DiversityRanker() +litm_ranker = LostInTheMiddleRanker(word_count_threshold=1024) pipeline = Pipeline() pipeline.add_node(component=web_retriever, name="Retriever", inputs=["Query"]) pipeline.add_node(component=sampler, name="Sampler", inputs=["Retriever"]) -pipeline.add_node(component=ranker, name="Ranker", inputs=["Sampler"]) -pipeline.add_node(component=merger, name="Merger", inputs=["Ranker"]) -pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["Merger"]) +pipeline.add_node(component=diversity_ranker, name="DiversityRanker", inputs=["Sampler"]) +pipeline.add_node(component=litm_ranker, name="LostInTheMiddleRanker", inputs=["DiversityRanker"]) +pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["LostInTheMiddleRanker"]) logger = logging.getLogger("boilerpy3") logger.setLevel(logging.CRITICAL) From 2c46028194a3933a1323e09b8fd5a33368df7dde Mon Sep 17 00:00:00 2001 From: Darja Fokina Date: Wed, 2 Aug 2023 11:28:45 +0200 Subject: [PATCH 10/14] Update docs --- docs/pydoc/config/ranker.yml | 2 +- haystack/nodes/ranker/lost_in_the_middle.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/pydoc/config/ranker.yml b/docs/pydoc/config/ranker.yml index e0776e7f04..052756615f 100644 --- a/docs/pydoc/config/ranker.yml +++ b/docs/pydoc/config/ranker.yml @@ -1,7 +1,7 @@ loaders: - type: python search_path: [../../../haystack/nodes/ranker] - modules: ["base", "sentence_transformers", "recentness_ranker", "diversity"] + modules: ["base", "sentence_transformers", "recentness_ranker", "diversity", "lost_in_the_middle"] ignore_when_discovered: ["__init__"] processors: - type: filter diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index 78e1bd9079..ba5a048ff4 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -12,7 +12,7 @@ class LostInTheMiddleRanker(BaseRanker): The LostInTheMiddleRanker implements a ranker that reorders documents based on the lost in the middle order. "Lost in the Middle: How Language Models Use Long Contexts" by Liu et al. aims to layout paragraphs into LLM context so that relevant paragraphs are at the beginning or end of the input context, while the least relevant - information should be in the middle of a context. + information is in the middle of the context. See https://arxiv.org/abs/2307.03172 for more details. """ From 918971efee33eb2bb3aad0816ee8136146fe4a0b Mon Sep 17 00:00:00 2001 From: Darja Fokina Date: Wed, 2 Aug 2023 11:50:48 +0200 Subject: [PATCH 11/14] Update lg in docstrings --- haystack/nodes/ranker/lost_in_the_middle.py | 23 +++++++++++---------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index ba5a048ff4..65c81aa42d 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -9,9 +9,9 @@ class LostInTheMiddleRanker(BaseRanker): """ - The LostInTheMiddleRanker implements a ranker that reorders documents based on the lost in the middle order. - "Lost in the Middle: How Language Models Use Long Contexts" by Liu et al. aims to layout paragraphs into LLM - context so that relevant paragraphs are at the beginning or end of the input context, while the least relevant + The LostInTheMiddleRanker implements a ranker that reorders documents based on the "lost in the middle" order. + "Lost in the Middle: How Language Models Use Long Contexts" paper by Liu et al. aims to lay out paragraphs into LLM + context so that the relevant paragraphs are at the beginning or end of the input context, while the least relevant information is in the middle of the context. See https://arxiv.org/abs/2307.03172 for more details. @@ -36,10 +36,10 @@ def __init__(self, word_count_threshold: Optional[int] = None, truncate_document def reorder_documents(self, documents: List[Document]) -> List[Document]: """ - Orders documents based on the lost in the middle order. Assumes that all documents are ordered by relevance. + Ranks documents based on the "lost in the middle" order. Assumes that all documents are ordered by relevance. :param documents: List of Documents to merge. - :return: Documents in the lost in the middle order. + :return: Documents in the "lost in the middle" order. """ # Return empty list if no documents are provided @@ -93,13 +93,13 @@ def reorder_documents(self, documents: List[Document]) -> List[Document]: def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]: """ - Reorders documents based on the lost in the middle order. + Reranks documents based on the "lost in the middle" order. - :param query: The query to rerank documents for. + :param query: The query to reorder documents for. :param documents: List of Documents to reorder. :param top_k: The number of documents to return. - :return: The re-ranked documents. + :return: The reordered documents. """ ordered_docs = self.reorder_documents(documents=documents) valid_top_k = isinstance(top_k, int) and 0 < top_k < len(ordered_docs) @@ -113,9 +113,9 @@ def predict_batch( batch_size: Optional[int] = None, ) -> Union[List[Document], List[List[Document]]]: """ - Reorders batch of documents based on the lost in the middle order. + Reranks batch of documents based on the "lost in the middle" order. - :param queries: The queries to rerank documents for (ignored). + :param queries: The queries to reorder documents for (ignored). :param documents: List of Documents to reorder. :param top_k: The number of documents to return. :param batch_size: The number of queries to process in one batch. @@ -148,8 +148,9 @@ def _exclude_middle_elements(self, ordered_docs: List[Document], top_k: int): def _truncate(self, document: Document, word_count_threshold: int) -> Document: """ Shortens a document by cutting off the content after a specified number of words. + :param document: Document to truncate. - :param word_count_threshold: integer representing the maximum number of words + :param word_count_threshold: An integer representing the maximum number of words allowed in the truncated document. :return: Document with truncated content. """ From e1353032cec0c03f72a9dbea3eb36ae48270bbb2 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 2 Aug 2023 14:48:12 +0200 Subject: [PATCH 12/14] Sebastian PR feedback --- haystack/nodes/ranker/lost_in_the_middle.py | 46 ++++------- test/nodes/test_lost_in_the_middle.py | 88 ++++++--------------- 2 files changed, 36 insertions(+), 98 deletions(-) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index 65c81aa42d..053b2abf51 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -17,22 +17,21 @@ class LostInTheMiddleRanker(BaseRanker): See https://arxiv.org/abs/2307.03172 for more details. """ - def __init__(self, word_count_threshold: Optional[int] = None, truncate_document: Optional[bool] = False): + def __init__(self, word_count_threshold: Optional[int] = None, top_k: Optional[int] = None): """ Creates an instance of LostInTheMiddleRanker. - If truncate_document is True, you must specify a word_count_threshold as well. If truncate_document is False - and word_count_threshold is specified, the word_count_threshold will be used as a soft limit. The last document - breaching the word_count_threshold will be included in the resulting list of Documents but won't be truncated. + If 'word_count_threshold' is specified, this ranker includes all documents up until the point where adding + another document would exceed the 'word_count_threshold'. The last document that causes the threshold to + be breached will be included in the resulting list of documents, but all subsequent documents will be + discarded. - :param word_count_threshold: The maximum number of words in all ordered documents. - :param truncate_document: Whether to truncate the last document that overflows the word count threshold. + :param word_count_threshold: The maximum total number of words across all documents selected by the ranker. + :param top_k: The maximum number of documents to return. """ super().__init__() - if truncate_document and not word_count_threshold: - raise ValueError("If truncate_document is set to True, you must specify a word_count_threshold as well.") self.word_count_threshold = word_count_threshold - self.truncate_document = truncate_document + self.top_k = top_k def reorder_documents(self, documents: List[Document]) -> List[Document]: """ @@ -81,11 +80,6 @@ def reorder_documents(self, documents: List[Document]) -> List[Document]: # If the total word count meets the threshold, stop processing further documents if word_count >= self.word_count_threshold: - # If truncation is allowed, truncate the last document to meet the word count threshold - if self.truncate_document: - last_docs_length = len(documents[doc_idx].content.split()) - truncate_last_doc_length = last_docs_length - (word_count - self.word_count_threshold) - documents[doc_idx] = self._truncate(documents[doc_idx], truncate_last_doc_length) break # Return the documents in the "lost in the middle" order @@ -95,15 +89,16 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = """ Reranks documents based on the "lost in the middle" order. - :param query: The query to reorder documents for. + :param query: The query to rerank documents for (ignored). :param documents: List of Documents to reorder. :param top_k: The number of documents to return. :return: The reordered documents. """ - ordered_docs = self.reorder_documents(documents=documents) - valid_top_k = isinstance(top_k, int) and 0 < top_k < len(ordered_docs) - return self._exclude_middle_elements(ordered_docs, top_k) if valid_top_k else ordered_docs # type: ignore + top_k = top_k or self.top_k + documents_to_reorder = documents[:top_k] if top_k else documents + ranked_docs = self.reorder_documents(documents=documents_to_reorder) + return ranked_docs def predict_batch( self, @@ -118,7 +113,7 @@ def predict_batch( :param queries: The queries to reorder documents for (ignored). :param documents: List of Documents to reorder. :param top_k: The number of documents to return. - :param batch_size: The number of queries to process in one batch. + :param batch_size: The number of queries to process in one batch (ignored). :return: The reordered documents. """ @@ -132,19 +127,6 @@ def predict_batch( results.append(self.predict(query="", documents=cur_docs, top_k=top_k)) return results - def _exclude_middle_elements(self, ordered_docs: List[Document], top_k: int): - if top_k < 1 or top_k > len(ordered_docs): - raise ValueError(f"top_k must be between 1 and {len(ordered_docs)}") - exclude_count = len(ordered_docs) - top_k - middle_index = len(ordered_docs) // 2 - half_top_k = exclude_count // 2 - - start_index = middle_index - half_top_k + len(ordered_docs) % 2 - end_index = start_index + exclude_count - remaining_elements = ordered_docs[:start_index] + ordered_docs[end_index:] - - return remaining_elements - def _truncate(self, document: Document, word_count_threshold: int) -> Document: """ Shortens a document by cutting off the content after a specified number of words. diff --git a/test/nodes/test_lost_in_the_middle.py b/test/nodes/test_lost_in_the_middle.py index 1a16e0b5e1..38d41bd17a 100644 --- a/test/nodes/test_lost_in_the_middle.py +++ b/test/nodes/test_lost_in_the_middle.py @@ -58,16 +58,9 @@ def test_lost_in_the_middle_init(): # tests that LostInTheMiddleRanker initializes with default values ranker = LostInTheMiddleRanker() assert ranker.word_count_threshold is None - assert ranker.truncate_document is False - ranker = LostInTheMiddleRanker(word_count_threshold=10, truncate_document=True) + ranker = LostInTheMiddleRanker(word_count_threshold=10) assert ranker.word_count_threshold == 10 - assert ranker.truncate_document is True - - with pytest.raises( - ValueError, match="If truncate_document is set to True, you must specify a word_count_threshold" - ): - LostInTheMiddleRanker(truncate_document=True) @pytest.mark.unit @@ -95,26 +88,6 @@ def test_word_count_threshold_greater_than_total_number_of_words_returns_all_doc assert all(doc.content == expected_order[idx] for idx, doc in enumerate(ordered_docs)) -@pytest.mark.unit -def test_truncation_with_threshold(): - # tests that truncation works as expected - ranker = LostInTheMiddleRanker(word_count_threshold=9, truncate_document=True) - docs = [ - Document("word1 word1"), - Document("word2 word2"), - Document("word3 word3"), - Document("word4 word4"), - Document("word5 word5"), - Document("word6 word6"), - Document("word7 word7"), - Document("word8 word8"), - Document("word9 word9"), - ] - result, _ = ranker.run(query="", documents=docs) - expected_order = "word1 word1 word3 word3 word5 word4 word4 word2 word2" - assert expected_order == " ".join(doc.content for doc in result["documents"]) - - @pytest.mark.unit def test_empty_documents_returns_empty_list(): ranker = LostInTheMiddleRanker() @@ -139,50 +112,33 @@ def test_non_textual_documents(): @pytest.mark.unit -@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8]) -def test_lost_in_the_middle_order_odd_with_top_k(top_k: int): +@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8, 12, 20]) +def test_lost_in_the_middle_order_with_postive_top_k(top_k: int): # tests that lost_in_the_middle order works with an odd number of documents and a top_k parameter docs = [Document(str(i)) for i in range(1, 10)] ranker = LostInTheMiddleRanker() - result = ranker._exclude_middle_elements(docs, top_k=top_k) - assert len(result) == top_k - - reverse_top_k = len(docs) - top_k - middle = len(docs) // 2 + 1 - expected_docs = docs[: middle - reverse_top_k // 2] + docs[middle + reverse_top_k // 2 + reverse_top_k % 2 :] - assert result == expected_docs + result = ranker.predict(query="irrelevant", documents=docs, top_k=top_k) + if top_k < len(docs): + # top_k is less than the number of documents, so only the top_k documents should be returned in LITM order + assert len(result) == top_k + expected_order = ranker.predict(query="irrelevant", documents=[Document(str(i)) for i in range(1, top_k + 1)]) + assert result == expected_order + else: + # top_k is greater than the number of documents, so all documents should be returned in LITM order + assert len(result) == len(docs) + assert result == ranker.predict(query="irrelevant", documents=docs) @pytest.mark.unit -@pytest.mark.parametrize("top_k", [1, 2, 3, 4, 5, 6, 7, 8]) -def test_lost_in_the_middle_order_even_with_top_k(top_k: int): - # tests that lost_in_the_middle order works with an even number of documents and a top_k parameter - docs = [Document(str(i)) for i in range(1, 9)] - ranker = LostInTheMiddleRanker() - result = ranker._exclude_middle_elements(docs, top_k=top_k) - assert len(result) == top_k - - reverse_top_k = len(docs) - top_k - middle = len(docs) // 2 - expected_docs = docs[: middle - reverse_top_k // 2] + docs[middle + reverse_top_k // 2 + reverse_top_k % 2 :] - assert result == expected_docs - - -@pytest.mark.unit -@pytest.mark.parametrize("top_k", [-5, 15]) -def test_lost_in_the_middle_order_odd_with_invalid_top_k(top_k: int): +@pytest.mark.parametrize("top_k", [-20, -10, -5, -1]) +def test_lost_in_the_middle_order_with_negative_top_k(top_k: int): # tests that lost_in_the_middle order works with an odd number of documents and an invalid top_k parameter docs = [Document(str(i)) for i in range(1, 10)] ranker = LostInTheMiddleRanker() - with pytest.raises(ValueError): - ranker._exclude_middle_elements(docs, top_k=top_k) - - -@pytest.mark.unit -@pytest.mark.parametrize("top_k", [-5, 15]) -def test_lost_in_the_middle_order_even_with_invalid_top_k(top_k: int): - # tests that lost_in_the_middle order works with an even number of documents and an invalid top_k parameter - docs = [Document(str(i)) for i in range(1, 9)] - ranker = LostInTheMiddleRanker() - with pytest.raises(ValueError): - ranker._exclude_middle_elements(docs, top_k=top_k) + result = ranker.predict(query="irrelevant", documents=docs, top_k=top_k) + if top_k < len(docs) * -1: + assert len(result) == 0 # top_k is too negative, so no documents should be returned + else: + # top_k is negative, subtract it from the total number of documents to get the expected number of documents + expected_docs = ranker.predict(query="irrelevant", documents=docs, top_k=len(docs) + top_k) + assert result == expected_docs From d0d1484dc78700b0c53d42f540f8595781cd70bd Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 2 Aug 2023 15:39:29 +0200 Subject: [PATCH 13/14] Add check for invalid values of word_count_threshold --- examples/web_lfqa_improved.py | 2 +- haystack/nodes/ranker/lost_in_the_middle.py | 7 ++++++- test/nodes/test_lost_in_the_middle.py | 10 ++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/web_lfqa_improved.py b/examples/web_lfqa_improved.py index 5c070dd093..cfef824f7c 100644 --- a/examples/web_lfqa_improved.py +++ b/examples/web_lfqa_improved.py @@ -23,7 +23,7 @@ """ prompt_node = PromptNode( - "gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=640 + "gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=768 ) web_retriever = WebRetriever(api_key=search_key, top_search_results=5, mode="preprocessed_documents", top_k=50) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index 053b2abf51..aab2504f86 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -30,6 +30,11 @@ def __init__(self, word_count_threshold: Optional[int] = None, top_k: Optional[i :param top_k: The maximum number of documents to return. """ super().__init__() + if isinstance(word_count_threshold, int) and word_count_threshold <= 0: + raise ValueError( + f"Invalid value for word_count_threshold: {word_count_threshold}. " + f"word_count_threshold must be a positive integer." + ) self.word_count_threshold = word_count_threshold self.top_k = top_k @@ -89,7 +94,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] = """ Reranks documents based on the "lost in the middle" order. - :param query: The query to rerank documents for (ignored). + :param query: The query to reorder documents for (ignored). :param documents: List of Documents to reorder. :param top_k: The number of documents to return. diff --git a/test/nodes/test_lost_in_the_middle.py b/test/nodes/test_lost_in_the_middle.py index 38d41bd17a..db89c3926e 100644 --- a/test/nodes/test_lost_in_the_middle.py +++ b/test/nodes/test_lost_in_the_middle.py @@ -63,6 +63,16 @@ def test_lost_in_the_middle_init(): assert ranker.word_count_threshold == 10 +@pytest.mark.unit +def test_lost_in_the_middle_init_invalid_word_count_threshold(): + # tests that LostInTheMiddleRanker raises an error when word_count_threshold is <= 0 + with pytest.raises(ValueError, match="Invalid value for word_count_threshold"): + LostInTheMiddleRanker(word_count_threshold=0) + + with pytest.raises(ValueError, match="Invalid value for word_count_threshold"): + LostInTheMiddleRanker(word_count_threshold=-5) + + @pytest.mark.unit def test_lost_in_the_middle_with_word_count_threshold(): # tests that lost_in_the_middle with word_count_threshold works as expected From 1d8ca6868d979e3f713b6718c4e30058d8ddd2ea Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 2 Aug 2023 15:43:57 +0200 Subject: [PATCH 14/14] Remove _truncate as it is not needed any more --- haystack/nodes/ranker/lost_in_the_middle.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/haystack/nodes/ranker/lost_in_the_middle.py b/haystack/nodes/ranker/lost_in_the_middle.py index aab2504f86..1ffbe5bcd3 100644 --- a/haystack/nodes/ranker/lost_in_the_middle.py +++ b/haystack/nodes/ranker/lost_in_the_middle.py @@ -131,19 +131,3 @@ def predict_batch( assert isinstance(cur_docs, list) results.append(self.predict(query="", documents=cur_docs, top_k=top_k)) return results - - def _truncate(self, document: Document, word_count_threshold: int) -> Document: - """ - Shortens a document by cutting off the content after a specified number of words. - - :param document: Document to truncate. - :param word_count_threshold: An integer representing the maximum number of words - allowed in the truncated document. - :return: Document with truncated content. - """ - words = document.content.split() - if len(words) > word_count_threshold: - # -1 to remove trailing whitespace - cut_off = sum(len(word) + 1 for word in words[:word_count_threshold]) - 1 - document.content = document.content[:cut_off] - return document