Skip to content

Commit

Permalink
add tests for top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
aloababa committed Oct 5, 2024
1 parent cf4f97a commit 47524b7
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
56 changes: 56 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_torch_reranker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
import sys
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
from transformers import pipeline # type: ignore[import-untyped]
Expand Down Expand Up @@ -89,6 +90,61 @@ async def test_reranker(client, model_base, helpers):
for i, pred in enumerate(predictions):
assert abs(rdata_results[i]["relevance_score"] - pred["score"]) < 0.01

@pytest.mark.anyio
async def test_reranker_top_k(client):
query = "Where is the Eiffel Tower located?"
documents = [
"The Eiffel Tower is located in Paris, France",
"The Eiffel Tower is located in the United States.",
"The Eiffel Tower is located in the United Kingdom.",
]

response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": 1},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["results"]
assert len(rdata_results) == 1

response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": 2},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["results"]
assert len(rdata_results) == 2

response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": sys.maxsize},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["results"]
assert len(rdata_results) == len(documents)

@pytest.mark.anyio
async def test_reranker_invalid_top_k(client):
query = "Where is the Eiffel Tower located?"
documents = [
"The Eiffel Tower is located in Paris, France",
"The Eiffel Tower is located in the United States.",
"The Eiffel Tower is located in the United Kingdom.",
]
response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": -1},
)
assert response.status_code == 422

response = await client.post(
f"{PREFIX}/rerank",
json={"model": MODEL, "query": query, "documents": documents, "top_k": 0},
)
assert response.status_code == 422

@pytest.mark.anyio
async def test_reranker_cant_embed_or_classify(client):
Expand Down
44 changes: 44 additions & 0 deletions libs/infinity_emb/tests/unit_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,50 @@ async def test_engine_reranker_torch_opt(engine):
np.testing.assert_almost_equal(rankings[:3], [0.83, 0.085, 0.028], decimal=2)


@pytest.mark.anyio
@pytest.mark.parametrize("engine", [InferenceEngine.torch, InferenceEngine.optimum])
async def test_engine_reranker_top_k(engine):
model_unpatched = CrossEncoder(

Check failure on line 104 in libs/infinity_emb/tests/unit_test/test_engine.py

View workflow job for this annotation

GitHub Actions / lint-infinity_emb / make lint #3.9

Ruff (F841)

tests/unit_test/test_engine.py:104:5: F841 Local variable `model_unpatched` is assigned to but never used

Check failure on line 104 in libs/infinity_emb/tests/unit_test/test_engine.py

View workflow job for this annotation

GitHub Actions / lint-infinity_emb / make lint #3.12

Ruff (F841)

tests/unit_test/test_engine.py:104:5: F841 Local variable `model_unpatched` is assigned to but never used
"mixedbread-ai/mxbai-rerank-xsmall-v1",
)
query = "Where is Paris?"
documents = [
"Paris is the capital of France.",
"Berlin is the capital of Germany.",
"You can now purchase my favorite dish",
]
engine = AsyncEmbeddingEngine.from_args(
EngineArgs(
model_name_or_path="mixedbread-ai/mxbai-rerank-xsmall-v1",
engine=InferenceEngine.torch,
model_warmup=False,
)
)

query_docs = [(query, doc) for doc in documents]

Check failure on line 121 in libs/infinity_emb/tests/unit_test/test_engine.py

View workflow job for this annotation

GitHub Actions / lint-infinity_emb / make lint #3.9

Ruff (F841)

tests/unit_test/test_engine.py:121:5: F841 Local variable `query_docs` is assigned to but never used

Check failure on line 121 in libs/infinity_emb/tests/unit_test/test_engine.py

View workflow job for this annotation

GitHub Actions / lint-infinity_emb / make lint #3.12

Ruff (F841)

tests/unit_test/test_engine.py:121:5: F841 Local variable `query_docs` is assigned to but never used

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=None)
assert len(rankings) == len(documents)

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=-1)
assert len(rankings) == len(documents)

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=0)
assert len(rankings) == len(documents)

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=3)
assert len(rankings) == 3

async with engine:
rankings, usage = await engine.rerank(query=query, docs=documents, top_k=len(documents)+sys.maxsize)
assert len(rankings) == len(documents)



@pytest.mark.anyio
async def test_async_api_torch_CLASSIFY():
sentences = ["This is awesome.", "I am depressed."]
Expand Down

0 comments on commit 47524b7

Please sign in to comment.