Skip to content

Commit

Permalink
feat: retrieve vectors when fetching records (#4063)
Browse files Browse the repository at this point in the history
# Description

This PR add support to `include=vectors` and
`include=vectors:vector_settings_id_01,vector_settings_id_02` parameters
in the following endpoints:
* `GET /api/v1/datasets/:dataset_id/records`
* `GET /api/v1/me/datasets/:dataset_id/records`
* `GET /api/v1/me/datasets/:dataset_id/records/search`

Possible improvements to this PR:
- [x] Use vector names instead of vector settings ids for `include` like
`include=vectors:vector_name_01,vector_name_02`.
- [ ] Add support for `include` parameter to `GET
/api/v1/records/:record_id` endpoint.

Closes #4051 

**Type of change**

- [x] New feature (non-breaking change which adds functionality)

**How Has This Been Tested**

- [x] Running tests locally.

**Checklist**

- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: gabrielmbmb <[email protected]>
  • Loading branch information
jfcalvo and gabrielmbmb authored Oct 30, 2023
1 parent c743532 commit 9dd3522
Show file tree
Hide file tree
Showing 7 changed files with 532 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ These are the section headers that we use:
- Added `GET /api/v1/datasets/:dataset_id/vectors-settings` endpoint for listing the vectors settings for a dataset. ([#3776](https://github.com/argilla-io/argilla/pull/3776))
- Added `DELETE /api/v1/vectors-settings/:vector_settings_id` endpoint for deleting a vector settings. ([#3776](https://github.com/argilla-io/argilla/pull/3776))
- Added `GET /api/v1/records/:record_id` endpoint to get a specific record. ([#4039](https://github.com/argilla-io/argilla/pull/4039))
- Added support to include vectors for `GET /api/v1/datasets/:dataset_id/records` endpoint response using `include` query param. ([#4063](https://github.com/argilla-io/argilla/pull/4063))
- Added support to include vectors for `GET /api/v1/me/datasets/:dataset_id/records` endpoint response using `include` query param. ([#4063](https://github.com/argilla-io/argilla/pull/4063))
- Added support to include vectors for `POST /api/v1/me/datasets/:dataset_id/records/search` endpoint response using `include` query param. ([#4063](https://github.com/argilla-io/argilla/pull/4063))

### Changed

Expand Down
16 changes: 11 additions & 5 deletions src/argilla/server/apis/v1/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from argilla.server.contexts import accounts, datasets
from argilla.server.database import get_async_db
from argilla.server.enums import MetadataPropertyType, RecordInclude, RecordSortField, ResponseStatusFilter, SortOrder
from argilla.server.enums import MetadataPropertyType, RecordSortField, ResponseStatusFilter, SortOrder
from argilla.server.models import Dataset as DatasetModel
from argilla.server.models import ResponseStatus, User
from argilla.server.policies import DatasetPolicyV1, MetadataPropertyPolicyV1, authorize, is_authorized
Expand All @@ -43,6 +43,7 @@
Question,
QuestionCreate,
Questions,
RecordIncludeParam,
Records,
RecordsCreate,
RecordsUpdate,
Expand Down Expand Up @@ -79,6 +80,10 @@

router = APIRouter(tags=["datasets"])

parse_record_include_param = parse_query_param(
name="include", help="Relationships to include in the response", model=RecordIncludeParam
)


async def _get_dataset(
db: AsyncSession,
Expand Down Expand Up @@ -234,7 +239,7 @@ async def _filter_records_using_search_engine(
offset: int,
user: Optional[User] = None,
response_statuses: Optional[List[ResponseStatusFilter]] = None,
include: Optional[List[RecordInclude]] = None,
include: Optional[RecordIncludeParam] = None,
sort_by_query_param: Optional[Dict[str, str]] = None,
) -> Tuple[List["Record"], int]:
search_responses = await _get_search_responses(
Expand Down Expand Up @@ -289,6 +294,7 @@ async def list_current_user_datasets(
dataset_list = current_user.datasets
else:
dataset_list = await datasets.list_datasets_by_workspace_id(db, workspace_id)

return Datasets(items=dataset_list)


Expand Down Expand Up @@ -364,7 +370,7 @@ async def list_current_user_dataset_records(
dataset_id: UUID,
metadata: MetadataQueryParams = Depends(),
sort_by_query_param: SortByQueryParamParsed,
include: List[RecordInclude] = Query([], description="Relationships to include in the response"),
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
Expand Down Expand Up @@ -409,7 +415,7 @@ async def list_dataset_records(
dataset_id: UUID,
metadata: MetadataQueryParams = Depends(),
sort_by_query_param: SortByQueryParamParsed,
include: List[RecordInclude] = Query([], description="Relationships to include in the response"),
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
Expand Down Expand Up @@ -709,7 +715,7 @@ async def search_dataset_records(
query: SearchRecordsQuery,
metadata: MetadataQueryParams = Depends(),
sort_by_query_param: SortByQueryParamParsed,
include: List[RecordInclude] = Query([]),
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
Expand Down
53 changes: 37 additions & 16 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import sqlalchemy
from fastapi.encoders import jsonable_encoder
from sqlalchemy import and_, func, or_, select
from sqlalchemy import Select, and_, func, or_, select
from sqlalchemy.orm import contains_eager, joinedload, selectinload

from argilla.server.contexts import accounts
Expand Down Expand Up @@ -59,6 +59,7 @@

from argilla.server.schemas.v1.datasets import (
DatasetUpdate,
RecordIncludeParam,
RecordsUpdate,
VectorSettingsCreate,
)
Expand Down Expand Up @@ -357,34 +358,33 @@ async def get_record_by_id(
selectinload(Record.dataset).selectinload(Dataset.questions),
selectinload(Record.dataset).selectinload(Dataset.metadata_properties),
)

if with_suggestions:
query = query.options(selectinload(Record.suggestions))

result = await db.execute(query)

return result.scalar_one_or_none()


async def get_records_by_ids(
db: "AsyncSession",
dataset_id: UUID,
records_ids: List[UUID],
include: Optional[List[RecordInclude]] = None,
include: Optional["RecordIncludeParam"] = None,
user_id: Optional[UUID] = None,
) -> List[Record]:
if include is None:
include = []

query = select(Record).filter(Record.dataset_id == dataset_id, Record.id.in_(records_ids))

if RecordInclude.responses in include:
if user_id:
if include and include.with_responses:
if not user_id:
query = query.options(joinedload(Record.responses))
else:
query = query.outerjoin(
Response, and_(Response.record_id == Record.id, Response.user_id == user_id)
).options(contains_eager(Record.responses))
else:
query = query.options(joinedload(Record.responses))

if RecordInclude.suggestions in include:
query = query.options(joinedload(Record.suggestions))
query = await _configure_query_relationships(query, include_params=include)

result = await db.execute(query)
records = result.unique().scalars().all()
Expand All @@ -396,11 +396,34 @@ async def get_records_by_ids(
return ordered_records


async def _configure_query_relationships(
query: "Select", include_params: Optional["RecordIncludeParam"] = None
) -> "Select":
if not include_params:
return query

if include_params.with_suggestions:
query = query.options(joinedload(Record.suggestions))

if include_params.with_all_vectors:
query = query.options(joinedload(Record.vectors))

elif include_params.with_some_vector:
vector_settings_ids_subquery = (
select(VectorSettings.id).filter(VectorSettings.name.in_(include_params.vectors)).subquery()
)
query = query.outerjoin(
Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(vector_settings_ids_subquery))
).options(contains_eager(Record.vectors))

return query


async def list_records_by_dataset_id(
db: "AsyncSession",
dataset_id: UUID,
user_id: Optional[UUID] = None,
include: List[RecordInclude] = [],
include: Optional["RecordIncludeParam"] = None,
response_statuses: List[ResponseStatusFilter] = [],
offset: int = 0,
limit: int = LIST_RECORDS_LIMIT,
Expand Down Expand Up @@ -433,12 +456,10 @@ async def list_records_by_dataset_id(
if response_status_filter_expressions:
records_query = records_query.filter(or_(*response_status_filter_expressions))

if RecordInclude.responses in include:
if include and include.with_responses:
records_query = records_query.options(contains_eager(Record.responses))

if RecordInclude.suggestions in include:
records_query = records_query.options(joinedload(Record.suggestions))

records_query = await _configure_query_relationships(query=records_query, include_params=include)
records_query = records_query.order_by(Record.inserted_at.asc()).offset(offset).limit(limit)
result_records = await db.execute(records_query)

Expand Down
1 change: 1 addition & 0 deletions src/argilla/server/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class UserRole(str, Enum):
class RecordInclude(str, Enum):
responses = "responses"
suggestions = "suggestions"
vectors = "vectors"


class QuestionType(str, Enum):
Expand Down
52 changes: 52 additions & 0 deletions src/argilla/server/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,19 @@ class RecordGetterDict(GetterDict):
def get(self, key: str, default: Any) -> Any:
if key == "metadata":
return getattr(self._obj, "metadata_", None)

if key == "responses" and not self._obj.is_relationship_loaded("responses"):
return default

if key == "suggestions" and not self._obj.is_relationship_loaded("suggestions"):
return default

if key == "vectors":
if self._obj.is_relationship_loaded("vectors"):
return {vector.vector_settings.name: vector.value for vector in self._obj.vectors}
else:
return default

return super().get(key, default)


Expand All @@ -439,6 +448,7 @@ class Record(BaseModel):
# response: Optional[Response]
responses: Optional[List[Response]]
suggestions: Optional[List[Suggestion]]
vectors: Optional[Dict[str, List[float]]]
inserted_at: datetime
updated_at: datetime

Expand Down Expand Up @@ -508,6 +518,48 @@ class RecordsUpdate(BaseModel):
)


class RecordIncludeParam(BaseModel):
relationships: Optional[List[RecordInclude]] = PydanticField(None, alias="keys")
vectors: Optional[List[str]] = PydanticField(None, alias="vectors")

@root_validator
def check(cls, values: Dict[str, Any]) -> Dict[str, Any]:
relationships = values.get("relationships")
if not relationships:
return values

vectors = values.get("vectors")
if vectors is not None and len(vectors) > 0 and RecordInclude.vectors in relationships:
# TODO: once we have a exception handler for ValueError in v1, remove HTTPException
# raise ValueError("Cannot include both 'vectors' and 'relationships' in the same request")
raise HTTPException(
status_code=422,
detail="'include' query param cannot have both 'vectors' and 'vectors:vector_settings_name_1,vectors_settings_name_2,...'",
)

return values

@property
def with_responses(self) -> bool:
return self._has_relationships and RecordInclude.responses in self.relationships

@property
def with_suggestions(self) -> bool:
return self._has_relationships and RecordInclude.suggestions in self.relationships

@property
def with_all_vectors(self) -> bool:
return self._has_relationships and not self.vectors and RecordInclude.vectors in self.relationships

@property
def with_some_vector(self) -> bool:
return self.vectors is not None and len(self.vectors) > 0

@property
def _has_relationships(self):
return self.relationships is not None


NT = TypeVar("NT", int, float)


Expand Down
Loading

0 comments on commit 9dd3522

Please sign in to comment.