Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: delete suggestion from record on search engine #4336

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/argilla/server/apis/v1/handlers/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ async def upsert_suggestion(
async def delete_record_suggestions(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
record_id: UUID,
current_user: User = Security(auth.get_current_user),
ids: str = Query(..., description="A comma separated list with the IDs of the suggestions to be removed"),
Expand All @@ -195,7 +196,7 @@ async def delete_record_suggestions(
detail=f"Cannot delete more than {DELETE_RECORD_SUGGESTIONS_LIMIT} suggestions at once",
)

await datasets.delete_suggestions(db, record, suggestion_ids)
await datasets.delete_suggestions(db, search_engine, record, suggestion_ids)


@router.delete("/records/{record_id}", response_model=RecordSchema, response_model_exclude_unset=True)
Expand Down
4 changes: 3 additions & 1 deletion src/argilla/server/apis/v1/handlers/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from argilla.server.models import Suggestion, User
from argilla.server.policies import SuggestionPolicyV1, authorize
from argilla.server.schemas.v1.suggestions import Suggestion as SuggestionSchema
from argilla.server.search_engine import SearchEngine, get_search_engine
from argilla.server.security import auth

router = APIRouter(tags=["suggestions"])
Expand All @@ -41,6 +42,7 @@ async def _get_suggestion(db: "AsyncSession", suggestion_id: UUID) -> Suggestion
async def delete_suggestion(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
suggestion_id: UUID,
current_user: User = Security(auth.get_current_user),
):
Expand All @@ -49,6 +51,6 @@ async def delete_suggestion(
await authorize(current_user, SuggestionPolicyV1.delete(suggestion))

try:
return await datasets.delete_suggestion(db, suggestion)
return await datasets.delete_suggestion(db, search_engine, suggestion)
except ValueError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))
60 changes: 54 additions & 6 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,21 @@
# limitations under the License.
import copy
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Set, Tuple, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
from uuid import UUID

import sqlalchemy
Expand Down Expand Up @@ -1113,22 +1127,56 @@ async def upsert_suggestion(
return suggestion


async def delete_suggestions(db: "AsyncSession", record: Record, suggestions_ids: List[UUID]) -> None:
async def delete_suggestions(
db: "AsyncSession", search_engine: SearchEngine, record: Record, suggestions_ids: List[UUID]
) -> None:
params = [Suggestion.id.in_(suggestions_ids), Suggestion.record_id == record.id]
await Suggestion.delete_many(db=db, params=params)
suggestions = await list_suggestions_by_id_and_record_id(db, suggestions_ids, record.id)

async with db.begin_nested():
await Suggestion.delete_many(db=db, params=params, autocommit=False)
for suggestion in suggestions:
await search_engine.delete_record_suggestion(suggestion)

await db.commit()


async def get_suggestion_by_id(db: "AsyncSession", suggestion_id: "UUID") -> Union[Suggestion, None]:
result = await db.execute(
select(Suggestion)
.filter_by(id=suggestion_id)
.options(selectinload(Suggestion.record).selectinload(Record.dataset))
.options(
selectinload(Suggestion.record).selectinload(Record.dataset),
selectinload(Suggestion.question),
)
)

return result.scalar_one_or_none()


async def delete_suggestion(db: "AsyncSession", suggestion: Suggestion) -> Suggestion:
return await suggestion.delete(db)
async def list_suggestions_by_id_and_record_id(
db: "AsyncSession", suggestion_ids: List[UUID], record_id: UUID
) -> Sequence[Suggestion]:
result = await db.execute(
select(Suggestion)
.filter(Suggestion.record_id == record_id, Suggestion.id.in_(suggestion_ids))
.options(
selectinload(Suggestion.record).selectinload(Record.dataset),
selectinload(Suggestion.question),
)
)

return result.scalars().all()


async def delete_suggestion(db: "AsyncSession", search_engine: SearchEngine, suggestion: Suggestion) -> Suggestion:
async with db.begin_nested():
suggestion = await suggestion.delete(db, autocommit=False)
await search_engine.delete_record_suggestion(suggestion)

await db.commit()

return suggestion


async def get_metadata_property_by_id(db: "AsyncSession", metadata_property_id: UUID) -> Optional[MetadataProperty]:
Expand Down
6 changes: 4 additions & 2 deletions src/argilla/server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Generic,
Iterable,
List,
Literal,
Optional,
Type,
TypeVar,
Expand All @@ -36,7 +35,6 @@
from argilla.server.enums import (
MetadataPropertyType,
RecordSortField,
ResponseStatus,
ResponseStatusFilter,
SimilarityOrder,
SortOrder,
Expand Down Expand Up @@ -315,6 +313,10 @@ async def delete_record_response(self, response: Response):
async def update_record_suggestion(self, suggestion: Suggestion):
pass

@abstractmethod
async def delete_record_suggestion(self, suggestion: Suggestion):
pass

@abstractmethod
async def search(
self,
Expand Down
9 changes: 9 additions & 0 deletions src/argilla/server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ async def update_record_suggestion(self, suggestion: Suggestion):
body={"doc": {"suggestions": es_suggestions}},
)

async def delete_record_suggestion(self, suggestion: Suggestion):
index_name = await self._get_index_or_raise(suggestion.record.dataset)

await self._update_document_request(
index_name,
id=suggestion.record_id,
body={"script": f'ctx._source["suggestions"].remove("{suggestion.question.name}")'},
)

async def set_records_vectors(self, dataset: Dataset, vectors: Iterable[Vector]):
index_name = await self._get_index_or_raise(dataset)

Expand Down
6 changes: 5 additions & 1 deletion tests/unit/server/api/v1/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Type
from unittest.mock import call
from uuid import UUID, uuid4

import pytest
Expand Down Expand Up @@ -1342,7 +1343,7 @@ async def test_delete_record_non_existent(self, async_client: "AsyncClient", own

@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner])
async def test_delete_record_suggestions(
self, async_client: "AsyncClient", db: "AsyncSession", role: UserRole
self, async_client: "AsyncClient", db: "AsyncSession", mock_search_engine: SearchEngine, role: UserRole
) -> None:
dataset = await DatasetFactory.create()
user = await UserFactory.create(workspaces=[dataset.workspace], role=role)
Expand All @@ -1363,6 +1364,9 @@ async def test_delete_record_suggestions(
assert response.status_code == 204
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 0

expected_calls = [call(suggestion) for suggestion in suggestions]
mock_search_engine.delete_record_suggestion.assert_has_calls(expected_calls)

async def test_delete_record_suggestions_with_no_ids(
self, async_client: "AsyncClient", owner_auth_header: dict
) -> None:
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/server/api/v1/test_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
from argilla._constants import API_KEY_HEADER_NAME
from argilla.server.models import Suggestion, UserRole
from argilla.server.search_engine import SearchEngine
from sqlalchemy import func, select

from tests.factories import SuggestionFactory, UserFactory
Expand All @@ -30,7 +31,9 @@
@pytest.mark.asyncio
class TestSuiteSuggestions:
@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner])
async def test_delete_suggestion(self, async_client: "AsyncClient", db: "AsyncSession", role: UserRole) -> None:
async def test_delete_suggestion(
self, async_client: "AsyncClient", mock_search_engine: SearchEngine, db: "AsyncSession", role: UserRole
) -> None:
suggestion = await SuggestionFactory.create()
user = await UserFactory.create(role=role, workspaces=[suggestion.record.dataset.workspace])

Expand All @@ -50,6 +53,8 @@ async def test_delete_suggestion(self, async_client: "AsyncClient", db: "AsyncSe
}
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 0

mock_search_engine.delete_record_suggestion.assert_called_once_with(suggestion)

async def test_delete_suggestion_non_existent(self, async_client: "AsyncClient", owner_auth_header: dict) -> None:
response = await async_client.delete(f"/api/v1/suggestions/{uuid4()}", headers=owner_auth_header)

Expand Down
Loading