From 47524b72572bc44a0801e8d48cf0a8c0f355d444 Mon Sep 17 00:00:00 2001 From: Benjamin Gustin Date: Sat, 5 Oct 2024 21:15:16 +0200 Subject: [PATCH] add tests for top_k --- .../tests/end_to_end/test_torch_reranker.py | 56 +++++++++++++++++++ .../tests/unit_test/test_engine.py | 44 +++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py b/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py index fe1667b0..39004bce 100644 --- a/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py +++ b/libs/infinity_emb/tests/end_to_end/test_torch_reranker.py @@ -1,5 +1,6 @@ import pytest import torch +import sys from asgi_lifespan import LifespanManager from httpx import AsyncClient from transformers import pipeline # type: ignore[import-untyped] @@ -89,6 +90,61 @@ async def test_reranker(client, model_base, helpers): for i, pred in enumerate(predictions): assert abs(rdata_results[i]["relevance_score"] - pred["score"]) < 0.01 +@pytest.mark.anyio +async def test_reranker_top_k(client): + query = "Where is the Eiffel Tower located?" + documents = [ + "The Eiffel Tower is located in Paris, France", + "The Eiffel Tower is located in the United States.", + "The Eiffel Tower is located in the United Kingdom.", + ] + + response = await client.post( + f"{PREFIX}/rerank", + json={"model": MODEL, "query": query, "documents": documents, "top_k": 1}, + ) + assert response.status_code == 200 + rdata = response.json() + rdata_results = rdata["results"] + assert len(rdata_results) == 1 + + response = await client.post( + f"{PREFIX}/rerank", + json={"model": MODEL, "query": query, "documents": documents, "top_k": 2}, + ) + assert response.status_code == 200 + rdata = response.json() + rdata_results = rdata["results"] + assert len(rdata_results) == 2 + + response = await client.post( + f"{PREFIX}/rerank", + json={"model": MODEL, "query": query, "documents": documents, "top_k": sys.maxsize}, + ) + assert response.status_code == 200 + rdata = response.json() + rdata_results = rdata["results"] + assert len(rdata_results) == len(documents) + +@pytest.mark.anyio +async def test_reranker_invalid_top_k(client): + query = "Where is the Eiffel Tower located?" + documents = [ + "The Eiffel Tower is located in Paris, France", + "The Eiffel Tower is located in the United States.", + "The Eiffel Tower is located in the United Kingdom.", + ] + response = await client.post( + f"{PREFIX}/rerank", + json={"model": MODEL, "query": query, "documents": documents, "top_k": -1}, + ) + assert response.status_code == 422 + + response = await client.post( + f"{PREFIX}/rerank", + json={"model": MODEL, "query": query, "documents": documents, "top_k": 0}, + ) + assert response.status_code == 422 @pytest.mark.anyio async def test_reranker_cant_embed_or_classify(client): diff --git a/libs/infinity_emb/tests/unit_test/test_engine.py b/libs/infinity_emb/tests/unit_test/test_engine.py index 0f39af5a..3541fffa 100644 --- a/libs/infinity_emb/tests/unit_test/test_engine.py +++ b/libs/infinity_emb/tests/unit_test/test_engine.py @@ -98,6 +98,50 @@ async def test_engine_reranker_torch_opt(engine): np.testing.assert_almost_equal(rankings[:3], [0.83, 0.085, 0.028], decimal=2) +@pytest.mark.anyio +@pytest.mark.parametrize("engine", [InferenceEngine.torch, InferenceEngine.optimum]) +async def test_engine_reranker_top_k(engine): + model_unpatched = CrossEncoder( + "mixedbread-ai/mxbai-rerank-xsmall-v1", + ) + query = "Where is Paris?" + documents = [ + "Paris is the capital of France.", + "Berlin is the capital of Germany.", + "You can now purchase my favorite dish", + ] + engine = AsyncEmbeddingEngine.from_args( + EngineArgs( + model_name_or_path="mixedbread-ai/mxbai-rerank-xsmall-v1", + engine=InferenceEngine.torch, + model_warmup=False, + ) + ) + + query_docs = [(query, doc) for doc in documents] + + async with engine: + rankings, usage = await engine.rerank(query=query, docs=documents, top_k=None) + assert len(rankings) == len(documents) + + async with engine: + rankings, usage = await engine.rerank(query=query, docs=documents, top_k=-1) + assert len(rankings) == len(documents) + + async with engine: + rankings, usage = await engine.rerank(query=query, docs=documents, top_k=0) + assert len(rankings) == len(documents) + + async with engine: + rankings, usage = await engine.rerank(query=query, docs=documents, top_k=3) + assert len(rankings) == 3 + + async with engine: + rankings, usage = await engine.rerank(query=query, docs=documents, top_k=len(documents)+sys.maxsize) + assert len(rankings) == len(documents) + + + @pytest.mark.anyio async def test_async_api_torch_CLASSIFY(): sentences = ["This is awesome.", "I am depressed."]