Skip to content

Commit

Permalink
feat: update suggestion from record on search engine (#4339)
Browse files Browse the repository at this point in the history
# Description

Update a suggestion from the record document on search engine when the
suggestion is updated using the records update suggestion endpoint.

Ref #4230 

**Type of change**

- [x] New feature (non-breaking change which adds functionality)

**How Has This Been Tested**

- [x] Running tests locally.
- [x] Checking manually that updating a suggestion is affecting filters
later on the UI.

**Checklist**

- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
jfcalvo authored and davidberenstein1957 committed Nov 29, 2023
1 parent bbed3b0 commit bb38ba0
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/argilla/server/apis/v1/handlers/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def get_record_suggestions(
async def upsert_suggestion(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
record_id: UUID,
suggestion_create: SuggestionCreate,
current_user: User = Security(auth.get_current_user),
Expand All @@ -161,7 +162,7 @@ async def upsert_suggestion(
# TODO: We should split API v1 into different FastAPI apps so we can customize error management.
# After mapping ValueError to 422 errors for API v1 then we can remove this try except.
try:
return await datasets.upsert_suggestion(db, record, question, suggestion_create)
return await datasets.upsert_suggestion(db, search_engine, record, question, suggestion_create)
except ValueError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))

Expand Down
35 changes: 29 additions & 6 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,16 +1078,39 @@ async def get_suggestion_by_record_id_and_question_id(
return result.scalar_one_or_none()


async def _preload_suggestion_relationships_before_index(db: "AsyncSession", suggestion: Suggestion) -> None:
await db.execute(
select(Suggestion)
.filter_by(id=suggestion.id)
.options(
selectinload(Suggestion.record).selectinload(Record.dataset),
selectinload(Suggestion.question),
)
)


async def upsert_suggestion(
db: "AsyncSession", record: Record, question: Question, suggestion_create: "SuggestionCreate"
db: "AsyncSession",
search_engine: SearchEngine,
record: Record,
question: Question,
suggestion_create: "SuggestionCreate",
) -> Suggestion:
question.parsed_settings.check_response(suggestion_create)

return await Suggestion.upsert(
db,
schema=SuggestionCreateWithRecordId(record_id=record.id, **suggestion_create.dict()),
constraints=[Suggestion.record_id, Suggestion.question_id],
)
async with db.begin_nested():
suggestion = await Suggestion.upsert(
db,
schema=SuggestionCreateWithRecordId(record_id=record.id, **suggestion_create.dict()),
constraints=[Suggestion.record_id, Suggestion.question_id],
autocommit=False,
)
await _preload_suggestion_relationships_before_index(db, suggestion)
await search_engine.update_record_suggestion(suggestion)

await db.commit()

return suggestion


async def delete_suggestions(db: "AsyncSession", record: Record, suggestions_ids: List[UUID]) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/argilla/server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
SimilarityOrder,
SortOrder,
)
from argilla.server.models import Dataset, MetadataProperty, Record, Response, User, Vector, VectorSettings
from argilla.server.models import Dataset, MetadataProperty, Record, Response, Suggestion, User, Vector, VectorSettings

__all__ = [
"SearchEngine",
Expand Down Expand Up @@ -311,6 +311,10 @@ async def update_record_response(self, response: Response):
async def delete_record_response(self, response: Response):
pass

@abstractmethod
async def update_record_suggestion(self, suggestion: Suggestion):
pass

@abstractmethod
async def search(
self,
Expand Down
11 changes: 11 additions & 0 deletions src/argilla/server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ async def delete_record_response(self, response: Response):
index_name, id=record.id, body={"script": f'ctx._source["responses"].remove("{response.user.username}")'}
)

async def update_record_suggestion(self, suggestion: Suggestion):
index_name = await self._get_index_or_raise(suggestion.record.dataset)

es_suggestions = self._map_record_suggestions_to_es([suggestion])

await self._update_document_request(
index_name,
id=suggestion.record_id,
body={"doc": {"suggestions": es_suggestions}},
)

async def set_records_vectors(self, dataset: Dataset, vectors: Iterable[Vector]):
index_name = await self._get_index_or_raise(dataset)

Expand Down
4 changes: 3 additions & 1 deletion tests/unit/server/api/v1/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ async def test_create_record_suggestion(
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 1

async def test_create_record_suggestion_update(
self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
self, async_client: "AsyncClient", db: "AsyncSession", mock_search_engine: SearchEngine, owner_auth_header: dict
):
dataset = await DatasetFactory.create()
question = await TextQuestionFactory.create(dataset=dataset)
Expand All @@ -1239,6 +1239,8 @@ async def test_create_record_suggestion_update(
}
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 1

mock_search_engine.update_record_suggestion.assert_called_once_with(suggestion)

@pytest.mark.parametrize(
"payload",
[
Expand Down

0 comments on commit bb38ba0

Please sign in to comment.