Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add top_k parameter to rerank endpoint #396

Merged
merged 4 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,21 @@ async def embed(
return embeddings, usage

async def rerank(
self, *, query: str, docs: list[str], raw_scores: bool = False
self,
*,
query: str,
docs: list[str],
raw_scores: bool = False,
top_k: Optional[int] = None,
) -> tuple[list[float], int]:
"""rerank multiple sentences

Kwargs:
query (str): query to be reranked
docs (list[str]): docs to be reranked
raw_scores (bool): return raw scores instead of sigmoid
top_k (Optional[int]): number of top scores to return after reranking
if top_k is None, <= 0 or out of range, all scores are returned

Raises:
ValueError: raised if engine is not started yet
Expand All @@ -172,7 +179,10 @@ async def rerank(
"""
self._assert_running()
scores, usage = await self._batch_handler.rerank(
query=query, docs=docs, raw_scores=raw_scores
query=query,
docs=docs,
raw_scores=raw_scores,
top_k=top_k,
)

return scores, usage
Expand Down
1 change: 1 addition & 0 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class RerankInput(BaseModel):
return_documents: bool = False
raw_scores: bool = False
model: str = "default/not-specified"
top_k: Optional[int] = Field(default=None, gt=0)


class _ReRankObject(BaseModel):
Expand Down
13 changes: 11 additions & 2 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import Any, List, Sequence, Set, Union
from typing import Any, List, Sequence, Set, Union, Optional

import numpy as np

Expand Down Expand Up @@ -147,14 +147,20 @@ async def embed(
return embeddings, usage

async def rerank(
self, query: str, docs: list[str], raw_scores: bool = False
self,
query: str,
docs: list[str],
raw_scores: bool = False,
top_k: Optional[int] = None,
) -> tuple[list[float], int]:
"""Schedule a query to be reranked with documents. Awaits until reranked.

Args:
query (str): query for reranking
docs (list[str]): documents to be reranked
raw_scores (bool): return raw scores instead of sigmoid
top_k (Optional[int]): number of top scores to return after reranking
if top_k is None, <= 0 or out of range, all scores are returned

Raises:
ModelNotDeployedError: If loaded model does not expose `embed`
Expand All @@ -172,6 +178,9 @@ async def rerank(
rerankables = [ReRankSingle(query=query, document=doc) for doc in docs]
scores, usage = await self._schedule(rerankables)

if top_k is not None and top_k > 0:
scores = scores[:top_k]

if not raw_scores:
# perform sigmoid on scores
scores = (1 / (1 + np.exp(-np.array(scores)))).tolist()
Expand Down
5 changes: 4 additions & 1 deletion libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,10 @@ async def _rerank(data: RerankInput):
start = time.perf_counter()

scores, usage = await engine.rerank(
query=data.query, docs=data.documents, raw_scores=data.raw_scores
query=data.query,
docs=data.documents,
raw_scores=data.raw_scores,
top_k=data.top_k,
)

duration = (time.perf_counter() - start) * 1000
Expand Down
56 changes: 56 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_torch_reranker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
import sys
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
from transformers import pipeline # type: ignore[import-untyped]
Expand Down Expand Up @@ -89,6 +90,61 @@ async def test_reranker(client, model_base, helpers):
for i, pred in enumerate(predictions):
assert abs(rdata_results[i]["relevance_score"] - pred["score"]) < 0.01

@pytest.mark.anyio
async def test_reranker_top_k(client):
query = "Where is the Eiffel Tower located?"
documents = [
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential flaw: If backend does not do sorting, top_k=1 will just take always the first result.

Better unit test: Wrap this with a unit test, and return_text. Make sure that topk=1 solution is always paris in this unit test. e.g. use https://docs.python.org/3/library/itertools.html

# someting like:
for return_text in [true, False]
    for raw_score in [True, False]:
          for permutation in itertools.permutation([..paris, ..us, uk]):
               # you test above
```

"The Eiffel Tower is located in Paris, France",
"The Eiffel Tower is located in the United States.",
"The Eiffel Tower is located in the United Kingdom.",
]

response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": 1},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["results"]
assert len(rdata_results) == 1

response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": 2},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["results"]
assert len(rdata_results) == 2

response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": sys.maxsize},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["results"]
assert len(rdata_results) == len(documents)

@pytest.mark.anyio
async def test_reranker_invalid_top_k(client):
query = "Where is the Eiffel Tower located?"
documents = [
"The Eiffel Tower is located in Paris, France",
"The Eiffel Tower is located in the United States.",
"The Eiffel Tower is located in the United Kingdom.",
]
response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": -1},
)
assert response.status_code == 422

response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": 0},
)
assert response.status_code == 422

@pytest.mark.anyio
async def test_reranker_cant_embed_or_classify(client):
Expand Down
44 changes: 44 additions & 0 deletions libs/infinity_emb/tests/unit_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,50 @@
np.testing.assert_almost_equal(rankings[:3], [0.83, 0.085, 0.028], decimal=2)


@pytest.mark.anyio
@pytest.mark.parametrize("engine", [InferenceEngine.torch, InferenceEngine.optimum])
async def test_engine_reranker_top_k(engine):
model_unpatched = CrossEncoder(

Check failure on line 104 in libs/infinity_emb/tests/unit_test/test_engine.py

View workflow job for this annotation

GitHub Actions / lint-infinity_emb / make lint #3.9

Ruff (F841)

tests/unit_test/test_engine.py:104:5: F841 Local variable `model_unpatched` is assigned to but never used

Check failure on line 104 in libs/infinity_emb/tests/unit_test/test_engine.py

View workflow job for this annotation

GitHub Actions / lint-infinity_emb / make lint #3.12

Ruff (F841)

tests/unit_test/test_engine.py:104:5: F841 Local variable `model_unpatched` is assigned to but never used
"mixedbread-ai/mxbai-rerank-xsmall-v1",
)
query = "Where is Paris?"
documents = [
"Paris is the capital of France.",
"Berlin is the capital of Germany.",
"You can now purchase my favorite dish",
]
engine = AsyncEmbeddingEngine.from_args(
EngineArgs(
model_name_or_path="mixedbread-ai/mxbai-rerank-xsmall-v1",
engine=InferenceEngine.torch,
model_warmup=False,
)
)

query_docs = [(query, doc) for doc in documents]

Check failure on line 121 in libs/infinity_emb/tests/unit_test/test_engine.py

View workflow job for this annotation

GitHub Actions / lint-infinity_emb / make lint #3.9

Ruff (F841)

tests/unit_test/test_engine.py:121:5: F841 Local variable `query_docs` is assigned to but never used

Check failure on line 121 in libs/infinity_emb/tests/unit_test/test_engine.py

View workflow job for this annotation

GitHub Actions / lint-infinity_emb / make lint #3.12

Ruff (F841)

tests/unit_test/test_engine.py:121:5: F841 Local variable `query_docs` is assigned to but never used

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=None)
assert len(rankings) == len(documents)

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=-1)
assert len(rankings) == len(documents)

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=0)
assert len(rankings) == len(documents)

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=3)
assert len(rankings) == 3

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=len(documents)+sys.maxsize)
assert len(rankings) == len(documents)



@pytest.mark.anyio
async def test_async_api_torch_CLASSIFY():
sentences = ["This is awesome.", "I am depressed."]
Expand Down
Loading