Skip to content

Commit

Permalink
feat: add metadata_filters to filter_by method (#3834)
Browse files Browse the repository at this point in the history
# Description

This PR adds the `pydantic.BaseModel` schemas for the `MetadataFilters`
i.e. `TermsMetadataFilter`, `IntegerMetadataFilter`, and
`FloatMetadataFilter` and also adds the arg `metadata_filters` to the
`filter_by` method in the `RemoteFeedbackDataset` to be able to filter
based on the pre-defined conditions for those metadata properties
defined in the `FeedbackDataset` in Argilla.

Closes #3835

**Type of change**

- [X] New feature (non-breaking change which adds functionality)
- [X] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

- [x] Add unit tests for `TermsMetadataFilter`, `IntegerMetadataFilter`,
and `FloatMetadataFilter`
- [ ] Add integration tests for the `filter_by` method with arg
`metadata_filters` -> On hold because Elastic Search indexing is not
working fine when triggering the tests
- [ ] Add integration tests for the `get_records` function in the SDK
with the arg `metadata_filters` -> On hold because Elastic Search
indexing is not working fine when triggering the tests

**Checklist**

- [ ] I added relevant documentation
- [X] follows the style guidelines of this project
- [X] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [X] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [X] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
Co-authored-by: Francisco Aranda <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 28, 2023
1 parent 1ee07e7 commit b308368
Show file tree
Hide file tree
Showing 21 changed files with 529 additions and 116 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ These are the section headers that we use:
- Added new endpoint `POST /api/v1/datasets/:dataset_id/metadata-properties` for dataset metadata property creation ([#3813](https://github.com/argilla-io/argilla/pull/3813))
- Added new endpoint `GET /api/v1/datasets/:dataset_id/metadata-properties` for listing dataset metadata property ([#3813](https://github.com/argilla-io/argilla/pull/3813))
- Added `TermsMetadataProperty`, `IntegerMetadataProperty` and `FloatMetadataProperty` classes allowing to define metadata properties for a `FeedbackDataset` ([#3818](https://github.com/argilla-io/argilla/pull/3818)).
- Added `metadata_filters` to `filter_by` method in `RemoteFeedbackDataset` to filter based on metadata i.e. `TermsMetadataFilter`, `IntegerMetadataFilter`, and `FloatMetadataFilter` ([#3834](https://github.com/argilla-io/argilla/pull/3834)).

### Changed

Expand Down
6 changes: 6 additions & 0 deletions src/argilla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,16 @@
from argilla.feedback import (
FeedbackDataset,
FeedbackRecord,
FloatMetadataFilter,
FloatMetadataProperty,
IntegerMetadataFilter,
IntegerMetadataProperty,
LabelQuestion,
MultiLabelQuestion,
RankingQuestion,
RatingQuestion,
ResponseSchema,
TermsMetadataFilter,
TermsMetadataProperty,
TextField,
TextQuestion,
Expand Down Expand Up @@ -115,6 +118,9 @@
"IntegerMetadataProperty",
"FloatMetadataProperty",
"TermsMetadataProperty",
"TermsMetadataFilter",
"IntegerMetadataFilter",
"FloatMetadataFilter",
],
"client.api": [
"copy",
Expand Down
32 changes: 17 additions & 15 deletions src/argilla/client/feedback/dataset/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@


class ArgillaMixin:
def __delete_dataset(self: "FeedbackDataset", client: "httpx.Client", id: UUID) -> None:
@staticmethod
def __delete_dataset(client: "httpx.Client", id: UUID) -> None:
try:
datasets_api_v1.delete_dataset(client=client, id=id)
except Exception as e:
Expand All @@ -97,7 +98,7 @@ def __add_fields(self: "FeedbackDataset", client: "httpx.Client", id: UUID) -> L
new_field = datasets_api_v1.add_field(client=client, id=id, field=field.to_server_payload()).parsed
fields.append(ArgillaMixin._parse_to_remote_field(new_field))
except Exception as e:
self.__delete_dataset(client=client, id=id)
ArgillaMixin.__delete_dataset(client=client, id=id)
raise Exception(
f"Failed while adding the field '{field.name}' to the `FeedbackDataset` in Argilla with"
f" exception: {e}"
Expand Down Expand Up @@ -135,7 +136,7 @@ def __add_questions(
).parsed
questions.append(ArgillaMixin._parse_to_remote_question(new_question))
except Exception as e:
self.__delete_dataset(client=client, id=id)
ArgillaMixin.__delete_dataset(client=client, id=id)
raise Exception(
f"Failed while adding the question '{question.name}' to the `FeedbackDataset` in Argilla with"
f" exception: {e}"
Expand Down Expand Up @@ -175,18 +176,19 @@ def __add_metadata_properties(
).parsed
metadata_properties.append(ArgillaMixin._parse_to_remote_metadata_property(new_metadata_property))
except Exception as e:
self.__delete_dataset(client=client, id=id)
ArgillaMixin.__delete_dataset(client=client, id=id)
raise Exception(
f"Failed while adding the metadata property '{metadata_property.name}' to the `FeedbackDataset` in"
f" Argilla with exception: {e}"
) from e
return metadata_properties

def __publish_dataset(self: "FeedbackDataset", client: "httpx.Client", id: UUID) -> None:
@staticmethod
def __publish_dataset(client: "httpx.Client", id: UUID) -> None:
try:
datasets_api_v1.publish_dataset(client=client, id=id)
except Exception as e:
self.__delete_dataset(client=client, id=id)
ArgillaMixin.__delete_dataset(client=client, id=id)
raise Exception(f"Failed while publishing the `FeedbackDataset` in Argilla with exception: {e}") from e

def __push_records(
Expand All @@ -212,13 +214,13 @@ def __push_records(
],
)
except Exception as e:
self.__delete_dataset(client=client, id=id)
ArgillaMixin.__delete_dataset(client=client, id=id)
raise Exception(
f"Failed while adding the records to the `FeedbackDataset` in Argilla with exception: {e}"
) from e

def push_to_argilla(
self: "FeedbackDataset",
self: Union["FeedbackDataset", "ArgillaMixin"],
name: str,
workspace: Optional[Union[str, Workspace]] = None,
show_progress: bool = False,
Expand Down Expand Up @@ -267,7 +269,7 @@ def push_to_argilla(

metadata_properties = self.__add_metadata_properties(client=httpx_client, id=argilla_id)

self.__publish_dataset(client=httpx_client, id=argilla_id)
ArgillaMixin.__publish_dataset(client=httpx_client, id=argilla_id)

self.__push_records(
client=httpx_client, id=argilla_id, show_progress=show_progress, question_name_to_id=question_name_to_id
Expand Down Expand Up @@ -351,9 +353,9 @@ def from_argilla(
)
)

fields = cls.__get_fields(client=httpx_client, id=existing_dataset.id)
questions = cls.__get_questions(client=httpx_client, id=existing_dataset.id)
metadata_properties = cls.__get_metadata_properties(client=httpx_client, id=existing_dataset.id)
fields = ArgillaMixin.__get_fields(client=httpx_client, id=existing_dataset.id)
questions = ArgillaMixin.__get_questions(client=httpx_client, id=existing_dataset.id)
metadata_properties = ArgillaMixin.__get_metadata_properties(client=httpx_client, id=existing_dataset.id)

return RemoteFeedbackDataset(
client=httpx_client,
Expand Down Expand Up @@ -411,8 +413,8 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[
workspace=workspace if workspace is not None else Workspace.from_id(dataset.workspace_id),
created_at=dataset.inserted_at,
updated_at=dataset.updated_at,
fields=cls.__get_fields(client=httpx_client, id=dataset.id),
questions=cls.__get_questions(client=httpx_client, id=dataset.id),
fields=ArgillaMixin.__get_fields(client=httpx_client, id=dataset.id),
questions=ArgillaMixin.__get_questions(client=httpx_client, id=dataset.id),
guidelines=dataset.guidelines or None,
)
for dataset in datasets
Expand All @@ -421,7 +423,7 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[

class UnificationMixin:
def unify_responses(
self,
self: "FeedbackDataset",
question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion],
strategy: Union[
str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy
Expand Down
7 changes: 5 additions & 2 deletions src/argilla/client/feedback/dataset/remote/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ def _fetch_records(self, offset: int, limit: int) -> "FeedbackRecordsModel":
pass

@abstractmethod
def add(self) -> None:
def add(
self, records: Union["FeedbackRecord", Dict[str, Any], List[Union["FeedbackRecord", Dict[str, Any]]]], **kwargs
) -> None:
pass

@abstractmethod
def delete(self) -> None:
def delete(self, records: List[RemoteFeedbackRecord]) -> None:
pass


Expand Down Expand Up @@ -273,6 +275,7 @@ def pull(self) -> "FeedbackDataset":
fields=self.fields,
questions=self.questions,
guidelines=self.guidelines,
metadata_properties=self.metadata_properties,
)
instance.add_records(
records=[record.to_local() for record in self._records],
Expand Down
24 changes: 18 additions & 6 deletions src/argilla/client/feedback/dataset/remote/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import httpx

from argilla.client.feedback.schemas.metadata import MetadataFilters
from argilla.client.feedback.schemas.types import (
AllowedRemoteFieldTypes,
AllowedRemoteMetadataPropertyTypes,
Expand Down Expand Up @@ -155,21 +156,33 @@ def __init__(
)

def filter_by(
self, response_status: Union[FeedbackResponseStatusFilter, List[FeedbackResponseStatusFilter]]
self,
*,
response_status: Optional[Union[FeedbackResponseStatusFilter, List[FeedbackResponseStatusFilter]]] = None,
metadata_filters: Optional[Union["MetadataFilters", List["MetadataFilters"]]] = None,
) -> FilteredRemoteFeedbackDataset:
"""Filters the current `RemoteFeedbackDataset` based on the `response_status` of
the responses of the records in Argilla. This method creates a new class instance
of `FilteredRemoteFeedbackDataset` with the given filters.
Args:
response_status: the response status/es to filter the dataset by. Can be
one of: draft, pending, submitted, and discarded.
one of: draft, pending, submitted, and discarded. Defaults to `None`.
metadata_filters: the metadata filters to filter the dataset by. Can be
one of: `TermsMetadataFilter`, `IntegerMetadataFilter`, and
`FloatMetadataFilter`. Defaults to `None`.
Returns:
A new instance of `FilteredRemoteFeedbackDataset` with the given filters.
"""
if not isinstance(response_status, list):
if not response_status and not metadata_filters:
raise ValueError("At least one of `response_status` or `metadata_filters` must be provided.")
if response_status and not isinstance(response_status, list):
response_status = [response_status]
if metadata_filters and not isinstance(metadata_filters, list):
metadata_filters = [metadata_filters]

# accessing records later
return FilteredRemoteFeedbackDataset(
client=self._client,
id=self.id,
Expand All @@ -180,9 +193,8 @@ def filter_by(
fields=self.fields,
questions=self.questions,
guidelines=self.guidelines,
filters={
"response_status": [status.value if hasattr(status, "value") else status for status in response_status]
},
response_status=response_status,
metadata_filters=metadata_filters,
)

@allowed_for_roles(roles=[UserRole.owner, UserRole.admin])
Expand Down
42 changes: 33 additions & 9 deletions src/argilla/client/feedback/dataset/remote/filtered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from argilla.client.feedback.dataset.remote.base import RemoteFeedbackDatasetBase, RemoteFeedbackRecordsBase
from argilla.client.feedback.schemas.metadata import MetadataFilters
from argilla.client.sdk.v1.datasets import api as datasets_api_v1
from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel, FeedbackResponseStatusFilter

if TYPE_CHECKING:
from uuid import UUID
Expand All @@ -26,19 +28,38 @@
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset
from argilla.client.feedback.schemas.records import FeedbackRecord
from argilla.client.feedback.schemas.remote.records import RemoteFeedbackRecord
from argilla.client.feedback.schemas.types import AllowedRemoteFieldTypes, AllowedRemoteQuestionTypes
from argilla.client.sdk.v1.datasets.models import FeedbackRecordsModel
from argilla.client.workspaces import Workspace


class FilteredRemoteFeedbackRecords(RemoteFeedbackRecordsBase):
def __init__(self, dataset: "RemoteFeedbackDataset", filters: Dict[str, Any]) -> None:
def __init__(
self,
dataset: "RemoteFeedbackDataset",
response_status: Optional[List["FeedbackResponseStatusFilter"]] = None,
metadata_filters: Optional[List["MetadataFilters"]] = None,
) -> None:
super().__init__(dataset=dataset)

self._filters = filters
self._response_status = (
[
status.value if hasattr(status, "value") else FeedbackResponseStatusFilter(status).value
for status in response_status
]
if response_status
else None
)
self._metadata_filters = (
[metadata_filter.query_string for metadata_filter in metadata_filters] if metadata_filters else None
)

def __len__(self) -> None:
raise NotImplementedError("`__len__` does not work for filtered datasets.")
warnings.warn(
"The `records` of a filtered dataset in Argilla are being lazily loaded"
" and len computation may add undesirable extra computation. You can fetch"
"records using\n`ds.pull()`\nor iterate over results to know the length of the result:\n"
"`records = [r for r in ds.records]\n",
stacklevel=1,
)

def _fetch_records(self, offset: int, limit: int) -> "FeedbackRecordsModel":
"""Fetches a batch of records from Argilla."""
Expand All @@ -47,7 +68,8 @@ def _fetch_records(self, offset: int, limit: int) -> "FeedbackRecordsModel":
id=self._dataset.id,
offset=offset,
limit=limit,
**self._filters,
response_status=self._response_status,
metadata_filters=self._metadata_filters,
).parsed

def add(
Expand Down Expand Up @@ -76,7 +98,8 @@ def __init__(
fields: List["AllowedRemoteFieldTypes"],
questions: List["AllowedRemoteQuestionTypes"],
guidelines: Optional[str] = None,
filters: Dict[str, Any] = {},
response_status: Optional[List["FeedbackResponseStatusFilter"]] = None,
metadata_filters: Optional[List["MetadataFilters"]] = None,
) -> None:
super().__init__(
client=client,
Expand All @@ -89,7 +112,8 @@ def __init__(
questions=questions,
guidelines=guidelines,
# kwargs
filters=filters,
response_status=response_status,
metadata_filters=metadata_filters,
)

def delete(self) -> None:
Expand Down
29 changes: 23 additions & 6 deletions src/argilla/client/feedback/dataset/remote/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Iterator, List, Union

from argilla.client.feedback.constants import FETCHING_BATCH_SIZE
Expand All @@ -36,10 +37,17 @@ def __getitem__(
Returns:
Either the record of the given index, or a list with the records at the given indexes.
"""
try:
num_records = len(self)
except NotImplementedError:
num_records = None
if not isinstance(key, int):
raise NotImplementedError(f"`key`={key} is not supported for this dataset. Only `int` is supported.")

offsets = []
limit = None
num_records = len(self)
if isinstance(key, slice):

if isinstance(key, slice) and num_records is not None:
start, stop, step = key.indices(num_records)
if step is not None and step != 1:
raise ValueError("When providing a `slice` just `step=None` or `step=1` are allowed.")
Expand All @@ -64,10 +72,11 @@ def __getitem__(
offsets[-1] = stop - (stop % FETCHING_BATCH_SIZE) + 1
limits[-1] = (stop % FETCHING_BATCH_SIZE) - 1
elif isinstance(key, int):
if key < 0:
key += num_records
if key < 0 or key >= num_records:
raise IndexError(f"Index {key} is out of range, dataset has {num_records} records.")
if num_records is not None:
if key < 0:
key += num_records
if key < 0 or key >= num_records:
raise IndexError(f"Index {key} is out of range, dataset has {num_records} records.")
offsets = [key]
limits = [1]
else:
Expand All @@ -76,6 +85,8 @@ def __getitem__(
records = []
for offset, limit in zip(offsets, limits):
fetched_records = self._fetch_records(offset=offset, limit=limit)
if len(fetched_records.items) == 0:
break
records.extend(
[
RemoteFeedbackRecord.from_api(
Expand All @@ -84,6 +95,10 @@ def __getitem__(
for record in fetched_records.items
]
)
if len(records) == 0:
raise IndexError(
"No records were found in the dataset in Argilla for the given index(es) and/or filter(s) if any."
)
return records[0] if isinstance(key, int) else records

@allowed_for_roles(roles=[UserRole.owner, UserRole.admin])
Expand All @@ -94,6 +109,8 @@ def __iter__(
current_batch = 0
while True:
batch = self._fetch_records(offset=FETCHING_BATCH_SIZE * current_batch, limit=FETCHING_BATCH_SIZE)
if len(batch.items) == 0:
break
for record in batch.items:
yield RemoteFeedbackRecord.from_api(
record, question_id_to_name=self._question_id_to_name, client=self._client
Expand Down
Loading

0 comments on commit b308368

Please sign in to comment.