Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update suggestion from record on search engine #4339

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/argilla/server/apis/v1/handlers/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def get_record_suggestions(
async def upsert_suggestion(
*,
db: AsyncSession = Depends(get_async_db),
search_engine: SearchEngine = Depends(get_search_engine),
record_id: UUID,
suggestion_create: SuggestionCreate,
current_user: User = Security(auth.get_current_user),
Expand All @@ -161,7 +162,7 @@ async def upsert_suggestion(
# TODO: We should split API v1 into different FastAPI apps so we can customize error management.
# After mapping ValueError to 422 errors for API v1 then we can remove this try except.
try:
return await datasets.upsert_suggestion(db, record, question, suggestion_create)
return await datasets.upsert_suggestion(db, search_engine, record, question, suggestion_create)
except ValueError as err:
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))

Expand Down
35 changes: 29 additions & 6 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,16 +1078,39 @@ async def get_suggestion_by_record_id_and_question_id(
return result.scalar_one_or_none()


async def _preload_suggestion_relationships_before_index(db: "AsyncSession", suggestion: Suggestion) -> None:
await db.execute(
select(Suggestion)
.filter_by(id=suggestion.id)
.options(
selectinload(Suggestion.record).selectinload(Record.dataset),
selectinload(Suggestion.question),
)
)


async def upsert_suggestion(
db: "AsyncSession", record: Record, question: Question, suggestion_create: "SuggestionCreate"
db: "AsyncSession",
search_engine: SearchEngine,
record: Record,
question: Question,
suggestion_create: "SuggestionCreate",
) -> Suggestion:
question.parsed_settings.check_response(suggestion_create)

return await Suggestion.upsert(
db,
schema=SuggestionCreateWithRecordId(record_id=record.id, **suggestion_create.dict()),
constraints=[Suggestion.record_id, Suggestion.question_id],
)
async with db.begin_nested():
suggestion = await Suggestion.upsert(
db,
schema=SuggestionCreateWithRecordId(record_id=record.id, **suggestion_create.dict()),
constraints=[Suggestion.record_id, Suggestion.question_id],
autocommit=False,
)
await _preload_suggestion_relationships_before_index(db, suggestion)
await search_engine.update_record_suggestion(suggestion)

await db.commit()

return suggestion


async def delete_suggestions(db: "AsyncSession", record: Record, suggestions_ids: List[UUID]) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/argilla/server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
SimilarityOrder,
SortOrder,
)
from argilla.server.models import Dataset, MetadataProperty, Record, Response, User, Vector, VectorSettings
from argilla.server.models import Dataset, MetadataProperty, Record, Response, Suggestion, User, Vector, VectorSettings

__all__ = [
"SearchEngine",
Expand Down Expand Up @@ -311,6 +311,10 @@ async def update_record_response(self, response: Response):
async def delete_record_response(self, response: Response):
pass

@abstractmethod
async def update_record_suggestion(self, suggestion: Suggestion):
pass

@abstractmethod
async def search(
self,
Expand Down
11 changes: 11 additions & 0 deletions src/argilla/server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,17 @@ async def delete_record_response(self, response: Response):
index_name, id=record.id, body={"script": f'ctx._source["responses"].remove("{response.user.username}")'}
)

async def update_record_suggestion(self, suggestion: Suggestion):
index_name = await self._get_index_or_raise(suggestion.record.dataset)

es_suggestions = self._map_record_suggestions_to_es([suggestion])

await self._update_document_request(
index_name,
id=suggestion.record_id,
body={"doc": {"suggestions": es_suggestions}},
)

async def set_records_vectors(self, dataset: Dataset, vectors: Iterable[Vector]):
index_name = await self._get_index_or_raise(dataset)

Expand Down
4 changes: 3 additions & 1 deletion tests/unit/server/api/v1/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ async def test_create_record_suggestion(
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 1

async def test_create_record_suggestion_update(
self, async_client: "AsyncClient", db: "AsyncSession", owner_auth_header: dict
self, async_client: "AsyncClient", db: "AsyncSession", mock_search_engine: SearchEngine, owner_auth_header: dict
):
dataset = await DatasetFactory.create()
question = await TextQuestionFactory.create(dataset=dataset)
Expand All @@ -1239,6 +1239,8 @@ async def test_create_record_suggestion_update(
}
assert (await db.execute(select(func.count(Suggestion.id)))).scalar() == 1

mock_search_engine.update_record_suggestion.assert_called_once_with(suggestion)

@pytest.mark.parametrize(
"payload",
[
Expand Down
Loading