Skip to content

Commit

Permalink
Tests: Add base64 test for audio endpoint (#414)
Browse files Browse the repository at this point in the history
* add base64 test for audio endpoint

* use audio url constant

* add failure test for base64

* remove print statement
  • Loading branch information
wirthual authored Oct 11, 2024
1 parent 6716258 commit 61432aa
Showing 1 changed file with 53 additions and 5 deletions.
58 changes: 53 additions & 5 deletions libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import base64

import numpy as np
import pytest
import requests
import torch
from asgi_lifespan import LifespanManager
from fastapi import status
Expand Down Expand Up @@ -46,7 +50,7 @@ async def test_model_route(client):

@pytest.mark.anyio
async def test_audio_single(client):
audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"
audio_url = pytest.AUDIO_SAMPLE_URL

response = await client.post(
f"{PREFIX}/embeddings_audio",
Expand Down Expand Up @@ -80,7 +84,7 @@ async def test_audio_single_text_only(client):

@pytest.mark.anyio
async def test_meta(client, helpers):
audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"
audio_url = pytest.AUDIO_SAMPLE_URL

text_input = ["a beep", "a horse", "a fish"]
audio_input = [audio_url]
Expand Down Expand Up @@ -119,9 +123,7 @@ async def test_meta(client, helpers):
async def test_audio_multiple(client):
for route in [f"{PREFIX}/embeddings_audio", f"{PREFIX}/embeddings"]:
for no_of_audios in [1, 5, 10]:
audio_urls = [
"https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"
] * no_of_audios
audio_urls = [pytest.AUDIO_SAMPLE_URL] * no_of_audios

response = await client.post(
route,
Expand All @@ -141,6 +143,52 @@ async def test_audio_multiple(client):
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
async def test_audio_base64(client):
bytes_downloaded = requests.get(pytest.AUDIO_SAMPLE_URL).content
base_64_audio = base64.b64encode(bytes_downloaded).decode("utf-8")

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={
"model": MODEL,
"input": [
"data:audio/wav;base64," + base_64_audio,
pytest.AUDIO_SAMPLE_URL,
],
},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0

np.testing.assert_array_equal(
rdata_results[0]["embedding"], rdata_results[1]["embedding"]
)


@pytest.mark.anyio
async def test_audio_base64_fail(client):
base_64_audio = "somethingsomething"

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={
"model": MODEL,
"input": [
"data:audio/wav;base64," + base_64_audio,
pytest.AUDIO_SAMPLE_URL,
],
},
)

assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


@pytest.mark.anyio
async def test_audio_fail(client):
for route in [f"{PREFIX}/embeddings_audio", f"{PREFIX}/embeddings"]:
Expand Down

0 comments on commit 61432aa

Please sign in to comment.