-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(document-search): LiteLLM Reranker #109
base: main
Are you sure you want to change the base?
Changes from all commits
ce27275
2a50318
fa9e528
2dcf188
8f71200
985fa0d
1dfe701
43ef956
2ecde29
9046252
9be5157
5efc234
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# /// script | ||
# requires-python = ">=3.10" | ||
# dependencies = [ | ||
# "ragbits-document-search", | ||
# "ragbits-core[litellm]", | ||
# ] | ||
# /// | ||
import asyncio | ||
|
||
from ragbits.document_search import DocumentSearch | ||
from ragbits.document_search.documents.document import DocumentMeta | ||
|
||
documents = [ | ||
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."), | ||
DocumentMeta.create_text_document_from_literal( | ||
"Why doesn't James Bond fart in bed? Because it would blow his cover." | ||
), | ||
DocumentMeta.create_text_document_from_literal( | ||
"Why programmers don't like to swim? Because they're scared of the floating points." | ||
), | ||
] | ||
|
||
config = { | ||
"embedder": {"type": "LiteLLMEmbeddings"}, | ||
"vector_store": {"type": "InMemoryVectorStore"}, | ||
"reranker": { | ||
"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker", | ||
"config": {"model": "cohere/rerank-english-v3.0"}, | ||
}, | ||
"providers": {"txt": {"type": "DummyProvider"}}, | ||
} | ||
|
||
|
||
async def main(): | ||
"""Run the example.""" | ||
document_search = DocumentSearch.from_config(config) | ||
|
||
for document in documents: | ||
await document_search.ingest_document(document) | ||
|
||
results = await document_search.search("I'm boiling my water and I need a joke") | ||
print(results) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from typing import List | ||
|
||
import litellm | ||
|
||
from ragbits.document_search.documents.element import Element, TextElement | ||
from ragbits.document_search.retrieval.rerankers.base import Reranker | ||
|
||
|
||
class LiteLLMReranker(Reranker): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think you should add it here: https://github.com/deepsense-ai/ragbits/blob/main/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/__init__.py#L9 for it to work with config like this:
|
||
""" | ||
A LiteLLM reranker for providers such as Cohere, Together AI, Azure AI. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be helpful to describe a bit these attributes below or link the documentation, I couldn't find |
||
""" | ||
|
||
model: str | ||
top_n: int | None = None | ||
return_documents: bool = False | ||
rank_fields: list[str] | None = None | ||
max_chunks_per_doc: int | None = None | ||
|
||
def rerank(self, chunks: List[Element], query: str) -> List[Element]: | ||
""" | ||
Reranking with LiteLLM API. | ||
|
||
Args: | ||
chunks: The chunks to rerank. | ||
query: The query to rerank the chunks against. | ||
|
||
Returns: | ||
The reranked chunks. | ||
|
||
Raises: | ||
ValueError: If chunks are not a list of TextElement objects. | ||
""" | ||
if not all(isinstance(chunk, TextElement) for chunk in chunks): | ||
raise ValueError("All chunks must be TextElement instances") | ||
|
||
documents = [chunk.content if isinstance(chunk, TextElement) else None for chunk in chunks] | ||
|
||
response = litellm.rerank( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe it would be better to use |
||
model=self.model, | ||
query=query, | ||
documents=documents, | ||
top_n=self.top_n, | ||
return_documents=self.return_documents, | ||
rank_fields=self.rank_fields, | ||
max_chunks_per_doc=self.max_chunks_per_doc, | ||
) | ||
target_order = [result["index"] for result in response.results] | ||
reranked_chunks = [chunks[i] for i in target_order] | ||
|
||
return reranked_chunks |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
from ragbits.document_search.documents.document import DocumentMeta, DocumentType | ||
from ragbits.document_search.documents.element import Element, TextElement | ||
from ragbits.document_search.documents.sources import LocalFileSource | ||
from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker | ||
|
||
|
||
@pytest.fixture | ||
def mock_litellm_response(monkeypatch): | ||
class MockResponse: | ||
results = [{"index": 1}, {"index": 0}] | ||
|
||
def mock_rerank(*args, **kwargs): | ||
return MockResponse() | ||
|
||
monkeypatch.setattr("litellm.rerank", mock_rerank) | ||
|
||
|
||
@pytest.fixture | ||
def reranker(): | ||
return LiteLLMReranker( | ||
model="test_model", | ||
top_n=2, | ||
return_documents=True, | ||
rank_fields=["content"], | ||
max_chunks_per_doc=1, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mock_document_meta(): | ||
return DocumentMeta(document_type=DocumentType.TXT, source=LocalFileSource(path=Path("test.txt"))) | ||
|
||
|
||
@pytest.fixture | ||
def mock_custom_element(mock_document_meta): | ||
class CustomElement(Element): | ||
def get_key(self): | ||
return "test_key" | ||
|
||
return CustomElement(element_type="test_type", document_meta=mock_document_meta) | ||
|
||
|
||
def test_rerank_success(reranker, mock_litellm_response, mock_document_meta): | ||
chunks = [ | ||
TextElement(content="chunk1", document_meta=mock_document_meta), | ||
TextElement(content="chunk2", document_meta=mock_document_meta), | ||
] | ||
query = "test query" | ||
|
||
reranked_chunks = reranker.rerank(chunks, query) | ||
|
||
assert reranked_chunks[0].content == "chunk2" | ||
assert reranked_chunks[1].content == "chunk1" | ||
|
||
|
||
def test_rerank_invalid_chunks(reranker, mock_custom_element): | ||
chunks = [mock_custom_element] | ||
query = "test query" | ||
|
||
with pytest.raises(ValueError, match="All chunks must be TextElement instances"): | ||
reranker.rerank(chunks, query) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a stupid question, but why is
pydantic
needed here?