Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): LiteLLM Reranker #109

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
default_language_version:
python: python3.10

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down
46 changes: 46 additions & 0 deletions packages/ragbits-document-search/examples/reranker_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-document-search",
# "ragbits-core[litellm]",
# ]
# ///
import asyncio

from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
DocumentMeta.create_text_document_from_literal(
"Why doesn't James Bond fart in bed? Because it would blow his cover."
),
DocumentMeta.create_text_document_from_literal(
"Why programmers don't like to swim? Because they're scared of the floating points."
),
]

config = {
"embedder": {"type": "LiteLLMEmbeddings"},
"vector_store": {"type": "InMemoryVectorStore"},
"reranker": {
"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker",
"config": {"model": "cohere/rerank-english-v3.0"},
},
"providers": {"txt": {"type": "DummyProvider"}},
}


async def main():
"""Run the example."""
document_search = DocumentSearch.from_config(config)

for document in documents:
await document_search.ingest_document(document)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)


if __name__ == "__main__":
asyncio.run(main())
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig())
entries = await self.vector_store.retrieve(search_vector[0], **search_config.vector_store_kwargs)
elements.extend([Element.from_vector_db_entry(entry) for entry in entries])

return self.reranker.rerank(elements)
return self.reranker.rerank(elements, query=query)

async def ingest_document(
self, document: Union[DocumentMeta, Document], document_processor: Optional[BaseProvider] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import abc

from pydantic import BaseModel

from ragbits.document_search.documents.element import Element


class Reranker(abc.ABC):
class Reranker(BaseModel, abc.ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a stupid question, but why is pydantic needed here?

"""
Reranks chunks retrieved from vector store.
"""

@staticmethod
@abc.abstractmethod
def rerank(chunks: list[Element]) -> list[Element]:
def rerank(self, chunks: list[Element], query: str) -> list[Element]:
"""
Rerank chunks.

Args:
chunks: The chunks to rerank.
query: The query to rerank the chunks against.

Returns:
The reranked chunks.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import List

import litellm

from ragbits.document_search.documents.element import Element, TextElement
from ragbits.document_search.retrieval.rerankers.base import Reranker


class LiteLLMReranker(Reranker):
Copy link
Collaborator

@akotyla akotyla Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you should add it here: https://github.com/deepsense-ai/ragbits/blob/main/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py#L9 for it to work with config like this:

"reranker": {
        "type": "LiteLLMReranker",
        "config": {"model": "cohere/rerank-english-v3.0"},
    }

"""
A LiteLLM reranker for providers such as Cohere, Together AI, Azure AI.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be helpful to describe a bit these attributes below or link the documentation, I couldn't find litellm, but cohere is probably the same https://docs.cohere.com/v2/reference/rerank#request

"""

model: str
top_n: int | None = None
return_documents: bool = False
rank_fields: list[str] | None = None
max_chunks_per_doc: int | None = None

def rerank(self, chunks: List[Element], query: str) -> List[Element]:
"""
Reranking with LiteLLM API.

Args:
chunks: The chunks to rerank.
query: The query to rerank the chunks against.

Returns:
The reranked chunks.

Raises:
ValueError: If chunks are not a list of TextElement objects.
"""
if not all(isinstance(chunk, TextElement) for chunk in chunks):
raise ValueError("All chunks must be TextElement instances")

documents = [chunk.content if isinstance(chunk, TextElement) else None for chunk in chunks]

response = litellm.rerank(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it would be better to use arerank here: https://docs.litellm.ai/docs/rerank#async-usage, what do you think?

model=self.model,
query=query,
documents=documents,
top_n=self.top_n,
return_documents=self.return_documents,
rank_fields=self.rank_fields,
max_chunks_per_doc=self.max_chunks_per_doc,
)
target_order = [result["index"] for result in response.results]
reranked_chunks = [chunks[i] for i in target_order]

return reranked_chunks
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ class NoopReranker(Reranker):
A no-op reranker that does not change the order of the chunks.
"""

@staticmethod
def rerank(chunks: List[Element]) -> List[Element]:
def rerank(self, chunks: List[Element], query: str) -> List[Element]: # pylint: disable=unused-argument
"""
No reranking, returning the same chunks as in input.

Args:
chunks: The chunks to rerank.
query: The query to rerank the chunks against.

Returns:
The reranked chunks.
Expand Down
65 changes: 65 additions & 0 deletions packages/ragbits-document-search/tests/unit/test_rerankers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from pathlib import Path

import pytest

from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, TextElement
from ragbits.document_search.documents.sources import LocalFileSource
from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker


@pytest.fixture
def mock_litellm_response(monkeypatch):
class MockResponse:
results = [{"index": 1}, {"index": 0}]

def mock_rerank(*args, **kwargs):
return MockResponse()

monkeypatch.setattr("litellm.rerank", mock_rerank)


@pytest.fixture
def reranker():
return LiteLLMReranker(
model="test_model",
top_n=2,
return_documents=True,
rank_fields=["content"],
max_chunks_per_doc=1,
)


@pytest.fixture
def mock_document_meta():
return DocumentMeta(document_type=DocumentType.TXT, source=LocalFileSource(path=Path("test.txt")))


@pytest.fixture
def mock_custom_element(mock_document_meta):
class CustomElement(Element):
def get_key(self):
return "test_key"

return CustomElement(element_type="test_type", document_meta=mock_document_meta)


def test_rerank_success(reranker, mock_litellm_response, mock_document_meta):
chunks = [
TextElement(content="chunk1", document_meta=mock_document_meta),
TextElement(content="chunk2", document_meta=mock_document_meta),
]
query = "test query"

reranked_chunks = reranker.rerank(chunks, query)

assert reranked_chunks[0].content == "chunk2"
assert reranked_chunks[1].content == "chunk1"


def test_rerank_invalid_chunks(reranker, mock_custom_element):
chunks = [mock_custom_element]
query = "test query"

with pytest.raises(ValueError, match="All chunks must be TextElement instances"):
reranker.rerank(chunks, query)
14 changes: 2 additions & 12 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading