Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: retrieve vectors when fetching records #4063

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ These are the section headers that we use:
- Added `GET /api/v1/datasets/:dataset_id/vectors-settings` endpoint for listing the vectors settings for a dataset. ([#3776](https://github.com/argilla-io/argilla/pull/3776))
- Added `DELETE /api/v1/vectors-settings/:vector_settings_id` endpoint for deleting a vector settings. ([#3776](https://github.com/argilla-io/argilla/pull/3776))
- Added `GET /api/v1/records/:record_id` endpoint to get a specific record. ([#4039](https://github.com/argilla-io/argilla/pull/4039))
- Added support to include vectors for `GET /api/v1/datasets/:dataset_id/records` endpoint response using `include` query param. ([#4063](https://github.com/argilla-io/argilla/pull/4063))
- Added support to include vectors for `GET /api/v1/me/datasets/:dataset_id/records` endpoint response using `include` query param. ([#4063](https://github.com/argilla-io/argilla/pull/4063))
- Added support to include vectors for `POST /api/v1/me/datasets/:dataset_id/records/search` endpoint response using `include` query param. ([#4063](https://github.com/argilla-io/argilla/pull/4063))

### Changed

Expand Down
16 changes: 11 additions & 5 deletions src/argilla/server/apis/v1/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from argilla.server.contexts import accounts, datasets
from argilla.server.database import get_async_db
from argilla.server.enums import MetadataPropertyType, RecordInclude, RecordSortField, ResponseStatusFilter, SortOrder
from argilla.server.enums import MetadataPropertyType, RecordSortField, ResponseStatusFilter, SortOrder
from argilla.server.models import Dataset as DatasetModel
from argilla.server.models import ResponseStatus, User
from argilla.server.policies import DatasetPolicyV1, MetadataPropertyPolicyV1, authorize, is_authorized
Expand All @@ -43,6 +43,7 @@
Question,
QuestionCreate,
Questions,
RecordIncludeParam,
Records,
RecordsCreate,
RecordsUpdate,
Expand Down Expand Up @@ -79,6 +80,10 @@

router = APIRouter(tags=["datasets"])

parse_record_include_param = parse_query_param(
name="include", help="Relationships to include in the response", model=RecordIncludeParam
)


async def _get_dataset(
db: AsyncSession,
Expand Down Expand Up @@ -234,7 +239,7 @@ async def _filter_records_using_search_engine(
offset: int,
user: Optional[User] = None,
response_statuses: Optional[List[ResponseStatusFilter]] = None,
include: Optional[List[RecordInclude]] = None,
include: Optional[RecordIncludeParam] = None,
sort_by_query_param: Optional[Dict[str, str]] = None,
) -> Tuple[List["Record"], int]:
search_responses = await _get_search_responses(
Expand Down Expand Up @@ -289,6 +294,7 @@ async def list_current_user_datasets(
dataset_list = current_user.datasets
else:
dataset_list = await datasets.list_datasets_by_workspace_id(db, workspace_id)

return Datasets(items=dataset_list)


Expand Down Expand Up @@ -364,7 +370,7 @@ async def list_current_user_dataset_records(
dataset_id: UUID,
metadata: MetadataQueryParams = Depends(),
sort_by_query_param: SortByQueryParamParsed,
include: List[RecordInclude] = Query([], description="Relationships to include in the response"),
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
Expand Down Expand Up @@ -409,7 +415,7 @@ async def list_dataset_records(
dataset_id: UUID,
metadata: MetadataQueryParams = Depends(),
sort_by_query_param: SortByQueryParamParsed,
include: List[RecordInclude] = Query([], description="Relationships to include in the response"),
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
Expand Down Expand Up @@ -709,7 +715,7 @@ async def search_dataset_records(
query: SearchRecordsQuery,
metadata: MetadataQueryParams = Depends(),
sort_by_query_param: SortByQueryParamParsed,
include: List[RecordInclude] = Query([]),
include: Optional[RecordIncludeParam] = Depends(parse_record_include_param),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
Expand Down
53 changes: 37 additions & 16 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import sqlalchemy
from fastapi.encoders import jsonable_encoder
from sqlalchemy import and_, func, or_, select
from sqlalchemy import Select, and_, func, or_, select
from sqlalchemy.orm import contains_eager, joinedload, selectinload

from argilla.server.contexts import accounts
Expand Down Expand Up @@ -59,6 +59,7 @@

from argilla.server.schemas.v1.datasets import (
DatasetUpdate,
RecordIncludeParam,
RecordsUpdate,
VectorSettingsCreate,
)
Expand Down Expand Up @@ -357,34 +358,33 @@ async def get_record_by_id(
selectinload(Record.dataset).selectinload(Dataset.questions),
selectinload(Record.dataset).selectinload(Dataset.metadata_properties),
)

if with_suggestions:
query = query.options(selectinload(Record.suggestions))

result = await db.execute(query)

return result.scalar_one_or_none()


async def get_records_by_ids(
db: "AsyncSession",
dataset_id: UUID,
records_ids: List[UUID],
include: Optional[List[RecordInclude]] = None,
include: Optional["RecordIncludeParam"] = None,
user_id: Optional[UUID] = None,
) -> List[Record]:
if include is None:
include = []

query = select(Record).filter(Record.dataset_id == dataset_id, Record.id.in_(records_ids))

if RecordInclude.responses in include:
if user_id:
if include and include.with_responses:
if not user_id:
query = query.options(joinedload(Record.responses))
else:
query = query.outerjoin(
Response, and_(Response.record_id == Record.id, Response.user_id == user_id)
).options(contains_eager(Record.responses))
else:
query = query.options(joinedload(Record.responses))

if RecordInclude.suggestions in include:
query = query.options(joinedload(Record.suggestions))
query = await _configure_query_relationships(query, include_params=include)

result = await db.execute(query)
records = result.unique().scalars().all()
Expand All @@ -396,11 +396,34 @@ async def get_records_by_ids(
return ordered_records


async def _configure_query_relationships(
query: "Select", include_params: Optional["RecordIncludeParam"] = None
) -> "Select":
if not include_params:
return query

if include_params.with_suggestions:
query = query.options(joinedload(Record.suggestions))

if include_params.with_all_vectors:
query = query.options(joinedload(Record.vectors))

elif include_params.with_some_vector:
vector_settings_ids_subquery = (
select(VectorSettings.id).filter(VectorSettings.name.in_(include_params.vectors)).subquery()
)
query = query.outerjoin(
Vector, and_(Vector.record_id == Record.id, Vector.vector_settings_id.in_(vector_settings_ids_subquery))
).options(contains_eager(Record.vectors))

return query


async def list_records_by_dataset_id(
db: "AsyncSession",
dataset_id: UUID,
user_id: Optional[UUID] = None,
include: List[RecordInclude] = [],
include: Optional["RecordIncludeParam"] = None,
response_statuses: List[ResponseStatusFilter] = [],
offset: int = 0,
limit: int = LIST_RECORDS_LIMIT,
Expand Down Expand Up @@ -433,12 +456,10 @@ async def list_records_by_dataset_id(
if response_status_filter_expressions:
records_query = records_query.filter(or_(*response_status_filter_expressions))

if RecordInclude.responses in include:
if include and include.with_responses:
records_query = records_query.options(contains_eager(Record.responses))

if RecordInclude.suggestions in include:
records_query = records_query.options(joinedload(Record.suggestions))

records_query = await _configure_query_relationships(query=records_query, include_params=include)
records_query = records_query.order_by(Record.inserted_at.asc()).offset(offset).limit(limit)
result_records = await db.execute(records_query)

Expand Down
1 change: 1 addition & 0 deletions src/argilla/server/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class UserRole(str, Enum):
class RecordInclude(str, Enum):
responses = "responses"
suggestions = "suggestions"
vectors = "vectors"


class QuestionType(str, Enum):
Expand Down
52 changes: 52 additions & 0 deletions src/argilla/server/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,10 +423,19 @@ class RecordGetterDict(GetterDict):
def get(self, key: str, default: Any) -> Any:
if key == "metadata":
return getattr(self._obj, "metadata_", None)

if key == "responses" and not self._obj.is_relationship_loaded("responses"):
return default

if key == "suggestions" and not self._obj.is_relationship_loaded("suggestions"):
return default

if key == "vectors":
if self._obj.is_relationship_loaded("vectors"):
return {vector.vector_settings.name: vector.value for vector in self._obj.vectors}
else:
return default

return super().get(key, default)


Expand All @@ -439,6 +448,7 @@ class Record(BaseModel):
# response: Optional[Response]
responses: Optional[List[Response]]
suggestions: Optional[List[Suggestion]]
vectors: Optional[Dict[str, List[float]]]
inserted_at: datetime
updated_at: datetime

Expand Down Expand Up @@ -508,6 +518,48 @@ class RecordsUpdate(BaseModel):
)


class RecordIncludeParam(BaseModel):
relationships: Optional[List[RecordInclude]] = PydanticField(None, alias="keys")
vectors: Optional[List[str]] = PydanticField(None, alias="vectors")

@root_validator
def check(cls, values: Dict[str, Any]) -> Dict[str, Any]:
relationships = values.get("relationships")
if not relationships:
return values

vectors = values.get("vectors")
if vectors is not None and len(vectors) > 0 and RecordInclude.vectors in relationships:
# TODO: once we have a exception handler for ValueError in v1, remove HTTPException
# raise ValueError("Cannot include both 'vectors' and 'relationships' in the same request")
raise HTTPException(
status_code=422,
detail="'include' query param cannot have both 'vectors' and 'vectors:vector_settings_name_1,vectors_settings_name_2,...'",
)

return values

@property
def with_responses(self) -> bool:
return self._has_relationships and RecordInclude.responses in self.relationships

@property
def with_suggestions(self) -> bool:
return self._has_relationships and RecordInclude.suggestions in self.relationships

@property
def with_all_vectors(self) -> bool:
return self._has_relationships and not self.vectors and RecordInclude.vectors in self.relationships

@property
def with_some_vector(self) -> bool:
return self.vectors is not None and len(self.vectors) > 0

@property
def _has_relationships(self):
return self.relationships is not None


NT = TypeVar("NT", int, float)


Expand Down
Loading
Loading