diff --git a/src/argilla/server/apis/v1/handlers/records.py b/src/argilla/server/apis/v1/handlers/records.py index 3656d279779..4144dc1bb47 100644 --- a/src/argilla/server/apis/v1/handlers/records.py +++ b/src/argilla/server/apis/v1/handlers/records.py @@ -138,6 +138,7 @@ async def get_record_suggestions( async def upsert_suggestion( *, db: AsyncSession = Depends(get_async_db), + search_engine: SearchEngine = Depends(get_search_engine), record_id: UUID, suggestion_create: SuggestionCreate, current_user: User = Security(auth.get_current_user), @@ -161,7 +162,7 @@ async def upsert_suggestion( # TODO: We should split API v1 into different FastAPI apps so we can customize error management. # After mapping ValueError to 422 errors for API v1 then we can remove this try except. try: - return await datasets.upsert_suggestion(db, record, question, suggestion_create) + return await datasets.upsert_suggestion(db, search_engine, record, question, suggestion_create) except ValueError as err: raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err)) diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index ab9ce53e6df..5851d0d93a1 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -1078,16 +1078,39 @@ async def get_suggestion_by_record_id_and_question_id( return result.scalar_one_or_none() +async def _preload_suggestion_relationships_before_index(db: "AsyncSession", suggestion: Suggestion) -> None: + await db.execute( + select(Suggestion) + .filter_by(id=suggestion.id) + .options( + selectinload(Suggestion.record).selectinload(Record.dataset), + selectinload(Suggestion.question), + ) + ) + + async def upsert_suggestion( - db: "AsyncSession", record: Record, question: Question, suggestion_create: "SuggestionCreate" + db: "AsyncSession", + search_engine: SearchEngine, + record: Record, + question: Question, + suggestion_create: "SuggestionCreate", ) -> Suggestion: question.parsed_settings.check_response(suggestion_create) - return await Suggestion.upsert( - db, - schema=SuggestionCreateWithRecordId(record_id=record.id, **suggestion_create.dict()), - constraints=[Suggestion.record_id, Suggestion.question_id], - ) + async with db.begin_nested(): + suggestion = await Suggestion.upsert( + db, + schema=SuggestionCreateWithRecordId(record_id=record.id, **suggestion_create.dict()), + constraints=[Suggestion.record_id, Suggestion.question_id], + autocommit=False, + ) + await _preload_suggestion_relationships_before_index(db, suggestion) + await search_engine.update_record_suggestion(suggestion) + + await db.commit() + + return suggestion async def delete_suggestions(db: "AsyncSession", record: Record, suggestions_ids: List[UUID]) -> None: diff --git a/src/argilla/server/search_engine/base.py b/src/argilla/server/search_engine/base.py index fb77fcd2082..5d0b4360477 100644 --- a/src/argilla/server/search_engine/base.py +++ b/src/argilla/server/search_engine/base.py @@ -41,7 +41,7 @@ SimilarityOrder, SortOrder, ) -from argilla.server.models import Dataset, MetadataProperty, Record, Response, User, Vector, VectorSettings +from argilla.server.models import Dataset, MetadataProperty, Record, Response, Suggestion, User, Vector, VectorSettings __all__ = [ "SearchEngine", @@ -311,6 +311,10 @@ async def update_record_response(self, response: Response): async def delete_record_response(self, response: Response): pass + @abstractmethod + async def update_record_suggestion(self, suggestion: Suggestion): + pass + @abstractmethod async def search( self, diff --git a/src/argilla/server/search_engine/commons.py b/src/argilla/server/search_engine/commons.py index 92a77d9c07f..c4743025035 100644 --- a/src/argilla/server/search_engine/commons.py +++ b/src/argilla/server/search_engine/commons.py @@ -319,6 +319,17 @@ async def delete_record_response(self, response: Response): index_name, id=record.id, body={"script": f'ctx._source["responses"].remove("{response.user.username}")'} ) + async def update_record_suggestion(self, suggestion: Suggestion): + index_name = await self._get_index_or_raise(suggestion.record.dataset) + + es_suggestions = self._map_record_suggestions_to_es([suggestion]) + + await self._update_document_request( + index_name, + id=suggestion.record_id, + body={"doc": {"suggestions": es_suggestions}}, + ) + async def set_records_vectors(self, dataset: Dataset, vectors: Iterable[Vector]): index_name = await self._get_index_or_raise(dataset) diff --git a/tests/unit/server/api/v1/test_records.py b/tests/unit/server/api/v1/test_records.py index 2fb50de1b08..be405f5d92d 100644 --- a/tests/unit/server/api/v1/test_records.py +++ b/tests/unit/server/api/v1/test_records.py @@ -1214,7 +1214,7 @@ async def test_create_record_suggestion( assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 1 async def test_create_record_suggestion_update( - self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict + self, async_client: "AsyncClient", db: "AsyncSession", mock_search_engine: SearchEngine, owner_auth_header: dict ): dataset = await DatasetFactory.create() question = await TextQuestionFactory.create(dataset=dataset) @@ -1239,6 +1239,8 @@ async def test_create_record_suggestion_update( } assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 1 + mock_search_engine.update_record_suggestion.assert_called_once_with(suggestion) + @pytest.mark.parametrize( "payload", [