diff --git a/src/argilla/server/apis/v1/handlers/records.py b/src/argilla/server/apis/v1/handlers/records.py index 4144dc1bb4..a78a25617d 100644 --- a/src/argilla/server/apis/v1/handlers/records.py +++ b/src/argilla/server/apis/v1/handlers/records.py @@ -175,6 +175,7 @@ async def upsert_suggestion( async def delete_record_suggestions( *, db: AsyncSession = Depends(get_async_db), + search_engine: SearchEngine = Depends(get_search_engine), record_id: UUID, current_user: User = Security(auth.get_current_user), ids: str = Query(..., description="A comma separated list with the IDs of the suggestions to be removed"), @@ -195,7 +196,7 @@ async def delete_record_suggestions( detail=f"Cannot delete more than {DELETE_RECORD_SUGGESTIONS_LIMIT} suggestions at once", ) - await datasets.delete_suggestions(db, record, suggestion_ids) + await datasets.delete_suggestions(db, search_engine, record, suggestion_ids) @router.delete("/records/{record_id}", response_model=RecordSchema, response_model_exclude_unset=True) diff --git a/src/argilla/server/apis/v1/handlers/suggestions.py b/src/argilla/server/apis/v1/handlers/suggestions.py index 2f9fd4cdaf..8ea66f6bea 100644 --- a/src/argilla/server/apis/v1/handlers/suggestions.py +++ b/src/argilla/server/apis/v1/handlers/suggestions.py @@ -22,6 +22,7 @@ from argilla.server.models import Suggestion, User from argilla.server.policies import SuggestionPolicyV1, authorize from argilla.server.schemas.v1.suggestions import Suggestion as SuggestionSchema +from argilla.server.search_engine import SearchEngine, get_search_engine from argilla.server.security import auth router = APIRouter(tags=["suggestions"]) @@ -41,6 +42,7 @@ async def _get_suggestion(db: "AsyncSession", suggestion_id: UUID) -> Suggestion async def delete_suggestion( *, db: AsyncSession = Depends(get_async_db), + search_engine: SearchEngine = Depends(get_search_engine), suggestion_id: UUID, current_user: User = Security(auth.get_current_user), ): @@ -49,6 +51,6 @@ async def delete_suggestion( await authorize(current_user, SuggestionPolicyV1.delete(suggestion)) try: - return await datasets.delete_suggestion(db, suggestion) + return await datasets.delete_suggestion(db, search_engine, suggestion) except ValueError as err: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err)) diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 5851d0d93a..e57cc5898a 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -13,7 +13,21 @@ # limitations under the License. import copy from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, +) from uuid import UUID import sqlalchemy @@ -1113,22 +1127,56 @@ async def upsert_suggestion( return suggestion -async def delete_suggestions(db: "AsyncSession", record: Record, suggestions_ids: List[UUID]) -> None: +async def delete_suggestions( + db: "AsyncSession", search_engine: SearchEngine, record: Record, suggestions_ids: List[UUID] +) -> None: params = [Suggestion.id.in_(suggestions_ids), Suggestion.record_id == record.id] - await Suggestion.delete_many(db=db, params=params) + suggestions = await list_suggestions_by_id_and_record_id(db, suggestions_ids, record.id) + + async with db.begin_nested(): + await Suggestion.delete_many(db=db, params=params, autocommit=False) + for suggestion in suggestions: + await search_engine.delete_record_suggestion(suggestion) + + await db.commit() async def get_suggestion_by_id(db: "AsyncSession", suggestion_id: "UUID") -> Union[Suggestion, None]: result = await db.execute( select(Suggestion) .filter_by(id=suggestion_id) - .options(selectinload(Suggestion.record).selectinload(Record.dataset)) + .options( + selectinload(Suggestion.record).selectinload(Record.dataset), + selectinload(Suggestion.question), + ) ) + return result.scalar_one_or_none() -async def delete_suggestion(db: "AsyncSession", suggestion: Suggestion) -> Suggestion: - return await suggestion.delete(db) +async def list_suggestions_by_id_and_record_id( + db: "AsyncSession", suggestion_ids: List[UUID], record_id: UUID +) -> Sequence[Suggestion]: + result = await db.execute( + select(Suggestion) + .filter(Suggestion.record_id == record_id, Suggestion.id.in_(suggestion_ids)) + .options( + selectinload(Suggestion.record).selectinload(Record.dataset), + selectinload(Suggestion.question), + ) + ) + + return result.scalars().all() + + +async def delete_suggestion(db: "AsyncSession", search_engine: SearchEngine, suggestion: Suggestion) -> Suggestion: + async with db.begin_nested(): + suggestion = await suggestion.delete(db, autocommit=False) + await search_engine.delete_record_suggestion(suggestion) + + await db.commit() + + return suggestion async def get_metadata_property_by_id(db: "AsyncSession", metadata_property_id: UUID) -> Optional[MetadataProperty]: diff --git a/src/argilla/server/search_engine/base.py b/src/argilla/server/search_engine/base.py index 5d0b436047..c54963dd96 100644 --- a/src/argilla/server/search_engine/base.py +++ b/src/argilla/server/search_engine/base.py @@ -22,7 +22,6 @@ Generic, Iterable, List, - Literal, Optional, Type, TypeVar, @@ -36,7 +35,6 @@ from argilla.server.enums import ( MetadataPropertyType, RecordSortField, - ResponseStatus, ResponseStatusFilter, SimilarityOrder, SortOrder, @@ -315,6 +313,10 @@ async def delete_record_response(self, response: Response): async def update_record_suggestion(self, suggestion: Suggestion): pass + @abstractmethod + async def delete_record_suggestion(self, suggestion: Suggestion): + pass + @abstractmethod async def search( self, diff --git a/src/argilla/server/search_engine/commons.py b/src/argilla/server/search_engine/commons.py index c474302503..e7038caae6 100644 --- a/src/argilla/server/search_engine/commons.py +++ b/src/argilla/server/search_engine/commons.py @@ -330,6 +330,15 @@ async def update_record_suggestion(self, suggestion: Suggestion): body={"doc": {"suggestions": es_suggestions}}, ) + async def delete_record_suggestion(self, suggestion: Suggestion): + index_name = await self._get_index_or_raise(suggestion.record.dataset) + + await self._update_document_request( + index_name, + id=suggestion.record_id, + body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'}, + ) + async def set_records_vectors(self, dataset: Dataset, vectors: Iterable[Vector]): index_name = await self._get_index_or_raise(dataset) diff --git a/tests/unit/server/api/v1/test_records.py b/tests/unit/server/api/v1/test_records.py index be405f5d92..17eab922ae 100644 --- a/tests/unit/server/api/v1/test_records.py +++ b/tests/unit/server/api/v1/test_records.py @@ -14,6 +14,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Callable, Type +from unittest.mock import call from uuid import UUID, uuid4 import pytest @@ -1342,7 +1343,7 @@ async def test_delete_record_non_existent(self, async_client: "AsyncClient", own @pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner]) async def test_delete_record_suggestions( - self, async_client: "AsyncClient", db: "AsyncSession", role: UserRole + self, async_client: "AsyncClient", db: "AsyncSession", mock_search_engine: SearchEngine, role: UserRole ) -> None: dataset = await DatasetFactory.create() user = await UserFactory.create(workspaces=[dataset.workspace], role=role) @@ -1363,6 +1364,9 @@ async def test_delete_record_suggestions( assert response.status_code == 204 assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 0 + expected_calls = [call(suggestion) for suggestion in suggestions] + mock_search_engine.delete_record_suggestion.assert_has_calls(expected_calls) + async def test_delete_record_suggestions_with_no_ids( self, async_client: "AsyncClient", owner_auth_header: dict ) -> None: diff --git a/tests/unit/server/api/v1/test_suggestions.py b/tests/unit/server/api/v1/test_suggestions.py index 134fae6c45..e2f33020e1 100644 --- a/tests/unit/server/api/v1/test_suggestions.py +++ b/tests/unit/server/api/v1/test_suggestions.py @@ -18,6 +18,7 @@ import pytest from argilla._constants import API_KEY_HEADER_NAME from argilla.server.models import Suggestion, UserRole +from argilla.server.search_engine import SearchEngine from sqlalchemy import func, select from tests.factories import SuggestionFactory, UserFactory @@ -30,7 +31,9 @@ @pytest.mark.asyncio class TestSuiteSuggestions: @pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner]) - async def test_delete_suggestion(self, async_client: "AsyncClient", db: "AsyncSession", role: UserRole) -> None: + async def test_delete_suggestion( + self, async_client: "AsyncClient", mock_search_engine: SearchEngine, db: "AsyncSession", role: UserRole + ) -> None: suggestion = await SuggestionFactory.create() user = await UserFactory.create(role=role, workspaces=[suggestion.record.dataset.workspace]) @@ -50,6 +53,8 @@ async def test_delete_suggestion(self, async_client: "AsyncClient", db: "AsyncSe } assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 0 + mock_search_engine.delete_record_suggestion.assert_called_once_with(suggestion) + async def test_delete_suggestion_non_existent(self, async_client: "AsyncClient", owner_auth_header: dict) -> None: response = await async_client.delete(f"/api/v1/suggestions/{uuid4()}", headers=owner_auth_header)