diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py index def025bf53..ca106a0a63 100644 --- a/src/argilla/server/server.py +++ b/src/argilla/server/server.py @@ -29,6 +29,7 @@ from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse from pydantic import ConfigError from argilla import __version__ as argilla_version @@ -202,7 +203,18 @@ async def check_telemetry(): version=str(argilla_version), ) -app = FastAPI() + +@argilla_app.get("/docs", include_in_schema=False) +async def redirect_docs(): + return RedirectResponse(url=f"{settings.base_url}api/docs") + + +@argilla_app.get("/api", include_in_schema=False) +async def redirect_api(): + return RedirectResponse(url=f"{settings.base_url}api/docs") + + +app = FastAPI(docs_url=None) app.mount(settings.base_url, argilla_app) configure_app_logging(app) diff --git a/tests/conftest.py b/tests/conftest.py index f8495bff2b..c9b1d52432 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,6 +35,12 @@ def telemetry_track_data(mocker): return spy +@pytest.fixture(scope="session") +def test_client(): + with TestClient(app) as client: + yield client + + @pytest.fixture def mocked_client( monkeypatch, diff --git a/tests/server/test_api.py b/tests/server/test_api.py index 72743a4762..f0b87ec894 100644 --- a/tests/server/test_api.py +++ b/tests/server/test_api.py @@ -20,6 +20,7 @@ TextClassificationRecord, ) from argilla.server.commons.models import TaskStatus, TaskType +from starlette.testclient import TestClient def create_some_data_for_text_classification( @@ -109,3 +110,13 @@ def uri_2_path(uri: str): p = urlparse(uri) return os.path.abspath(os.path.join(p.netloc, p.path)) + + +def test_docs_redirect(test_client: TestClient): + response = test_client.get("/docs", follow_redirects=False) + assert response.status_code == 307 + assert response.next_request.url.path == "/api/docs" + + response = test_client.get("/api", follow_redirects=False) + assert response.status_code == 307 + assert response.next_request.url.path == "/api/docs"