From e8a8f8939ceea006d76566f34e69ea8fd37c82dd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Feb 2023 17:27:39 +0100 Subject: [PATCH 01/45] [pre-commit.ci] pre-commit autoupdate (#2299) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black: 22.12.0 → 23.1.0](https://github.com/psf/black/compare/22.12.0...23.1.0) - [github.com/pycqa/isort: 5.11.5 → 5.12.0](https://github.com/pycqa/isort/compare/5.11.5...5.12.0) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- scripts/migrations/es_migration_25042021.py | 1 - src/argilla/client/apis/metrics.py | 1 - src/argilla/client/apis/search.py | 1 - src/argilla/client/apis/status.py | 1 - src/argilla/client/datasets.py | 1 + src/argilla/client/sdk/commons/api.py | 1 - src/argilla/client/sdk/commons/errors.py | 2 -- src/argilla/client/sdk/datasets/api.py | 1 - src/argilla/client/sdk/text_classification/api.py | 2 -- src/argilla/logging.py | 2 -- src/argilla/monitoring/_flair.py | 1 - src/argilla/monitoring/_spacy.py | 1 - src/argilla/monitoring/_transformers.py | 3 +-- src/argilla/monitoring/asgi.py | 2 +- src/argilla/server/apis/v0/handlers/datasets.py | 2 -- src/argilla/server/apis/v0/handlers/metrics.py | 3 --- src/argilla/server/apis/v0/handlers/records_update.py | 2 -- src/argilla/server/apis/v0/handlers/text2text.py | 2 -- src/argilla/server/apis/v0/handlers/text_classification.py | 7 ------- .../v0/handlers/text_classification_dataset_settings.py | 3 --- .../server/apis/v0/handlers/token_classification.py | 2 -- .../v0/handlers/token_classification_dataset_settings.py | 2 -- src/argilla/server/apis/v0/models/text2text.py | 1 - src/argilla/server/apis/v0/models/text_classification.py | 4 ---- src/argilla/server/apis/v0/models/token_classification.py | 2 -- .../server/apis/v0/validators/text_classification.py | 1 - src/argilla/server/commons/config.py | 3 --- src/argilla/server/commons/models.py | 1 - src/argilla/server/commons/telemetry.py | 2 -- src/argilla/server/daos/backend/client_adapters/factory.py | 1 - .../server/daos/backend/client_adapters/opensearch.py | 3 --- src/argilla/server/daos/backend/generic_elastic.py | 2 -- .../server/daos/backend/mappings/token_classification.py | 1 - src/argilla/server/daos/backend/metrics/base.py | 7 ++++--- .../server/daos/backend/metrics/text_classification.py | 1 - src/argilla/server/daos/backend/query_helpers.py | 1 - src/argilla/server/daos/backend/search/model.py | 2 -- src/argilla/server/daos/backend/search/query_builder.py | 3 --- src/argilla/server/daos/datasets.py | 3 --- src/argilla/server/daos/models/records.py | 3 --- src/argilla/server/daos/records.py | 2 -- src/argilla/server/errors/base_errors.py | 1 - src/argilla/server/security/auth_provider/base.py | 2 +- src/argilla/server/server.py | 1 - src/argilla/server/services/datasets.py | 3 --- src/argilla/server/services/search/model.py | 1 - src/argilla/server/services/search/service.py | 1 - src/argilla/server/services/storage/service.py | 2 -- src/argilla/server/services/tasks/text2text/service.py | 1 - .../tasks/text_classification/labeling_rules_service.py | 1 - .../server/services/tasks/text_classification/model.py | 3 --- .../server/services/tasks/text_classification/service.py | 1 - .../server/services/tasks/token_classification/metrics.py | 1 - .../server/services/tasks/token_classification/model.py | 4 ---- .../server/services/tasks/token_classification/service.py | 1 - tests/client/functional_tests/test_record_update.py | 1 - tests/client/functional_tests/test_scan_raw_records.py | 1 - tests/client/sdk/commons/test_client.py | 2 -- tests/client/test_api.py | 4 ---- tests/client/test_client_errors.py | 1 - tests/conftest.py | 2 -- tests/functional_tests/search/test_search_service.py | 1 - tests/functional_tests/test_log_for_text_classification.py | 2 -- .../functional_tests/test_log_for_token_classification.py | 1 - tests/labeling/text_classification/test_rule.py | 5 ----- tests/monitoring/test_flair_monitoring.py | 3 +-- tests/monitoring/test_monitor.py | 1 - tests/monitoring/test_transformers_monitoring.py | 2 -- tests/server/backend/test_query_builder.py | 1 - tests/server/daos/models/test_records.py | 1 - tests/server/info/test_api.py | 2 -- tests/server/security/test_model.py | 2 -- tests/server/test_errors.py | 1 - tests/server/text2text/test_model.py | 1 - tests/server/text_classification/test_model.py | 3 --- tests/server/token_classification/test_model.py | 4 ---- 77 files changed, 11 insertions(+), 143 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index af7daea006..f67cae6ff1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,13 +20,13 @@ repos: # - --remove-header - repo: https://github.com/psf/black - rev: 22.12.0 + rev: 23.1.0 hooks: - id: black additional_dependencies: ["click==8.0.4"] - repo: https://github.com/pycqa/isort - rev: 5.11.5 + rev: 5.12.0 hooks: - id: isort diff --git a/scripts/migrations/es_migration_25042021.py b/scripts/migrations/es_migration_25042021.py index 01d4ec2a14..63776fe977 100644 --- a/scripts/migrations/es_migration_25042021.py +++ b/scripts/migrations/es_migration_25042021.py @@ -103,7 +103,6 @@ def map_doc_2_action( if __name__ == "__main__": - client = Elasticsearch(hosts=settings.elasticsearch) for dataset in settings.migration_datasets: diff --git a/src/argilla/client/apis/metrics.py b/src/argilla/client/apis/metrics.py index 50bc63dc5c..abb4cf88d2 100644 --- a/src/argilla/client/apis/metrics.py +++ b/src/argilla/client/apis/metrics.py @@ -19,7 +19,6 @@ class MetricsAPI(AbstractApi): - _API_URL_PATTERN = "/api/datasets/{task}/{name}/metrics/{metric}:summary" def metric_summary( diff --git a/src/argilla/client/apis/search.py b/src/argilla/client/apis/search.py index 9ac808ad25..f8a1b66a98 100644 --- a/src/argilla/client/apis/search.py +++ b/src/argilla/client/apis/search.py @@ -37,7 +37,6 @@ class SearchResults: class Search(AbstractApi): - _API_URL_PATTERN = "/api/datasets/{name}/{task}:search" def search_records( diff --git a/src/argilla/client/apis/status.py b/src/argilla/client/apis/status.py index 4d29c03afd..ce5d42372e 100644 --- a/src/argilla/client/apis/status.py +++ b/src/argilla/client/apis/status.py @@ -26,7 +26,6 @@ @dataclasses.dataclass(frozen=True) class ApiInfo: - version: Optional[str] = None diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py index 0a05463d8a..3791292571 100644 --- a/src/argilla/client/datasets.py +++ b/src/argilla/client/datasets.py @@ -1137,6 +1137,7 @@ def __only_annotations__(self, data) -> bool: def _to_datasets_dict(self) -> Dict: """Helper method to put token classification records in a `datasets.Dataset`""" + # create a dict first, where we make the necessary transformations def entities_to_dict( entities: Optional[ diff --git a/src/argilla/client/sdk/commons/api.py b/src/argilla/client/sdk/commons/api.py index 1974d443a7..9ccc3eb3bb 100644 --- a/src/argilla/client/sdk/commons/api.py +++ b/src/argilla/client/sdk/commons/api.py @@ -103,7 +103,6 @@ async def async_bulk( def build_bulk_response( response: httpx.Response, name: str, body: Any ) -> Response[BulkResponse]: - if 200 <= response.status_code < 400: return Response( status_code=response.status_code, diff --git a/src/argilla/client/sdk/commons/errors.py b/src/argilla/client/sdk/commons/errors.py index bb38fa0432..003226f40a 100644 --- a/src/argilla/client/sdk/commons/errors.py +++ b/src/argilla/client/sdk/commons/errors.py @@ -64,7 +64,6 @@ def __str__(self): class ArApiResponseError(BaseClientError): - HTTP_STATUS: int def __init__(self, **ctx): @@ -105,7 +104,6 @@ class ValidationApiError(ArApiResponseError): HTTP_STATUS = 422 def __init__(self, client_ctx, params, **ctx): - for error in params.get("errors", []): current_level = client_ctx for loc in error["loc"]: diff --git a/src/argilla/client/sdk/datasets/api.py b/src/argilla/client/sdk/datasets/api.py index 08ab036e8f..b836b949d0 100644 --- a/src/argilla/client/sdk/datasets/api.py +++ b/src/argilla/client/sdk/datasets/api.py @@ -89,7 +89,6 @@ def delete_dataset( def _build_response( response: httpx.Response, name: str ) -> Response[Union[Dataset, ErrorMessage, HTTPValidationError]]: - if response.status_code == 200: parsed_response = Dataset(**response.json()) return Response( diff --git a/src/argilla/client/sdk/text_classification/api.py b/src/argilla/client/sdk/text_classification/api.py index 5699566eba..d6cd7b64a5 100644 --- a/src/argilla/client/sdk/text_classification/api.py +++ b/src/argilla/client/sdk/text_classification/api.py @@ -97,7 +97,6 @@ def fetch_dataset_labeling_rules( client: AuthenticatedClient, name: str, ) -> Response[Union[List[LabelingRule], HTTPValidationError, ErrorMessage]]: - url = "{}/api/datasets/TextClassification/{name}/labeling/rules".format( client.base_url, name=name ) @@ -118,7 +117,6 @@ def dataset_rule_metrics( query: str, label: str, ) -> Response[Union[LabelingRuleMetricsSummary, HTTPValidationError, ErrorMessage]]: - url = "{}/api/datasets/TextClassification/{name}/labeling/rules/{query}/metrics?label={label}".format( client.base_url, name=name, query=query, label=label ) diff --git a/src/argilla/logging.py b/src/argilla/logging.py index 076467aa7e..94b1cc249e 100644 --- a/src/argilla/logging.py +++ b/src/argilla/logging.py @@ -84,7 +84,6 @@ def __init__(self, *args, **kwargs): self.emit = lambda record: None def emit(self, record: logging.LogRecord): - try: level = logger.level(record.levelname).name except AttributeError: @@ -100,7 +99,6 @@ def emit(self, record: logging.LogRecord): def configure_logging(): - """Normalizes logging configuration for argilla and its dependencies""" intercept_handler = LoguruLoggerHandler() if not intercept_handler.is_available: diff --git a/src/argilla/monitoring/_flair.py b/src/argilla/monitoring/_flair.py index b9deeb6bf4..6330b9fc78 100644 --- a/src/argilla/monitoring/_flair.py +++ b/src/argilla/monitoring/_flair.py @@ -87,7 +87,6 @@ def flair_monitor( sample_rate: float, log_interval: float, ) -> Optional[SequenceTagger]: - return FlairMonitor( pl, api=api, diff --git a/src/argilla/monitoring/_spacy.py b/src/argilla/monitoring/_spacy.py index a238223e56..a1c80833ac 100644 --- a/src/argilla/monitoring/_spacy.py +++ b/src/argilla/monitoring/_spacy.py @@ -64,7 +64,6 @@ def doc2token_classification( def _prepare_log_data( self, docs_info: Tuple[Doc, Optional[Dict[str, Any]]] ) -> Dict[str, Any]: - return dict( records=[ self.doc2token_classification( diff --git a/src/argilla/monitoring/_transformers.py b/src/argilla/monitoring/_transformers.py index 8c6a1d4300..46fe973f26 100644 --- a/src/argilla/monitoring/_transformers.py +++ b/src/argilla/monitoring/_transformers.py @@ -56,7 +56,6 @@ def _prepare_log_data( data: List[Tuple[str, Dict[str, Any], List[LabelPrediction]]], multi_label: bool = False, ) -> Dict[str, Any]: - agent = self.model_config.name_or_path records = [] @@ -100,7 +99,7 @@ def __call__( sequences: Union[str, List[str]], candidate_labels: List[str], *args, - **kwargs + **kwargs, ): metadata = (kwargs.pop("metadata", None) or {}).copy() hypothesis_template = kwargs.get("hypothesis_template", "@default") diff --git a/src/argilla/monitoring/asgi.py b/src/argilla/monitoring/asgi.py index 48d0dcbfcb..4b2119be43 100644 --- a/src/argilla/monitoring/asgi.py +++ b/src/argilla/monitoring/asgi.py @@ -112,7 +112,7 @@ def __init__( agent: Optional[str] = None, tags: Dict[str, str] = None, *args, - **kwargs + **kwargs, ): BaseHTTPMiddleware.__init__(self, *args, **kwargs) diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index bb66e9dbbc..01d7d75478 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -69,7 +69,6 @@ async def create_dataset( datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["create:datasets"]), ) -> Dataset: - owner = user.check_workspace(ws_params.workspace) dataset_class = TasksFactory.get_task_dataset(request.task) @@ -114,7 +113,6 @@ def update_dataset( service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> Dataset: - found_ds = service.find_by_name( user=current_user, name=name, workspace=ds_params.workspace ) diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py index 51c3659506..96e0364697 100644 --- a/src/argilla/server/apis/v0/handlers/metrics.py +++ b/src/argilla/server/apis/v0/handlers/metrics.py @@ -29,7 +29,6 @@ class MetricInfo(BaseModel): - id: str = Field(description="The metric id") name: str = Field(description="The metric name") description: Optional[str] = Field( @@ -39,7 +38,6 @@ class MetricInfo(BaseModel): @dataclass class MetricSummaryParams: - interval: Optional[float] = Query( default=None, gt=0.0, @@ -53,7 +51,6 @@ class MetricSummaryParams: def configure_router(router: APIRouter, cfg: TaskConfig): - base_metrics_endpoint = f"/{cfg.task}/{{name}}/metrics" new_base_metrics_endpoint = f"/{{name}}/{cfg.task}/metrics" diff --git a/src/argilla/server/apis/v0/handlers/records_update.py b/src/argilla/server/apis/v0/handlers/records_update.py index 87d25dde54..f77d391e53 100644 --- a/src/argilla/server/apis/v0/handlers/records_update.py +++ b/src/argilla/server/apis/v0/handlers/records_update.py @@ -36,7 +36,6 @@ def configure_router(router: APIRouter): - RecordType = Union[ TextClassificationRecord, TokenClassificationRecord, @@ -61,7 +60,6 @@ async def partial_update_dataset_record( storage: RecordsStorageService = Depends(RecordsStorageService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> RecordType: - dataset = service.find_by_name( user=current_user, name=name, diff --git a/src/argilla/server/apis/v0/handlers/text2text.py b/src/argilla/server/apis/v0/handlers/text2text.py index 2f34a6fb3b..d0ca90bb62 100644 --- a/src/argilla/server/apis/v0/handlers/text2text.py +++ b/src/argilla/server/apis/v0/handlers/text2text.py @@ -78,7 +78,6 @@ async def bulk_records( datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> BulkResponse: - task = task_type owner = current_user.check_workspace(common_params.workspace) try: @@ -129,7 +128,6 @@ def search_records( datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> Text2TextSearchResults: - search = search or Text2TextSearchRequest() query = search.query or Text2TextQuery() dataset = datasets.find_by_name( diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py index 857318415e..414529dbbf 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification.py +++ b/src/argilla/server/apis/v0/handlers/text_classification.py @@ -96,7 +96,6 @@ async def bulk_records( validator: DatasetValidator = Depends(DatasetValidator.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> BulkResponse: - task = task_type owner = current_user.check_workspace(common_params.workspace) try: @@ -325,7 +324,6 @@ async def list_labeling_rules( ), current_user: User = Security(auth.get_user, scopes=[]), ) -> List[LabelingRule]: - dataset = datasets.find_by_name( user=current_user, name=name, @@ -357,7 +355,6 @@ async def create_rule( datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> LabelingRule: - dataset = datasets.find_by_name( user=current_user, name=name, @@ -398,7 +395,6 @@ async def compute_rule_metrics( datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> LabelingRuleMetricsSummary: - dataset = datasets.find_by_name( user=current_user, name=name, @@ -454,7 +450,6 @@ async def delete_labeling_rule( datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> None: - dataset = datasets.find_by_name( user=current_user, name=name, @@ -484,7 +479,6 @@ async def get_rule( datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> LabelingRule: - dataset = datasets.find_by_name( user=current_user, name=name, @@ -518,7 +512,6 @@ async def update_rule( datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> LabelingRule: - dataset = datasets.find_by_name( user=current_user, name=name, diff --git a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py index fc9fad8ec4..e40c97e079 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py @@ -36,7 +36,6 @@ def configure_router(router: APIRouter): - task = TaskType.text_classification base_endpoint = f"/{task}/{{name}}/settings" new_base_endpoint = f"/{{name}}/{task}/settings" @@ -57,7 +56,6 @@ async def get_dataset_settings( datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["read:dataset.settings"]), ) -> TextClassificationSettings: - found_ds = datasets.find_by_name( user=user, name=name, @@ -90,7 +88,6 @@ async def save_settings( validator: DatasetValidator = Depends(DatasetValidator.get_instance), user: User = Security(auth.get_user, scopes=["write:dataset.settings"]), ) -> TextClassificationSettings: - found_ds = datasets.find_by_name( user=user, name=name, diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py index afa59342f7..940775a132 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification.py +++ b/src/argilla/server/apis/v0/handlers/token_classification.py @@ -91,7 +91,6 @@ async def bulk_records( validator: DatasetValidator = Depends(DatasetValidator.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> BulkResponse: - task = task_type owner = current_user.check_workspace(common_params.workspace) try: @@ -153,7 +152,6 @@ def search_records( datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> TokenClassificationSearchResults: - search = search or TokenClassificationSearchRequest() query = search.query or TokenClassificationQuery() diff --git a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py index 24e9048003..4756aeb9da 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py @@ -56,7 +56,6 @@ async def get_dataset_settings( datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["read:dataset.settings"]), ) -> TokenClassificationSettings: - found_ds = datasets.find_by_name( user=user, name=name, @@ -89,7 +88,6 @@ async def save_settings( validator: DatasetValidator = Depends(DatasetValidator.get_instance), user: User = Security(auth.get_user, scopes=["write:dataset.settings"]), ) -> TokenClassificationSettings: - found_ds = datasets.find_by_name( user=user, name=name, diff --git a/src/argilla/server/apis/v0/models/text2text.py b/src/argilla/server/apis/v0/models/text2text.py index ff8845a8eb..e2730f46e3 100644 --- a/src/argilla/server/apis/v0/models/text2text.py +++ b/src/argilla/server/apis/v0/models/text2text.py @@ -51,7 +51,6 @@ def sort_sentences_by_score(cls, sentences: List[Text2TextPrediction]): class Text2TextRecordInputs(BaseRecordInputs[Text2TextAnnotation]): - text: str diff --git a/src/argilla/server/apis/v0/models/text_classification.py b/src/argilla/server/apis/v0/models/text_classification.py index cabe2f9c7a..84f1868f78 100644 --- a/src/argilla/server/apis/v0/models/text_classification.py +++ b/src/argilla/server/apis/v0/models/text_classification.py @@ -76,7 +76,6 @@ def initialize_labels(cls, values): class CreateLabelingRule(UpdateLabelingRule): - query: str = Field(description="The es rule query") @validator("query") @@ -109,7 +108,6 @@ class TextClassificationAnnotation(_TextClassificationAnnotation): class TextClassificationRecordInputs(BaseRecordInputs[TextClassificationAnnotation]): - inputs: Dict[str, Union[str, List[str]]] multi_label: bool = False explanation: Optional[Dict[str, List[TokenAttributions]]] = None @@ -122,7 +120,6 @@ class TextClassificationRecord( class TextClassificationBulkRequest(UpdateDatasetRequest): - records: List[TextClassificationRecordInputs] @validator("records") @@ -138,7 +135,6 @@ def check_multi_label_integrity(cls, records: List[TextClassificationRecord]): class TextClassificationQuery(ServiceBaseRecordsQuery): - predicted_as: List[str] = Field(default_factory=list) annotated_as: List[str] = Field(default_factory=list) score: Optional[ScoreRange] = Field(default=None) diff --git a/src/argilla/server/apis/v0/models/token_classification.py b/src/argilla/server/apis/v0/models/token_classification.py index 8e9bbbf47c..63eaaa9632 100644 --- a/src/argilla/server/apis/v0/models/token_classification.py +++ b/src/argilla/server/apis/v0/models/token_classification.py @@ -42,7 +42,6 @@ class TokenClassificationAnnotation(_TokenClassificationAnnotation): class TokenClassificationRecordInputs(BaseRecordInputs[TokenClassificationAnnotation]): - text: str = Field() tokens: List[str] = Field(min_items=1) # TODO(@frascuchon): Delete this field and all related logic @@ -73,7 +72,6 @@ class TokenClassificationBulkRequest(UpdateDatasetRequest): class TokenClassificationQuery(ServiceBaseRecordsQuery): - predicted_as: List[str] = Field(default_factory=list) annotated_as: List[str] = Field(default_factory=list) score: Optional[ScoreRange] = Field(default=None) diff --git a/src/argilla/server/apis/v0/validators/text_classification.py b/src/argilla/server/apis/v0/validators/text_classification.py index 02787fa1fc..e225749bed 100644 --- a/src/argilla/server/apis/v0/validators/text_classification.py +++ b/src/argilla/server/apis/v0/validators/text_classification.py @@ -38,7 +38,6 @@ # TODO(@frascuchon): Move validator and its models to the service layer class DatasetValidator: - _INSTANCE = None def __init__(self, datasets: DatasetsService, metrics: MetricsService): diff --git a/src/argilla/server/commons/config.py b/src/argilla/server/commons/config.py index 988c18359b..b12dfc368c 100644 --- a/src/argilla/server/commons/config.py +++ b/src/argilla/server/commons/config.py @@ -34,7 +34,6 @@ class TaskConfig(BaseModel): class TasksFactory: - __REGISTERED_TASKS__ = dict() @classmethod @@ -46,7 +45,6 @@ def register_task( record_class: Type[ServiceRecord], metrics: Optional[Type[ServiceBaseTaskMetrics]] = None, ): - cls.__REGISTERED_TASKS__[task_type] = TaskConfig( task=task_type, dataset=dataset_class, @@ -99,7 +97,6 @@ def find_task_metric( def find_task_metrics( cls, task: TaskType, metric_ids: Set[str] ) -> List[ServiceBaseMetric]: - if not metric_ids: return [] diff --git a/src/argilla/server/commons/models.py b/src/argilla/server/commons/models.py index ff51220f44..c0d82925ad 100644 --- a/src/argilla/server/commons/models.py +++ b/src/argilla/server/commons/models.py @@ -23,7 +23,6 @@ class TaskStatus(str, Enum): class TaskType(str, Enum): - text_classification = "TextClassification" token_classification = "TokenClassification" text2text = "Text2Text" diff --git a/src/argilla/server/commons/telemetry.py b/src/argilla/server/commons/telemetry.py index 24e0a86a0e..3f39929033 100644 --- a/src/argilla/server/commons/telemetry.py +++ b/src/argilla/server/commons/telemetry.py @@ -51,7 +51,6 @@ def _configure_analytics(disable_send: bool = False) -> Client: @dataclasses.dataclass class _TelemetryClient: - client: Client __INSTANCE__: "_TelemetryClient" = None @@ -76,7 +75,6 @@ def get(cls): return cls.__INSTANCE__ def __post_init__(self): - from argilla import __version__ self.__server_id__ = uuid.UUID(int=uuid.getnode()) diff --git a/src/argilla/server/daos/backend/client_adapters/factory.py b/src/argilla/server/daos/backend/client_adapters/factory.py index ad145f0eba..2ab8ac2d33 100644 --- a/src/argilla/server/daos/backend/client_adapters/factory.py +++ b/src/argilla/server/daos/backend/client_adapters/factory.py @@ -38,7 +38,6 @@ def get( retry_on_timeout: bool = True, max_retries: int = 5, ) -> IClientAdapter: - ( client_class, support_vector_search, diff --git a/src/argilla/server/daos/backend/client_adapters/opensearch.py b/src/argilla/server/daos/backend/client_adapters/opensearch.py index b71323a401..088b48da84 100644 --- a/src/argilla/server/daos/backend/client_adapters/opensearch.py +++ b/src/argilla/server/daos/backend/client_adapters/opensearch.py @@ -41,7 +41,6 @@ @dataclasses.dataclass class OpenSearchClient(IClientAdapter): - index_shards: int config_backend: Dict[str, Any] @@ -110,7 +109,6 @@ def search_docs( enable_highlight: bool = True, routing: str = None, ) -> Dict[str, Any]: - with self.error_handling(index=index): highlight = self.highlight if enable_highlight else None es_query = self.query_builder.map_2_es_query( @@ -782,7 +780,6 @@ def _normalize_document( highlight: Optional[HighlightParser] = None, is_phrase_query: bool = True, ): - data = { **document["_source"], "id": document["_id"], diff --git a/src/argilla/server/daos/backend/generic_elastic.py b/src/argilla/server/daos/backend/generic_elastic.py index 888b555552..9a1bdd103b 100644 --- a/src/argilla/server/daos/backend/generic_elastic.py +++ b/src/argilla/server/daos/backend/generic_elastic.py @@ -184,7 +184,6 @@ def search_records( exclude_fields: List[str] = None, enable_highlight: bool = True, ) -> Tuple[int, List[Dict[str, Any]]]: - index = dataset_records_index(id) if not sort.sort_by and sort.shuffle is False: @@ -248,7 +247,6 @@ def create_dataset( vectors_cfg: Optional[Dict[str, Any]] = None, force_recreate: bool = False, ) -> None: - _mappings = self._common_records_mappings task_mappings = self.get_task_mapping(task).copy() for k in task_mappings: diff --git a/src/argilla/server/daos/backend/mappings/token_classification.py b/src/argilla/server/daos/backend/mappings/token_classification.py index c51f2dee49..6685560968 100644 --- a/src/argilla/server/daos/backend/mappings/token_classification.py +++ b/src/argilla/server/daos/backend/mappings/token_classification.py @@ -38,7 +38,6 @@ class TokenTagMetrics(BaseModel): class TokenMetrics(BaseModel): - idx: int value: str char_start: int diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py index 5f9ed46df8..80122add68 100644 --- a/src/argilla/server/daos/backend/metrics/base.py +++ b/src/argilla/server/daos/backend/metrics/base.py @@ -125,7 +125,6 @@ def _build_aggregation(self, interval: Optional[float] = None) -> Dict[str, Any] @dataclasses.dataclass class TermsAggregation(ElasticsearchMetric): - field: str = None script: Union[str, Dict[str, Any]] = None fixed_size: Optional[int] = None @@ -206,7 +205,10 @@ def _build_aggregation( ) -> Dict[str, Any]: field = text_field or self.default_field terms_id = f"{self.id}_{field}" if text_field else self.id - return TermsAggregation(id=terms_id, field=field,).aggregation_request( + return TermsAggregation( + id=terms_id, + field=field, + ).aggregation_request( size=size or self.DEFAULT_WORDCOUNT_SIZE )[terms_id] @@ -223,7 +225,6 @@ def aggregation_request( index: str, size: int = None, ) -> List[Dict[str, Any]]: - schema = client.get_property_type( index=index, property_name="metadata", diff --git a/src/argilla/server/daos/backend/metrics/text_classification.py b/src/argilla/server/daos/backend/metrics/text_classification.py index 70fffc4ed7..fa59630558 100644 --- a/src/argilla/server/daos/backend/metrics/text_classification.py +++ b/src/argilla/server/daos/backend/metrics/text_classification.py @@ -48,7 +48,6 @@ class LabelingRulesMetric(ElasticsearchMetric): def _build_aggregation( self, rule_query: str, labels: Optional[List[str]] ) -> Dict[str, Any]: - annotated_records_filter = filters.exists_field("annotated_as") rule_query_filter = filters.text_query(rule_query) aggr_filters = { diff --git a/src/argilla/server/daos/backend/query_helpers.py b/src/argilla/server/daos/backend/query_helpers.py index 731edd6b8f..1c57c5ccee 100644 --- a/src/argilla/server/daos/backend/query_helpers.py +++ b/src/argilla/server/daos/backend/query_helpers.py @@ -297,7 +297,6 @@ def histogram_aggregation( script: Union[str, Dict[str, Any]] = None, interval: float = 0.1, ): - assert field_name or script, "Either field name or script must be provided" if script: diff --git a/src/argilla/server/daos/backend/search/model.py b/src/argilla/server/daos/backend/search/model.py index 395bc702e6..fcaad0e5fa 100644 --- a/src/argilla/server/daos/backend/search/model.py +++ b/src/argilla/server/daos/backend/search/model.py @@ -26,7 +26,6 @@ class SortOrder(str, Enum): class QueryRange(BaseModel): - range_from: float = Field(default=0.0, alias="from") range_to: float = Field(default=None, alias="to") @@ -77,7 +76,6 @@ class VectorSearch(BaseModel): class BaseRecordsQuery(BaseQuery): - query_text: Optional[str] = None advanced_query_dsl: bool = False diff --git a/src/argilla/server/daos/backend/search/query_builder.py b/src/argilla/server/daos/backend/search/query_builder.py index ac47a55c08..a05ed872b9 100644 --- a/src/argilla/server/daos/backend/search/query_builder.py +++ b/src/argilla/server/daos/backend/search/query_builder.py @@ -32,7 +32,6 @@ class HighlightParser: - _SEARCH_KEYWORDS_FIELD = "search_keywords" __HIGHLIGHT_PRE_TAG__ = "<@@-ar-key>" @@ -196,7 +195,6 @@ def map_2_es_query( id_from: Optional[str] = None, shuffle: bool = False, ) -> Dict[str, Any]: - if query and query.raw_query: es_query = {"query": query.raw_query} else: @@ -252,7 +250,6 @@ def map_2_es_sort_configuration( schema: Optional[Dict[str, Any]] = None, sort: Optional[SortConfig] = None, ) -> Optional[List[Dict[str, Any]]]: - if not sort: return None diff --git a/src/argilla/server/daos/datasets.py b/src/argilla/server/daos/datasets.py index 512e201835..c45f8ddd20 100644 --- a/src/argilla/server/daos/datasets.py +++ b/src/argilla/server/daos/datasets.py @@ -117,7 +117,6 @@ def update_dataset( self, dataset: DatasetDB, ) -> DatasetDB: - self._es.update_dataset_document( id=dataset.id, document=self._dataset_to_es_doc(dataset) ) @@ -133,7 +132,6 @@ def find_by_name( as_dataset_class: Type[DatasetDB] = BaseDatasetDB, task: Optional[str] = None, ) -> Optional[DatasetDB]: - dataset_id = BaseDatasetDB.build_dataset_id( name=name, owner=owner, @@ -220,7 +218,6 @@ def save_settings( dataset: DatasetDB, settings: DatasetSettingsDB, ) -> BaseDatasetSettingsDB: - self._configure_vectors(dataset, settings) self._es.update_dataset_document( id=dataset.id, diff --git a/src/argilla/server/daos/models/records.py b/src/argilla/server/daos/models/records.py index 4bdd0b18ea..43776057e6 100644 --- a/src/argilla/server/daos/models/records.py +++ b/src/argilla/server/daos/models/records.py @@ -29,7 +29,6 @@ class DaoRecordsSearch(BaseModel): - query: Optional[BaseRecordsQuery] = None sort: SortConfig = Field(default_factory=SortConfig) @@ -119,7 +118,6 @@ def update_annotation(values, annotation_field: str): @root_validator() def prepare_record_for_db(cls, values): - values = cls.update_annotation(values, "prediction") values = cls.update_annotation(values, "annotation") @@ -239,7 +237,6 @@ def dict(self, *args, **kwargs) -> "DictStrAny": class BaseRecordDB(BaseRecordInDB, Generic[AnnotationDB]): - # Read only ones metrics: Dict[str, Any] = Field(default_factory=dict) search_keywords: Optional[List[str]] = None diff --git a/src/argilla/server/daos/records.py b/src/argilla/server/daos/records.py index 4de51d78f6..1ab091d885 100644 --- a/src/argilla/server/daos/records.py +++ b/src/argilla/server/daos/records.py @@ -152,7 +152,6 @@ def search_records( exclude_fields: List[str] = None, highligth_results: bool = True, ) -> DaoRecordsSearchResults: - try: search = search or DaoRecordsSearch() @@ -261,7 +260,6 @@ async def get_record_by_id( dataset: DatasetDB, id: str, ) -> Optional[Dict[str, Any]]: - return self._es.find_record_by_id( dataset_id=dataset.id, record_id=id, diff --git a/src/argilla/server/errors/base_errors.py b/src/argilla/server/errors/base_errors.py index db63f5826f..ab69d7ff7b 100644 --- a/src/argilla/server/errors/base_errors.py +++ b/src/argilla/server/errors/base_errors.py @@ -19,7 +19,6 @@ class ServerError(Exception): - HTTP_STATUS: int = status.HTTP_500_INTERNAL_SERVER_ERROR @classmethod diff --git a/src/argilla/server/security/auth_provider/base.py b/src/argilla/server/security/auth_provider/base.py index 2c5aa7ecb9..268cd9beb7 100644 --- a/src/argilla/server/security/auth_provider/base.py +++ b/src/argilla/server/security/auth_provider/base.py @@ -32,6 +32,6 @@ async def get_user( self, security_scopes: SecurityScopes, api_key: Optional[str] = Depends(api_key_header), - **kwargs + **kwargs, ) -> User: raise NotImplementedError() diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py index 1dd7667ccd..7e1ece9e26 100644 --- a/src/argilla/server/server.py +++ b/src/argilla/server/server.py @@ -162,7 +162,6 @@ async def setup_elasticsearch(): def configure_app_security(app: FastAPI): - if hasattr(auth, "router"): app.include_router(auth.router) diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index c5f834fb68..2a3bce6f9a 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -44,7 +44,6 @@ class ServiceBaseDatasetSettings(BaseDatasetSettingsDB): class DatasetsService: - _INSTANCE: "DatasetsService" = None @classmethod @@ -197,7 +196,6 @@ def copy_dataset( copy_tags: Dict[str, Any] = None, copy_metadata: Dict[str, Any] = None, ) -> ServiceDataset: - dataset_workspace = copy_workspace or dataset.owner dataset_workspace = user.check_workspace(dataset_workspace) @@ -253,7 +251,6 @@ async def get_settings( async def save_settings( self, user: User, dataset: ServiceDataset, settings: ServiceDatasetSettings ) -> ServiceDatasetSettings: - self.__dao__.save_settings(dataset=dataset, settings=settings) return settings diff --git a/src/argilla/server/services/search/model.py b/src/argilla/server/services/search/model.py index de02d46356..076d751f99 100644 --- a/src/argilla/server/services/search/model.py +++ b/src/argilla/server/services/search/model.py @@ -48,7 +48,6 @@ class ServiceScoreRange(ServiceQueryRange): class ServiceBaseSearchResultsAggregations(BaseModel): - predicted_as: Dict[str, int] = Field(default_factory=dict) annotated_as: Dict[str, int] = Field(default_factory=dict) annotated_by: Dict[str, int] = Field(default_factory=dict) diff --git a/src/argilla/server/services/search/service.py b/src/argilla/server/services/search/service.py index 8097be98a4..b543a50cf7 100644 --- a/src/argilla/server/services/search/service.py +++ b/src/argilla/server/services/search/service.py @@ -67,7 +67,6 @@ def search( exclude_metrics: bool = True, metrics: Optional[List[ServiceMetric]] = None, ) -> ServiceSearchResults: - if record_from > 0: metrics = None diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py index 2866fb37ba..27897efe11 100644 --- a/src/argilla/server/services/storage/service.py +++ b/src/argilla/server/services/storage/service.py @@ -37,7 +37,6 @@ class DeleteRecordsOut: class RecordsStorageService: - _INSTANCE: "RecordsStorageService" = None @classmethod @@ -115,7 +114,6 @@ async def update_record( record: ServiceRecord, **data, ) -> ServiceRecord: - if data.get("metadata"): record.metadata = { **(record.metadata or {}), diff --git a/src/argilla/server/services/tasks/text2text/service.py b/src/argilla/server/services/tasks/text2text/service.py index 870eca8788..48f5af61b3 100644 --- a/src/argilla/server/services/tasks/text2text/service.py +++ b/src/argilla/server/services/tasks/text2text/service.py @@ -81,7 +81,6 @@ def search( size: int = 100, exclude_metrics: bool = True, ) -> ServiceSearchResults: - metrics = TasksFactory.find_task_metrics( dataset.task, metric_ids={ diff --git a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py index 63bf5b77d8..9b0b4871c0 100644 --- a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py +++ b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py @@ -42,7 +42,6 @@ class LabelingRuleSummary(BaseModel): class LabelingService: - _INSTANCE = None @classmethod diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py index 9b8a272846..2450d9731d 100644 --- a/src/argilla/server/services/tasks/text_classification/model.py +++ b/src/argilla/server/services/tasks/text_classification/model.py @@ -71,7 +71,6 @@ def strip_query(cls, query: str) -> str: class ServiceTextClassificationDataset(ServiceBaseDataset): - task: TaskType = Field(default=TaskType.text_classification, const=True) rules: List[ServiceLabelingRule] = Field(default_factory=list) @@ -300,7 +299,6 @@ def flatten_text(cls, text: Dict[str, Any]): def _labels_from_annotation( cls, annotation: TextClassificationAnnotation, multi_label: bool ) -> Union[List[str], List[int]]: - if not annotation: return [] @@ -334,7 +332,6 @@ def extended_fields(self) -> Dict[str, Any]: class ServiceTextClassificationQuery(ServiceBaseRecordsQuery): - predicted_as: List[str] = Field(default_factory=list) annotated_as: List[str] = Field(default_factory=list) score: Optional[ServiceScoreRange] = Field(default=None) diff --git a/src/argilla/server/services/tasks/text_classification/service.py b/src/argilla/server/services/tasks/text_classification/service.py index 0176bd6b25..e6351d13e9 100644 --- a/src/argilla/server/services/tasks/text_classification/service.py +++ b/src/argilla/server/services/tasks/text_classification/service.py @@ -215,7 +215,6 @@ def _is_dataset_multi_label( def get_labeling_rules( self, dataset: ServiceTextClassificationDataset ) -> Iterable[ServiceLabelingRule]: - return self.__labeling__.list_rules(dataset) def add_labeling_rule( diff --git a/src/argilla/server/services/tasks/token_classification/metrics.py b/src/argilla/server/services/tasks/token_classification/metrics.py index a6b70288ae..e210b0960f 100644 --- a/src/argilla/server/services/tasks/token_classification/metrics.py +++ b/src/argilla/server/services/tasks/token_classification/metrics.py @@ -270,7 +270,6 @@ def build_tokens_metrics( record: ServiceTokenClassificationRecord, tags: Optional[List[str]] = None, ) -> List[TokenMetrics]: - return [ TokenMetrics( idx=token_idx, diff --git a/src/argilla/server/services/tasks/token_classification/model.py b/src/argilla/server/services/tasks/token_classification/model.py index fb2491d0de..926783ecbd 100644 --- a/src/argilla/server/services/tasks/token_classification/model.py +++ b/src/argilla/server/services/tasks/token_classification/model.py @@ -77,7 +77,6 @@ class ServiceTokenClassificationAnnotation(ServiceBaseAnnotation): class ServiceTokenClassificationRecord( ServiceBaseRecord[ServiceTokenClassificationAnnotation] ): - tokens: List[str] = Field(min_items=1) text: str = Field() _raw_text: Optional[str] = Field(alias="raw_text") @@ -87,7 +86,6 @@ class ServiceTokenClassificationRecord( _predicted: Optional[PredictionStatus] = Field(alias="predicted") def extended_fields(self) -> Dict[str, Any]: - return { **super().extended_fields(), # See ../service/service.py @@ -144,7 +142,6 @@ def span_utils(self) -> SpanUtils: @property def predicted(self) -> Optional[PredictionStatus]: if self.annotation and self.prediction: - annotated_entities = self.annotation.entities predicted_entities = self.prediction.entities if len(annotated_entities) != len(predicted_entities): @@ -223,7 +220,6 @@ class Config: class ServiceTokenClassificationQuery(ServiceBaseRecordsQuery): - predicted_as: List[str] = Field(default_factory=list) annotated_as: List[str] = Field(default_factory=list) score: Optional[ServiceScoreRange] = Field(default=None) diff --git a/src/argilla/server/services/tasks/token_classification/service.py b/src/argilla/server/services/tasks/token_classification/service.py index c422ec2ce0..534ff6920b 100644 --- a/src/argilla/server/services/tasks/token_classification/service.py +++ b/src/argilla/server/services/tasks/token_classification/service.py @@ -77,7 +77,6 @@ def search( size: int = 100, exclude_metrics: bool = True, ) -> ServiceSearchResults: - """ Run a search in a dataset diff --git a/tests/client/functional_tests/test_record_update.py b/tests/client/functional_tests/test_record_update.py index d2046c1f8a..345149ca69 100644 --- a/tests/client/functional_tests/test_record_update.py +++ b/tests/client/functional_tests/test_record_update.py @@ -34,7 +34,6 @@ def test_partial_record_update( mocked_client, gutenberg_spacy_ner, ): - expected_id = "00c27206-da48-4fc3-aab7-4b730628f8ac" record = record_data_by_id( diff --git a/tests/client/functional_tests/test_scan_raw_records.py b/tests/client/functional_tests/test_scan_raw_records.py index 967b8851c5..f58b18006a 100644 --- a/tests/client/functional_tests/test_scan_raw_records.py +++ b/tests/client/functional_tests/test_scan_raw_records.py @@ -28,7 +28,6 @@ def test_scan_records( gutenberg_spacy_ner, fields, ): - import pandas as pd import argilla as rg diff --git a/tests/client/sdk/commons/test_client.py b/tests/client/sdk/commons/test_client.py index 1fe621df33..5b1dd8bedd 100644 --- a/tests/client/sdk/commons/test_client.py +++ b/tests/client/sdk/commons/test_client.py @@ -32,7 +32,6 @@ def test_wrong_hostname_values( url: str, raises_error: bool, ): - if raises_error: with pytest.raises(Exception): Client(base_url=url) @@ -42,7 +41,6 @@ def test_wrong_hostname_values( def test_http_calls(mocked_client): - rb_api = active_api() data = rb_api.http_client.get("/api/_info") assert data.get("version"), data diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 23d02253f6..4f80a66deb 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -208,7 +208,6 @@ def test_log_records_with_too_long_text(mocked_client): def test_not_found_response(mocked_client): - with pytest.raises(NotFoundApiError): api.load(name="not-found") @@ -227,7 +226,6 @@ def test_log_without_name(mocked_client): def test_log_passing_empty_records_list(mocked_client): - with pytest.raises( InputValueError, match="Empty record list has been passed as argument.", @@ -285,7 +283,6 @@ def raise_http_error(*args, **kwargs): ) with pytest.raises(BaseClientError): - try: future.result() finally: @@ -639,7 +636,6 @@ def test_token_classification_spans(span, valid): def test_load_text2text(mocked_client, supported_vector_search): - vectors = {"bert_uncased": [1.2, 3.4, 6.4, 6.4]} records = [] diff --git a/tests/client/test_client_errors.py b/tests/client/test_client_errors.py index d0bf562067..a19ff3dd12 100644 --- a/tests/client/test_client_errors.py +++ b/tests/client/test_client_errors.py @@ -18,7 +18,6 @@ def test_unauthorized_response_error(mocked_client): - with pytest.raises(UnauthorizedApiError, match="Could not validate credentials"): import argilla as ar diff --git a/tests/conftest.py b/tests/conftest.py index 72f6d86241..a9b28296e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,7 +33,6 @@ @pytest.fixture def telemetry_track_data(mocker): - client = telemetry._TelemetryClient.get() if client: # Disable sending data for tests @@ -48,7 +47,6 @@ def mocked_client( monkeypatch, telemetry_track_data, ) -> SecuredClient: - with TestClient(app, raise_server_exceptions=False) as _client: client_ = SecuredClient(_client) diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py index 1b09a87d7c..92eb8062bb 100644 --- a/tests/functional_tests/search/test_search_service.py +++ b/tests/functional_tests/search/test_search_service.py @@ -132,7 +132,6 @@ def test_query_builder_with_nested( def test_failing_metrics(service, mocked_client): - dataset = Dataset( name="test_failing_metrics", owner=argilla.get_workspace(), diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py index 3ece320078..0a85ec53e8 100644 --- a/tests/functional_tests/test_log_for_text_classification.py +++ b/tests/functional_tests/test_log_for_text_classification.py @@ -185,7 +185,6 @@ def test_log_data_with_vectors_and_update_ko(mocked_client: SecuredClient): def test_log_data_in_several_workspaces(mocked_client: SecuredClient): - workspace = "test-ws" dataset = "test_log_data_in_several_workspaces" text = "This is a text" @@ -314,7 +313,6 @@ def test_dynamics_metadata(mocked_client): def test_log_with_bulk_error(mocked_client): - dataset = "test_log_with_bulk_error" ar.delete(dataset) try: diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py index 08dddb73fe..f4bea51e2c 100644 --- a/tests/functional_tests/test_log_for_token_classification.py +++ b/tests/functional_tests/test_log_for_token_classification.py @@ -54,7 +54,6 @@ def test_log_with_empty_tokens_list(mocked_client): def test_call_metrics_with_no_api_client_initialized(mocked_client): - for metric in ALL_METRICS: if metric == entity_consistency: continue diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py index 00d2c37f85..3c1d29a326 100644 --- a/tests/labeling/text_classification/test_rule.py +++ b/tests/labeling/text_classification/test_rule.py @@ -178,7 +178,6 @@ def test_create_rules_with_update( def test_load_rules(mocked_client, log_dataset): - mocked_client.post( f"/api/datasets/TextClassification/{log_dataset}/labeling/rules", json={"query": "a query", "label": "LALA"}, @@ -191,7 +190,6 @@ def test_load_rules(mocked_client, log_dataset): def test_add_rules(mocked_client, log_dataset): - expected_rules = [ Rule(query="a query", label="La La"), Rule(query="another query", label="La La"), @@ -209,7 +207,6 @@ def test_add_rules(mocked_client, log_dataset): def test_delete_rules(mocked_client, log_dataset): - rules = [ Rule(query="a query", label="La La"), Rule(query="another query", label="La La"), @@ -235,7 +232,6 @@ def test_delete_rules(mocked_client, log_dataset): def test_update_rules(mocked_client, log_dataset): - rules = [ Rule(query="a query", label="La La"), Rule(query="another query", label="La La"), @@ -316,7 +312,6 @@ def test_copy_dataset_with_rules(mocked_client, log_dataset): ], ) def test_rule_metrics(mocked_client, log_dataset, rule, expected_metrics): - delete_rule_silently(mocked_client, log_dataset, rule) mocked_client.post( diff --git a/tests/monitoring/test_flair_monitoring.py b/tests/monitoring/test_flair_monitoring.py index e944c08aed..e8cb51b6ad 100644 --- a/tests/monitoring/test_flair_monitoring.py +++ b/tests/monitoring/test_flair_monitoring.py @@ -15,7 +15,6 @@ def test_flair_monitoring(mocked_client, monkeypatch): - from flair.data import Sentence from flair.models import SequenceTagger @@ -52,7 +51,7 @@ def test_flair_monitoring(mocked_client, monkeypatch): assert record.tokens == [token.text for token in sentence.tokens] assert len(record.prediction) == len(detected_labels) - for ((label, start, end, score), span) in zip(record.prediction, detected_labels): + for (label, start, end, score), span in zip(record.prediction, detected_labels): assert label == span.value assert start == span.span.start_pos assert end == span.span.end_pos diff --git a/tests/monitoring/test_monitor.py b/tests/monitoring/test_monitor.py index 95c28ea9e1..18240551af 100644 --- a/tests/monitoring/test_monitor.py +++ b/tests/monitoring/test_monitor.py @@ -38,7 +38,6 @@ def test_monitor_with_non_supported_model(): def test_monitor_non_supported_huggingface_model(): with warnings.catch_warnings(record=True) as warning_list: - from transformers import ( AutoModelForTokenClassification, AutoTokenizer, diff --git a/tests/monitoring/test_transformers_monitoring.py b/tests/monitoring/test_transformers_monitoring.py index 3ea8eb28cf..ee0487c98f 100644 --- a/tests/monitoring/test_transformers_monitoring.py +++ b/tests/monitoring/test_transformers_monitoring.py @@ -242,7 +242,6 @@ def test_monitor_zero_short_passing_labels_keyword_arg( mocked_monitor, dataset, ): - argilla.delete(dataset) predictions = mocked_monitor( text, @@ -307,7 +306,6 @@ def test_monitor_zero_shot_with_text_array( mocked_monitor, dataset, ): - argilla.delete(dataset) predictions = mocked_monitor( [text], candidate_labels=labels, hypothesis_template=hypothesis diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py index edadba6a9e..2ceef02733 100644 --- a/tests/server/backend/test_query_builder.py +++ b/tests/server/backend/test_query_builder.py @@ -61,7 +61,6 @@ ], ) def test_build_sort_configuration(index_schema, sort_cfg, expected_sort): - builder = EsQueryBuilder() es_sort = builder.map_2_es_sort_configuration( diff --git a/tests/server/daos/models/test_records.py b/tests/server/daos/models/test_records.py index 21d8980d2a..e2542c804f 100644 --- a/tests/server/daos/models/test_records.py +++ b/tests/server/daos/models/test_records.py @@ -21,7 +21,6 @@ def test_metadata_limit(): - long_value = "a" * (settings.metadata_field_length + 1) short_value = "a" * (settings.metadata_field_length - 1) diff --git a/tests/server/info/test_api.py b/tests/server/info/test_api.py index 6c22ead5e5..d25b143555 100644 --- a/tests/server/info/test_api.py +++ b/tests/server/info/test_api.py @@ -18,7 +18,6 @@ def test_api_info(mocked_client): - response = mocked_client.get("/api/_info") assert response.status_code == 200 @@ -28,7 +27,6 @@ def test_api_info(mocked_client): def test_api_status(mocked_client): - response = mocked_client.get("/api/_status") assert response.status_code == 200 diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py index 88ff875bdc..4c4428ed06 100644 --- a/tests/server/security/test_model.py +++ b/tests/server/security/test_model.py @@ -64,7 +64,6 @@ def test_check_non_provided_workspaces(): def test_check_user_workspaces(): - a_ws = "A-workspace" expected_workspaces = [a_ws, "B-ws"] user = User(username="test-user", workspaces=[a_ws, "B-ws", "C-ws"]) @@ -76,7 +75,6 @@ def test_check_user_workspaces(): def test_default_workspace(): - user = User(username="admin") assert user.default_workspace == "admin" diff --git a/tests/server/test_errors.py b/tests/server/test_errors.py index 838cca3956..0090ab2e31 100644 --- a/tests/server/test_errors.py +++ b/tests/server/test_errors.py @@ -16,7 +16,6 @@ def test_generic_error(): - err = GenericServerError(error=ValueError("this is an error")) assert ( str(err) diff --git a/tests/server/text2text/test_model.py b/tests/server/text2text/test_model.py index 5bead4fd23..c8fa2fd30d 100644 --- a/tests/server/text2text/test_model.py +++ b/tests/server/text2text/test_model.py @@ -22,7 +22,6 @@ def test_sentences_sorted_by_score(): - record = Text2TextRecord( text="The inpu2 text", prediction=Text2TextAnnotation( diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index ebbc61af47..18a967a637 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -134,7 +134,6 @@ def test_model_with_annotations(): def test_single_label_with_multiple_annotation(): - with pytest.raises( ValidationError, match="Single label record must include only one annotation label", @@ -213,7 +212,6 @@ def test_score_integrity(): def test_prediction_ok_cases(): - data = { "multi_label": True, "inputs": {"data": "My cool data"}, @@ -330,7 +328,6 @@ def test_validate_without_labels_for_single_label(annotation): def test_query_with_uncovered_by_rules(): - query = TextClassificationQuery(uncovered_by_rules=["query", "other*"]) assert EsQueryBuilder._to_es_query(query) == { diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py index 2f710c7d84..11701da917 100644 --- a/tests/server/token_classification/test_model.py +++ b/tests/server/token_classification/test_model.py @@ -31,7 +31,6 @@ def test_char_position(): - with pytest.raises( ValidationError, match="End character cannot be placed before the starting character," @@ -68,7 +67,6 @@ def test_fix_substrings(): def test_entities_with_spaces(): - text = "This is a great space" ServiceTokenClassificationRecord( text=text, @@ -284,7 +282,6 @@ def test_annotated_without_entities(): def test_adjust_spans(): - text = "A text with some empty spaces that could bring not cleany annotated spans" record = ServiceTokenClassificationRecord( text=text, @@ -338,7 +335,6 @@ def test_whitespace_in_tokens(): def test_predicted_ok_ko_computation(): - text = "A text with some empty spaces that could bring not cleanly annotated spans" record = ServiceTokenClassificationRecord( text=text, From a09c74d7cdb05df4f823d419898df03b3a392f8c Mon Sep 17 00:00:00 2001 From: Daniel Vila Suero Date: Tue, 7 Feb 2023 18:16:34 +0100 Subject: [PATCH 02/45] Add deploy to readme (#2307) # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. Closes # **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code - [ ] I added comments to my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c7f402073e..8e9f7af9a2 100644 --- a/README.md +++ b/README.md @@ -16,20 +16,20 @@ CI + + +

Open-source framework for data-centric NLP

Data Labeling, curation, and Inference Store

Designed for MLOps & Feedback Loops

- > 🆕 🔥 Play with Argilla UI with this [live-demo](https://argilla-live-demo.hf.space) powered by Hugging Face Spaces (login:`argilla`, password:`1234`) > 🆕 🔥 Since `1.2.0` Argilla supports vector search for finding the most similar records to a given one. This feature uses vector or semantic search combined with more traditional search (keyword and filter based). Learn more on this [deep-dive guide](https://docs.argilla.io/en/latest/guides/features/semantic-search.html) -![imagen](https://user-images.githubusercontent.com/1107111/204772677-facee627-9b3b-43ca-8533-bbc9b4e2d0aa.png) - From 2bc97c62380bb46aad6439323588051c005202f6 Mon Sep 17 00:00:00 2001 From: Daniel Vila Suero Date: Tue, 7 Feb 2023 18:19:35 +0100 Subject: [PATCH 03/45] Update README.md (#2308) # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. List any dependencies that are required for this change. Closes # **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code - [ ] I added comments to my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 8e9f7af9a2..ff816f08ae 100644 --- a/README.md +++ b/README.md @@ -16,16 +16,15 @@ CI - +

Open-source framework for data-centric NLP

-

Data Labeling, curation, and Inference Store

-

Designed for MLOps & Feedback Loops

+

Data Labeling for MLOps & Feedback Loops

-> 🆕 🔥 Play with Argilla UI with this [live-demo](https://argilla-live-demo.hf.space) powered by Hugging Face Spaces (login:`argilla`, password:`1234`) +> 🆕 🔥 Deploy [Argilla on Spaces](https://huggingface.co/new-space?template=argilla/argilla-template-space) > 🆕 🔥 Since `1.2.0` Argilla supports vector search for finding the most similar records to a given one. This feature uses vector or semantic search combined with more traditional search (keyword and filter based). Learn more on this [deep-dive guide](https://docs.argilla.io/en/latest/guides/features/semantic-search.html) From d3542748b1fbaa9b9b1d56f6e5b70027fded3a90 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 9 Feb 2023 17:11:52 +0100 Subject: [PATCH 04/45] chore: Upgrade package version --- frontend/package.json | 2 +- src/argilla/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/package.json b/frontend/package.json index 7a87342bbf..9fd9999c0e 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -1,6 +1,6 @@ { "name": "argilla", - "version": "1.3.0-dev0", + "version": "1.4.0-dev0", "private": true, "eslintIgnore": [ "node_modules/**/*", diff --git a/src/argilla/_version.py b/src/argilla/_version.py index 90d7c9110a..7208766f3a 100644 --- a/src/argilla/_version.py +++ b/src/argilla/_version.py @@ -13,4 +13,4 @@ # limitations under the License. # coding: utf-8 -version = "1.3.0-dev0" +version = "1.4.0-dev0" From 5f0627cf523fdfcb2efeb7ae5663c3e9a845ecb9 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Tue, 14 Feb 2023 09:35:45 +0100 Subject: [PATCH 05/45] ci: Replace `isort` by `ruff` in `pre-commit` (#2325) Hello! # Pull Request overview * Replaced `isort` by `ruff` in `pre-commit` * This has fixed the import sorting throughout the repo. * Manually went through `ruff` errors and fixed a bunch: * Unused imports * Added `if TYPE_CHECKING:` statements * Placed `# noqa: ` where the warned-about statement is the desired behaviour * Add basic `ruff` configuration to `pyproject.toml` ## Details This PR focuses on replacing `isort` by `ruff` in `pre-commit`. The motivation for this is that: * `isort` frequently breaks. I have experienced 2 separate occasions in the last few months alone where the latest `isort` release has broken my CI runs in NLTK and SetFit. * `isort` is no longer supported for Python 3.7, whereas Argilla still supports 3.7 for now. * `ruff` is absurdly fast, I actually can't believe how quick it is. This PR consists of 3 commits at this time, and I would advise looking at them commit-by-commit rather than at the PR as a whole. I'll also explain each commit individually. ## [Add ruff basic configuration](https://github.com/argilla-io/argilla/commit/497420e7097b7039d479df6bf431ea73c370f90b) I've added basic configuration for [`ruff`](https://github.com/charliermarsh/ruff), a very efficient linter. I recommend the following commands: ``` # Get all [F]ailures and [E]rrors ruff . # See all import sort errors ruff . --select I # Fix all import sort errors ruff . --select I --fix ``` ## [Remove unused imports, apply some noqa's, add TYPE_CHECKING](https://github.com/argilla-io/argilla/commit/f219acbaba8f4b7fffaea70e8e1e9d1fe3a28475) The unused imports speaks for itself. As for the `noqa`'s, `ruff` (like most linters) respect the `# noqa` (no quality assurance) keyword. I've used the keyword in various locations where linters would warn, but the behaviour is actually correct. As a result, the output of `ruff .` now only points to questionable code. Lastly, I added `TYPE_CHECKING` in some locations. If type hints hint at objects that do not need to be imported during run-time, then it's common to type hint like `arr: "numpy.ndarray"`. However, IDE's won't understand what `arr` is. Python has implemented `TYPE_CHECKING` which can be used to conditionally import code *only* when type checking. As a result, the code block is not actually executed in practice, but the inclusion of it allows for IDEs to better support development. See an example here: ```python from typing import TYPE_CHECKING if TYPE_CHECKING: import numpy def func(arr: "numpy.ndarray") -> None: ... ``` ## [Replace isort with ruff in CI](https://github.com/argilla-io/argilla/commit/e05f30efb65c9e9b3921a551e1f4b581fcc14db4) I've replaced `isort` (which was both [broken for 5.11.*](https://github.com/PyCQA/isort/issues/2077) and [does not work for Python 3.8 in 5.12.*](https://github.com/PyCQA/isort/releases/tag/5.12.0)) with `ruff` in the CI, using both `--select I` to only select `isort` warnings and `--fix` to immediately fix the warnings. Then I ran `pre-commit run --all` to fix the ~67 outstanding issues in the repository. --- **Type of change** - [x] Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** I verified that the behaviour did not change using `pytest tests`. **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I added comments to my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - Tom Aarsen --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 10 ++++-- pyproject.toml | 31 +++++++++++++++++++ scripts/load_data.py | 5 ++- src/argilla/__init__.py | 4 +-- src/argilla/client/apis/datasets.py | 2 +- src/argilla/client/apis/search.py | 2 +- src/argilla/client/datasets.py | 11 +++++-- src/argilla/client/sdk/client.py | 4 +-- src/argilla/client/sdk/commons/api.py | 2 +- src/argilla/client/sdk/text2text/api.py | 1 - .../client/sdk/text_classification/api.py | 2 +- .../client/sdk/token_classification/api.py | 1 - src/argilla/client/sdk/users/api.py | 3 -- .../text_classification/label_models.py | 11 +++---- src/argilla/listeners/listener.py | 6 ++-- src/argilla/listeners/models.py | 5 ++- src/argilla/metrics/__init__.py | 4 +-- src/argilla/monitoring/asgi.py | 2 +- src/argilla/monitoring/base.py | 1 - .../server/apis/v0/handlers/datasets.py | 2 +- .../server/apis/v0/handlers/metrics.py | 4 +-- .../server/apis/v0/handlers/records_update.py | 11 ++----- .../apis/v0/handlers/token_classification.py | 1 - .../server/apis/v0/models/text2text.py | 3 +- .../apis/v0/models/text_classification.py | 7 ++--- .../server/daos/backend/metrics/base.py | 5 ++- .../daos/backend/search/query_builder.py | 2 +- src/argilla/server/daos/models/datasets.py | 2 +- src/argilla/server/daos/models/records.py | 2 +- src/argilla/server/errors/api_errors.py | 2 +- src/argilla/server/errors/base_errors.py | 2 +- src/argilla/server/server.py | 1 - src/argilla/server/services/datasets.py | 2 +- .../server/services/storage/service.py | 4 +-- .../server/services/tasks/text2text/models.py | 4 --- .../services/tasks/text2text/service.py | 3 +- .../tasks/text_classification/model.py | 2 +- src/argilla/utils/span_utils.py | 2 +- tests/client/apis/test_base.py | 1 - tests/client/conftest.py | 3 +- .../functional_tests/test_record_update.py | 1 - .../functional_tests/test_scan_raw_records.py | 5 +-- tests/client/sdk/commons/api.py | 3 +- tests/client/sdk/commons/test_client.py | 1 - tests/client/sdk/conftest.py | 3 +- tests/client/sdk/datasets/test_api.py | 9 +----- tests/client/sdk/datasets/test_models.py | 1 - tests/client/sdk/text2text/test_models.py | 1 - .../sdk/text_classification/test_models.py | 1 - .../sdk/token_classification/test_models.py | 1 - tests/client/sdk/users/test_api.py | 1 - tests/client/test_api.py | 6 ++-- tests/client/test_asgi.py | 9 +++--- tests/client/test_client_errors.py | 1 - tests/client/test_dataset.py | 3 +- tests/client/test_models.py | 4 +-- tests/conftest.py | 4 +-- tests/datasets/test_datasets.py | 3 +- .../test_delete_records_from_datasets.py | 1 - .../datasets/test_update_record.py | 1 - .../search/test_search_service.py | 3 +- .../test_log_for_text_classification.py | 6 ++-- .../test_log_for_token_classification.py | 4 +-- tests/helpers.py | 5 ++- .../text_classification/test_label_errors.py | 5 ++- .../text_classification/test_label_models.py | 2 -- .../labeling/text_classification/test_rule.py | 1 - .../text_classification/test_weak_labels.py | 3 +- tests/listeners/test_listener.py | 3 +- tests/metrics/test_common_metrics.py | 3 +- tests/metrics/test_token_classification.py | 3 +- tests/monitoring/test_base_monitor.py | 2 +- tests/monitoring/test_flair_monitoring.py | 3 +- .../test_transformers_monitoring.py | 3 +- tests/server/backend/test_query_builder.py | 1 - tests/server/commons/test_records_dao.py | 1 - tests/server/commons/test_settings.py | 3 +- tests/server/commons/test_telemetry.py | 3 +- tests/server/daos/models/test_records.py | 1 - tests/server/datasets/test_api.py | 3 +- tests/server/datasets/test_dao.py | 1 - tests/server/datasets/test_model.py | 3 +- .../datasets/test_get_record.py | 1 - tests/server/info/test_api.py | 2 +- tests/server/security/test_dao.py | 1 - tests/server/security/test_model.py | 3 +- tests/server/security/test_provider.py | 3 +- tests/server/security/test_service.py | 1 - tests/server/test_app.py | 1 - tests/server/text2text/test_api.py | 3 +- tests/server/text_classification/test_api.py | 2 +- .../text_classification/test_api_rules.py | 1 - .../server/text_classification/test_model.py | 3 +- tests/server/token_classification/test_api.py | 2 +- .../server/token_classification/test_model.py | 3 +- tests/utils/test_span_utils.py | 1 - tests/utils/test_utils.py | 1 - 97 files changed, 139 insertions(+), 182 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f67cae6ff1..5ba8658505 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,10 +25,14 @@ repos: - id: black additional_dependencies: ["click==8.0.4"] - - repo: https://github.com/pycqa/isort - rev: 5.12.0 + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.244 hooks: - - id: isort + # Simulate isort via (the much faster) ruff + - id: ruff + args: + - --select=I + - --fix - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook rev: v9.4.0 diff --git a/pyproject.toml b/pyproject.toml index 5052f1d547..f28608d42e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,3 +108,34 @@ exclude_lines = [ [tool.isort] profile = "black" + +[tool.ruff] +# Ignore line length violations +ignore = ["E501"] + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] + +[tool.ruff.per-file-ignores] +# Ignore imported but unused; +"__init__.py" = ["F401"] diff --git a/scripts/load_data.py b/scripts/load_data.py index 0295872157..06405d7650 100644 --- a/scripts/load_data.py +++ b/scripts/load_data.py @@ -15,12 +15,11 @@ import sys import time +import argilla as rg import pandas as pd import requests -from datasets import load_dataset - -import argilla as rg from argilla.labeling.text_classification import Rule, add_rules +from datasets import load_dataset class LoadDatasets: diff --git a/src/argilla/__init__.py b/src/argilla/__init__.py index c9fb4f6787..1dbd654c9f 100644 --- a/src/argilla/__init__.py +++ b/src/argilla/__init__.py @@ -47,12 +47,10 @@ read_datasets, read_pandas, ) - from argilla.client.models import ( - TextGenerationRecord, # TODO Remove TextGenerationRecord - ) from argilla.client.models import ( Text2TextRecord, TextClassificationRecord, + TextGenerationRecord, # TODO Remove TextGenerationRecord TokenAttributions, TokenClassificationRecord, ) diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py index 55079c7a95..5fff021362 100644 --- a/src/argilla/client/apis/datasets.py +++ b/src/argilla/client/apis/datasets.py @@ -15,7 +15,7 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from pydantic import BaseModel, Field diff --git a/src/argilla/client/apis/search.py b/src/argilla/client/apis/search.py index f8a1b66a98..07de6bb71e 100644 --- a/src/argilla/client/apis/search.py +++ b/src/argilla/client/apis/search.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import List, Optional, Union +from typing import List, Optional from argilla.client.apis import AbstractApi from argilla.client.models import Record diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py index 944e7d93f2..952a3a96ef 100644 --- a/src/argilla/client/datasets.py +++ b/src/argilla/client/datasets.py @@ -16,7 +16,7 @@ import logging import random import uuid -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import pandas as pd from pkg_resources import parse_version @@ -32,6 +32,11 @@ from argilla.client.sdk.datasets.models import TaskType from argilla.utils.span_utils import SpanUtils +if TYPE_CHECKING: + import datasets + import pandas + import spacy + _LOGGER = logging.getLogger(__name__) @@ -60,7 +65,7 @@ def _requires_spacy(func): @functools.wraps(func) def check_if_spacy_installed(*args, **kwargs): try: - import spacy + import spacy # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( f"'spacy' must be installed to use `{func.__name__}`" @@ -1007,7 +1012,7 @@ def from_datasets( for row in dataset: # TODO: fails with a KeyError if no tokens column is present and no mapping is indicated if not row["tokens"]: - _LOGGER.warning(f"Ignoring row with no tokens.") + _LOGGER.warning("Ignoring row with no tokens.") continue if row.get("tags"): diff --git a/src/argilla/client/sdk/client.py b/src/argilla/client/sdk/client.py index ee244c7697..be4d2983b9 100644 --- a/src/argilla/client/sdk/client.py +++ b/src/argilla/client/sdk/client.py @@ -119,7 +119,7 @@ async def inner_async(self, *args, **kwargs): try: result = await func(self, *args, **kwargs) return result - except httpx.ConnectError as err: + except httpx.ConnectError as err: # noqa: F841 return wrap_error(self.base_url) @functools.wraps(func) @@ -127,7 +127,7 @@ def inner(self, *args, **kwargs): try: result = func(self, *args, **kwargs) return result - except httpx.ConnectError as err: + except httpx.ConnectError as err: # noqa: F841 return wrap_error(self.base_url) is_coroutine = inspect.iscoroutinefunction(func) diff --git a/src/argilla/client/sdk/commons/api.py b/src/argilla/client/sdk/commons/api.py index 9ccc3eb3bb..3810aebdf1 100644 --- a/src/argilla/client/sdk/commons/api.py +++ b/src/argilla/client/sdk/commons/api.py @@ -126,7 +126,7 @@ def build_data_response( parsed_record = json.loads(r) try: parsed_response = data_type(**parsed_record) - except Exception as err: + except Exception as err: # noqa: F841 raise GenericApiError(**parsed_record) from None parsed_responses.append(parsed_response) return Response( diff --git a/src/argilla/client/sdk/text2text/api.py b/src/argilla/client/sdk/text2text/api.py index 2baab48725..5b5aca9568 100644 --- a/src/argilla/client/sdk/text2text/api.py +++ b/src/argilla/client/sdk/text2text/api.py @@ -33,7 +33,6 @@ def data( limit: Optional[int] = None, id_from: Optional[str] = None, ) -> Response[Union[List[Text2TextRecord], HTTPValidationError, ErrorMessage]]: - path = f"/api/datasets/{name}/Text2Text/data" params = build_param_dict(id_from, limit) diff --git a/src/argilla/client/sdk/text_classification/api.py b/src/argilla/client/sdk/text_classification/api.py index 8a5161fe4a..39c603411e 100644 --- a/src/argilla/client/sdk/text_classification/api.py +++ b/src/argilla/client/sdk/text_classification/api.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import httpx diff --git a/src/argilla/client/sdk/token_classification/api.py b/src/argilla/client/sdk/token_classification/api.py index f5a95ba36d..10a91d0f59 100644 --- a/src/argilla/client/sdk/token_classification/api.py +++ b/src/argilla/client/sdk/token_classification/api.py @@ -39,7 +39,6 @@ def data( ) -> Response[ Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage] ]: - path = f"/api/datasets/{name}/TokenClassification/data" params = build_param_dict(id_from, limit) diff --git a/src/argilla/client/sdk/users/api.py b/src/argilla/client/sdk/users/api.py index 964b6f4480..c22d704b13 100644 --- a/src/argilla/client/sdk/users/api.py +++ b/src/argilla/client/sdk/users/api.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import httpx - from argilla.client.sdk.client import AuthenticatedClient -from argilla.client.sdk.commons.errors_handler import handle_response_error from argilla.client.sdk.users.models import User diff --git a/src/argilla/labeling/text_classification/label_models.py b/src/argilla/labeling/text_classification/label_models.py index 27f2e742ea..7186c46b9e 100644 --- a/src/argilla/labeling/text_classification/label_models.py +++ b/src/argilla/labeling/text_classification/label_models.py @@ -20,7 +20,6 @@ import numpy as np from argilla import DatasetForTextClassification, TextClassificationRecord -from argilla.client.datasets import Dataset from argilla.labeling.text_classification.weak_labels import WeakLabels, WeakMultiLabels _LOGGER = logging.getLogger(__name__) @@ -368,7 +367,7 @@ def score( MissingAnnotationError: If the ``weak_labels`` do not contain annotated records. """ try: - import sklearn + import sklearn # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'sklearn' must be installed to compute the metrics! " @@ -501,7 +500,7 @@ def __init__( self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu" ): try: - import snorkel + import snorkel # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'snorkel' must be installed to use the `Snorkel` label model! " @@ -764,8 +763,8 @@ class FlyingSquid(LabelModel): def __init__(self, weak_labels: WeakLabels, **kwargs): try: - import flyingsquid - import pgmpy + import flyingsquid # noqa: F401 + import pgmpy # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'flyingsquid' must be installed to use the `FlyingSquid` label model!" @@ -1024,7 +1023,7 @@ def score( MissingAnnotationError: If the ``weak_labels`` do not contain annotated records. """ try: - import sklearn + import sklearn # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'sklearn' must be installed to compute the metrics! " diff --git a/src/argilla/listeners/listener.py b/src/argilla/listeners/listener.py index 77e48e7896..d836bbf5e6 100644 --- a/src/argilla/listeners/listener.py +++ b/src/argilla/listeners/listener.py @@ -98,7 +98,7 @@ def catch_exceptions_decorator(job_func): def wrapper(*args, **kwargs): try: return job_func(*args, **kwargs) - except: + except: # noqa: E722 import traceback print(traceback.format_exc()) @@ -208,7 +208,7 @@ def __listener_iteration_job__(self, *args, **kwargs): self._LOGGER.debug(f"Evaluate condition with arguments: {condition_args}") if self.condition(*condition_args): - self._LOGGER.debug(f"Condition passed! Running action...") + self._LOGGER.debug("Condition passed! Running action...") return self.__run_action__(ctx, *args, **kwargs) def __compute_metrics__(self, current_api, dataset, query: str) -> Metrics: @@ -235,7 +235,7 @@ def __run_action__(self, ctx: Optional[RGListenerContext] = None, *args, **kwarg ) self._LOGGER.debug(f"Running action with arguments: {action_args}") return self.action(*args, *action_args, **kwargs) - except: + except: # noqa: E722 import traceback print(traceback.format_exc()) diff --git a/src/argilla/listeners/models.py b/src/argilla/listeners/models.py index c61f603177..5c276b9a3f 100644 --- a/src/argilla/listeners/models.py +++ b/src/argilla/listeners/models.py @@ -13,12 +13,15 @@ # limitations under the License. import dataclasses -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union from prodict import Prodict from argilla.client.models import Record +if TYPE_CHECKING: + from argilla.listeners import RGDatasetListener + @dataclasses.dataclass class Search: diff --git a/src/argilla/metrics/__init__.py b/src/argilla/metrics/__init__.py index 53ef179b65..df7a911c61 100644 --- a/src/argilla/metrics/__init__.py +++ b/src/argilla/metrics/__init__.py @@ -19,15 +19,13 @@ entity_consistency, entity_density, entity_labels, -) -from .token_classification import f1 as ner_f1 -from .token_classification import ( mention_length, token_capitalness, token_frequency, token_length, tokens_length, ) +from .token_classification import f1 as ner_f1 __all__ = [ text_length, diff --git a/src/argilla/monitoring/asgi.py b/src/argilla/monitoring/asgi.py index 4b2119be43..f861eee958 100644 --- a/src/argilla/monitoring/asgi.py +++ b/src/argilla/monitoring/asgi.py @@ -21,7 +21,7 @@ from argilla.monitoring.base import BaseMonitor try: - import starlette + import starlette # noqa: F401 except ModuleNotFoundError: raise ModuleNotFoundError( "'starlette' must be installed to use the middleware feature! " diff --git a/src/argilla/monitoring/base.py b/src/argilla/monitoring/base.py index 2e8002de46..59fc8976d4 100644 --- a/src/argilla/monitoring/base.py +++ b/src/argilla/monitoring/base.py @@ -13,7 +13,6 @@ # limitations under the License. import atexit -import dataclasses import logging import random import threading diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index 01d7d75478..0d1470f0f9 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -64,7 +64,7 @@ async def list_datasets( description="Create a new dataset", ) async def create_dataset( - request: CreateDatasetRequest = Body(..., description=f"The request dataset info"), + request: CreateDatasetRequest = Body(..., description="The request dataset info"), ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["create:datasets"]), diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py index 96e0364697..68b92bbb37 100644 --- a/src/argilla/server/apis/v0/handlers/metrics.py +++ b/src/argilla/server/apis/v0/handlers/metrics.py @@ -58,7 +58,7 @@ def configure_router(router: APIRouter, cfg: TaskConfig): path=base_metrics_endpoint, new_path=new_base_metrics_endpoint, router_method=router.get, - operation_id=f"get_dataset_metrics", + operation_id="get_dataset_metrics", name="get_dataset_metrics", ) def get_dataset_metrics( @@ -84,7 +84,7 @@ def get_dataset_metrics( path=base_metrics_endpoint + "/{metric}:summary", new_path=new_base_metrics_endpoint + "/{metric}:summary", router_method=router.post, - operation_id=f"metric_summary", + operation_id="metric_summary", name="metric_summary", ) def metric_summary( diff --git a/src/argilla/server/apis/v0/handlers/records_update.py b/src/argilla/server/apis/v0/handlers/records_update.py index f77d391e53..b2034265ad 100644 --- a/src/argilla/server/apis/v0/handlers/records_update.py +++ b/src/argilla/server/apis/v0/handlers/records_update.py @@ -14,17 +14,12 @@ from typing import Any, Dict, Optional, Union -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Security from pydantic import BaseModel -from argilla.client.sdk.token_classification.models import TokenClassificationQuery -from argilla.server.apis.v0.helpers import deprecate_endpoint from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies -from argilla.server.apis.v0.models.text2text import Text2TextQuery, Text2TextRecord -from argilla.server.apis.v0.models.text_classification import ( - TextClassificationQuery, - TextClassificationRecord, -) +from argilla.server.apis.v0.models.text2text import Text2TextRecord +from argilla.server.apis.v0.models.text_classification import TextClassificationRecord from argilla.server.apis.v0.models.token_classification import TokenClassificationRecord from argilla.server.commons.config import TasksFactory from argilla.server.commons.models import TaskStatus diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py index 940775a132..7a21310e1d 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification.py +++ b/src/argilla/server/apis/v0/handlers/token_classification.py @@ -23,7 +23,6 @@ metrics, token_classification_dataset_settings, ) -from argilla.server.apis.v0.helpers import deprecate_endpoint from argilla.server.apis.v0.models.commons.model import BulkResponse from argilla.server.apis.v0.models.commons.params import ( CommonTaskHandlerDependencies, diff --git a/src/argilla/server/apis/v0/models/text2text.py b/src/argilla/server/apis/v0/models/text2text.py index e2730f46e3..5345f95746 100644 --- a/src/argilla/server/apis/v0/models/text2text.py +++ b/src/argilla/server/apis/v0/models/text2text.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime from typing import Dict, List, Optional from pydantic import BaseModel, Field, validator @@ -27,7 +26,7 @@ SortableField, ) from argilla.server.apis.v0.models.datasets import UpdateDatasetRequest -from argilla.server.commons.models import PredictionStatus, TaskType +from argilla.server.commons.models import PredictionStatus from argilla.server.services.metrics.models import CommonTasksMetrics from argilla.server.services.search.model import ( ServiceBaseRecordsQuery, diff --git a/src/argilla/server/apis/v0/models/text_classification.py b/src/argilla/server/apis/v0/models/text_classification.py index 84f1868f78..b2886fe56b 100644 --- a/src/argilla/server/apis/v0/models/text_classification.py +++ b/src/argilla/server/apis/v0/models/text_classification.py @@ -14,7 +14,7 @@ # limitations under the License. from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field, root_validator, validator @@ -39,14 +39,11 @@ ) from argilla.server.services.tasks.text_classification.model import ( ServiceTextClassificationDataset, -) -from argilla.server.services.tasks.text_classification.model import ( - ServiceTextClassificationQuery as _TextClassificationQuery, + TokenAttributions, ) from argilla.server.services.tasks.text_classification.model import ( TextClassificationAnnotation as _TextClassificationAnnotation, ) -from argilla.server.services.tasks.text_classification.model import TokenAttributions class UpdateLabelingRule(BaseModel): diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py index 80122add68..256ffcdbf1 100644 --- a/src/argilla/server/daos/backend/metrics/base.py +++ b/src/argilla/server/daos/backend/metrics/base.py @@ -13,11 +13,14 @@ # limitations under the License. import dataclasses -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from argilla.server.daos.backend.query_helpers import aggregations from argilla.server.helpers import unflatten_dict +if TYPE_CHECKING: + from argilla.server.daos.backend.client_adapters.base import IClientAdapter + @dataclasses.dataclass class ElasticsearchMetric: diff --git a/src/argilla/server/daos/backend/search/query_builder.py b/src/argilla/server/daos/backend/search/query_builder.py index a05ed872b9..76d9e989ff 100644 --- a/src/argilla/server/daos/backend/search/query_builder.py +++ b/src/argilla/server/daos/backend/search/query_builder.py @@ -277,7 +277,7 @@ def map_2_es_sort_configuration( es_sort = [] for sortable_field in sort.sort_by or [SortableField(id="id")]: if valid_fields: - if not sortable_field.id.split(".")[0] in valid_fields: + if sortable_field.id.split(".")[0] not in valid_fields: raise AssertionError( f"Wrong sort id {sortable_field.id}. Valid values are: " f"{[str(v) for v in valid_fields]}" diff --git a/src/argilla/server/daos/models/datasets.py b/src/argilla/server/daos/models/datasets.py index 22a57e2019..027e4a7011 100644 --- a/src/argilla/server/daos/models/datasets.py +++ b/src/argilla/server/daos/models/datasets.py @@ -46,7 +46,7 @@ def id(self) -> str: """The dataset id. Compounded by owner and name""" return self.build_dataset_id(self.name, self.owner) - def dict(self, *args, **kwargs) -> "DictStrAny": + def dict(self, *args, **kwargs) -> Dict[str, Any]: """ Extends base component dict extending object properties and user defined extended fields diff --git a/src/argilla/server/daos/models/records.py b/src/argilla/server/daos/models/records.py index 43776057e6..117d39581e 100644 --- a/src/argilla/server/daos/models/records.py +++ b/src/argilla/server/daos/models/records.py @@ -225,7 +225,7 @@ def extended_fields(self) -> Dict[str, Any]: "score": self.scores, } - def dict(self, *args, **kwargs) -> "DictStrAny": + def dict(self, *args, **kwargs) -> Dict[str, Any]: """ Extends base component dict extending object properties and user defined extended fields diff --git a/src/argilla/server/errors/api_errors.py b/src/argilla/server/errors/api_errors.py index 01fb75c81f..1cc0f6eb50 100644 --- a/src/argilla/server/errors/api_errors.py +++ b/src/argilla/server/errors/api_errors.py @@ -15,7 +15,7 @@ import logging from typing import Any, Dict -from fastapi import HTTPException, Request, status +from fastapi import HTTPException, Request from fastapi.exception_handlers import http_exception_handler from pydantic import BaseModel diff --git a/src/argilla/server/errors/base_errors.py b/src/argilla/server/errors/base_errors.py index ab69d7ff7b..36ed5bd700 100644 --- a/src/argilla/server/errors/base_errors.py +++ b/src/argilla/server/errors/base_errors.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Optional, Type, Union import pydantic from starlette import status diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py index 7e1ece9e26..67e146d670 100644 --- a/src/argilla/server/server.py +++ b/src/argilla/server/server.py @@ -16,7 +16,6 @@ """ This module configures the global fastapi application """ -import fileinput import glob import inspect import logging diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index 2a3bce6f9a..d67914d354 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -146,7 +146,7 @@ def delete(self, user: User, dataset: ServiceDataset): self.__dao__.delete_dataset(dataset) else: raise ForbiddenOperationError( - f"You don't have the necessary permissions to delete this dataset. " + "You don't have the necessary permissions to delete this dataset. " "Only dataset creators or administrators can delete datasets" ) diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py index 27897efe11..703adaab55 100644 --- a/src/argilla/server/services/storage/service.py +++ b/src/argilla/server/services/storage/service.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import Any, Dict, List, Optional, Type +from typing import List, Optional, Type from fastapi import Depends @@ -94,7 +94,7 @@ async def delete_records( else: if not user.is_superuser() and user.username != dataset.created_by: raise ForbiddenOperationError( - f"You don't have the necessary permissions to delete records on this dataset. " + "You don't have the necessary permissions to delete records on this dataset. " "Only dataset creators or administrators can delete datasets" ) diff --git a/src/argilla/server/services/tasks/text2text/models.py b/src/argilla/server/services/tasks/text2text/models.py index bccc2b912d..f8d1e8438c 100644 --- a/src/argilla/server/services/tasks/text2text/models.py +++ b/src/argilla/server/services/tasks/text2text/models.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -21,9 +20,7 @@ from argilla.server.services.datasets import ServiceBaseDataset from argilla.server.services.search.model import ( ServiceBaseRecordsQuery, - ServiceBaseSearchResultsAggregations, ServiceScoreRange, - ServiceSearchResults, ) from argilla.server.services.tasks.commons import ( ServiceBaseAnnotation, @@ -91,4 +88,3 @@ class ServiceText2TextQuery(ServiceBaseRecordsQuery): class ServiceText2TextDataset(ServiceBaseDataset): task: TaskType = Field(default=TaskType.text2text, const=True) - pass diff --git a/src/argilla/server/services/tasks/text2text/service.py b/src/argilla/server/services/tasks/text2text/service.py index 48f5af61b3..a83e219153 100644 --- a/src/argilla/server/services/tasks/text2text/service.py +++ b/src/argilla/server/services/tasks/text2text/service.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Type +from typing import Iterable, List, Optional from fastapi import Depends from argilla.server.commons.config import TasksFactory -from argilla.server.services.metrics.models import ServiceBaseTaskMetrics from argilla.server.services.search.model import ( ServiceSearchResults, ServiceSortableField, diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py index 2450d9731d..a94a2632a4 100644 --- a/src/argilla/server/services/tasks/text_classification/model.py +++ b/src/argilla/server/services/tasks/text_classification/model.py @@ -99,7 +99,7 @@ def check_label_length(cls, class_label): assert 1 <= len(class_label) <= DEFAULT_MAX_KEYWORD_LENGTH, ( f"Class name '{class_label}' exceeds max length of {DEFAULT_MAX_KEYWORD_LENGTH}" if len(class_label) > DEFAULT_MAX_KEYWORD_LENGTH - else f"Class name must not be empty" + else "Class name must not be empty" ) return class_label diff --git a/src/argilla/utils/span_utils.py b/src/argilla/utils/span_utils.py index 201039cb4d..4a236c6083 100644 --- a/src/argilla/utils/span_utils.py +++ b/src/argilla/utils/span_utils.py @@ -111,7 +111,7 @@ def validate(self, spans: List[Tuple[str, int, int]]): if misaligned_spans_errors: spans = "\n".join(misaligned_spans_errors) - message += f"Following entity spans are not aligned with provided tokenization\n" + message += "Following entity spans are not aligned with provided tokenization\n" message += f"Spans:\n{spans}\n" message += f"Tokens:\n{self.tokens}" diff --git a/tests/client/apis/test_base.py b/tests/client/apis/test_base.py index 189969dbb6..1101d9849a 100644 --- a/tests/client/apis/test_base.py +++ b/tests/client/apis/test_base.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client import api from argilla.client.apis import AbstractApi, api_compatibility from argilla.client.sdk._helpers import handle_response_error diff --git a/tests/client/conftest.py b/tests/client/conftest.py index ae01679272..ab95abdd6f 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -15,10 +15,9 @@ import datetime from typing import List -import pytest - import argilla import argilla as ar +import pytest from argilla.client.sdk.datasets.models import TaskType from argilla.client.sdk.text2text.models import ( CreationText2TextRecord, diff --git a/tests/client/functional_tests/test_record_update.py b/tests/client/functional_tests/test_record_update.py index 345149ca69..8c52cb4058 100644 --- a/tests/client/functional_tests/test_record_update.py +++ b/tests/client/functional_tests/test_record_update.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client.api import active_api from argilla.client.sdk.commons.errors import NotFoundApiError diff --git a/tests/client/functional_tests/test_scan_raw_records.py b/tests/client/functional_tests/test_scan_raw_records.py index f58b18006a..7af6af9b6f 100644 --- a/tests/client/functional_tests/test_scan_raw_records.py +++ b/tests/client/functional_tests/test_scan_raw_records.py @@ -13,8 +13,6 @@ # limitations under the License. import pytest - -import argilla from argilla.client.api import active_api from argilla.client.sdk.token_classification.models import TokenClassificationRecord @@ -28,9 +26,8 @@ def test_scan_records( gutenberg_spacy_ner, fields, ): - import pandas as pd - import argilla as rg + import pandas as pd data = active_api().datasets.scan( name=gutenberg_spacy_ner, diff --git a/tests/client/sdk/commons/api.py b/tests/client/sdk/commons/api.py index 92eeaa9ef6..b7ca635486 100644 --- a/tests/client/sdk/commons/api.py +++ b/tests/client/sdk/commons/api.py @@ -14,8 +14,6 @@ # limitations under the License. import httpx import pytest -from httpx import Response as HttpxResponse - from argilla.client.sdk.commons.api import ( build_bulk_response, build_data_response, @@ -29,6 +27,7 @@ ValidationError, ) from argilla.client.sdk.text_classification.models import TextClassificationRecord +from httpx import Response as HttpxResponse def test_text2text_bulk(sdk_client, mocked_client, bulk_text2text_data, monkeypatch): diff --git a/tests/client/sdk/commons/test_client.py b/tests/client/sdk/commons/test_client.py index 5b1dd8bedd..cb801f12ef 100644 --- a/tests/client/sdk/commons/test_client.py +++ b/tests/client/sdk/commons/test_client.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client.api import active_api from argilla.client.sdk.client import Client diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py index 7186d02816..75ef26042b 100644 --- a/tests/client/sdk/conftest.py +++ b/tests/client/sdk/conftest.py @@ -16,9 +16,8 @@ from datetime import datetime from typing import Any, Dict, List -import pytest - import argilla as ar +import pytest from argilla._constants import DEFAULT_API_KEY from argilla.client.sdk.client import AuthenticatedClient from argilla.client.sdk.text2text.models import ( diff --git a/tests/client/sdk/datasets/test_api.py b/tests/client/sdk/datasets/test_api.py index ec46434227..7a5fba08d0 100644 --- a/tests/client/sdk/datasets/test_api.py +++ b/tests/client/sdk/datasets/test_api.py @@ -14,7 +14,6 @@ # limitations under the License. import httpx import pytest - from argilla._constants import DEFAULT_API_KEY from argilla.client.sdk.client import AuthenticatedClient from argilla.client.sdk.commons.errors import ( @@ -22,14 +21,8 @@ NotFoundApiError, ValidationApiError, ) -from argilla.client.sdk.commons.models import ( - ErrorMessage, - HTTPValidationError, - Response, - ValidationError, -) from argilla.client.sdk.datasets.api import _build_response, get_dataset -from argilla.client.sdk.datasets.models import Dataset, TaskType +from argilla.client.sdk.datasets.models import Dataset from argilla.client.sdk.text_classification.models import TextClassificationBulkData diff --git a/tests/client/sdk/datasets/test_models.py b/tests/client/sdk/datasets/test_models.py index a10ea4acd6..c490db0e9b 100644 --- a/tests/client/sdk/datasets/test_models.py +++ b/tests/client/sdk/datasets/test_models.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from argilla.client.sdk.datasets.models import Dataset, TaskType from argilla.server.apis.v0.models.datasets import Dataset as ServerDataset diff --git a/tests/client/sdk/text2text/test_models.py b/tests/client/sdk/text2text/test_models.py index 811d56f61a..57f2e70178 100644 --- a/tests/client/sdk/text2text/test_models.py +++ b/tests/client/sdk/text2text/test_models.py @@ -16,7 +16,6 @@ from datetime import datetime import pytest - from argilla.client.models import Text2TextRecord from argilla.client.sdk.text2text.models import ( CreationText2TextRecord, diff --git a/tests/client/sdk/text_classification/test_models.py b/tests/client/sdk/text_classification/test_models.py index 24bee1eed1..9bebb747c5 100644 --- a/tests/client/sdk/text_classification/test_models.py +++ b/tests/client/sdk/text_classification/test_models.py @@ -16,7 +16,6 @@ from datetime import datetime import pytest - from argilla.client.models import TextClassificationRecord, TokenAttributions from argilla.client.sdk.text_classification.models import ( ClassPrediction, diff --git a/tests/client/sdk/token_classification/test_models.py b/tests/client/sdk/token_classification/test_models.py index bf0af8e502..4e419b07d2 100644 --- a/tests/client/sdk/token_classification/test_models.py +++ b/tests/client/sdk/token_classification/test_models.py @@ -16,7 +16,6 @@ from datetime import datetime import pytest - from argilla.client.models import TokenClassificationRecord from argilla.client.sdk.token_classification.models import ( CreationTokenClassificationRecord, diff --git a/tests/client/sdk/users/test_api.py b/tests/client/sdk/users/test_api.py index 34c053ff8a..8eb8f2b3c7 100644 --- a/tests/client/sdk/users/test_api.py +++ b/tests/client/sdk/users/test_api.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client.sdk.client import AuthenticatedClient from argilla.client.sdk.commons.errors import BaseClientError, UnauthorizedApiError from argilla.client.sdk.users.api import whoami diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 4f80a66deb..0c52482c27 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -15,14 +15,13 @@ import concurrent.futures import datetime from time import sleep -from typing import Any, Iterable, List +from typing import Any, Iterable +import argilla as ar import datasets import httpx import pandas as pd import pytest - -import argilla as ar from argilla._constants import ( _OLD_WORKSPACE_HEADER_NAME, DEFAULT_API_KEY, @@ -46,6 +45,7 @@ from argilla.server.apis.v0.models.text_classification import ( TextClassificationSearchResults, ) + from tests.helpers import SecuredClient from tests.server.test_api import create_some_data_for_text_classification diff --git a/tests/client/test_asgi.py b/tests/client/test_asgi.py index 41565b7aa8..465e20136d 100644 --- a/tests/client/test_asgi.py +++ b/tests/client/test_asgi.py @@ -16,17 +16,16 @@ import time from typing import Any, Dict -from fastapi import FastAPI -from starlette.applications import Starlette -from starlette.responses import JSONResponse, PlainTextResponse -from starlette.testclient import TestClient - import argilla from argilla.monitoring.asgi import ( ArgillaLogHTTPMiddleware, text_classification_mapper, token_classification_mapper, ) +from fastapi import FastAPI +from starlette.applications import Starlette +from starlette.responses import JSONResponse, PlainTextResponse +from starlette.testclient import TestClient def test_argilla_middleware_for_text_classification( diff --git a/tests/client/test_client_errors.py b/tests/client/test_client_errors.py index a19ff3dd12..f9f0cd4703 100644 --- a/tests/client/test_client_errors.py +++ b/tests/client/test_client_errors.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.client.sdk.commons.errors import UnauthorizedApiError diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py index e335e419b3..83ae429af0 100644 --- a/tests/client/test_dataset.py +++ b/tests/client/test_dataset.py @@ -17,12 +17,11 @@ import sys from time import sleep +import argilla as ar import datasets import pandas as pd import pytest import spacy - -import argilla as ar from argilla.client.datasets import ( DatasetBase, DatasetForTokenClassification, diff --git a/tests/client/test_models.py b/tests/client/test_models.py index 6d5d4b052b..75d37bde11 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -19,15 +19,13 @@ import numpy import pandas as pd import pytest -from pydantic import ValidationError - -from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.client.models import ( Text2TextRecord, TextClassificationRecord, TokenClassificationRecord, _Validators, ) +from pydantic import ValidationError @pytest.mark.parametrize( diff --git a/tests/conftest.py b/tests/conftest.py index a9b28296e0..33738a6db6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,6 @@ import httpx import pytest from _pytest.logging import LogCaptureFixture - from argilla.client.sdk.users import api as users_api from argilla.server.commons import telemetry @@ -23,10 +22,9 @@ from loguru import logger except ModuleNotFoundError: logger = None -from starlette.testclient import TestClient - from argilla import app from argilla.client.api import active_api +from starlette.testclient import TestClient from .helpers import SecuredClient diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 3edfdccf73..4499d25964 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla as ar +import pytest from argilla import TextClassificationSettings, TokenClassificationSettings from argilla.client import api from argilla.client.sdk.commons.errors import ForbiddenApiError diff --git a/tests/functional_tests/datasets/test_delete_records_from_datasets.py b/tests/functional_tests/datasets/test_delete_records_from_datasets.py index 5279ebb567..238c358338 100644 --- a/tests/functional_tests/datasets/test_delete_records_from_datasets.py +++ b/tests/functional_tests/datasets/test_delete_records_from_datasets.py @@ -15,7 +15,6 @@ import time import pytest - from argilla.client.sdk.commons.errors import ForbiddenApiError diff --git a/tests/functional_tests/datasets/test_update_record.py b/tests/functional_tests/datasets/test_update_record.py index 8ffd74fabc..7eb050a4c5 100644 --- a/tests/functional_tests/datasets/test_update_record.py +++ b/tests/functional_tests/datasets/test_update_record.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.server.apis.v0.models.text2text import Text2TextRecord from argilla.server.apis.v0.models.text_classification import TextClassificationRecord from argilla.server.apis.v0.models.token_classification import TokenClassificationRecord diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py index 92eb8062bb..9fc0b9b439 100644 --- a/tests/functional_tests/search/test_search_service.py +++ b/tests/functional_tests/search/test_search_service.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla +import pytest from argilla.server.apis.v0.models.commons.model import ScoreRange from argilla.server.apis.v0.models.datasets import Dataset from argilla.server.apis.v0.models.text_classification import ( diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py index 0a85ec53e8..5e2a16f8bb 100644 --- a/tests/functional_tests/test_log_for_text_classification.py +++ b/tests/functional_tests/test_log_for_text_classification.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla as ar +import pytest from argilla.client.sdk.commons.errors import ( BadRequestApiError, GenericApiError, ValidationApiError, ) from argilla.server.settings import settings -from tests.client.conftest import SUPPORTED_VECTOR_SEARCH, supported_vector_search + +from tests.client.conftest import SUPPORTED_VECTOR_SEARCH from tests.helpers import SecuredClient diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py index f4bea51e2c..2df8c37c5c 100644 --- a/tests/functional_tests/test_log_for_token_classification.py +++ b/tests/functional_tests/test_log_for_token_classification.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla +import pytest from argilla import TokenClassificationRecord from argilla.client import api from argilla.client.sdk.commons.errors import NotFoundApiError from argilla.metrics import __all__ as ALL_METRICS from argilla.metrics import entity_consistency + from tests.client.conftest import SUPPORTED_VECTOR_SEARCH from tests.helpers import SecuredClient diff --git a/tests/helpers.py b/tests/helpers.py index 3999376e6f..4246a8fd45 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -14,14 +14,13 @@ from typing import List -from fastapi import FastAPI -from starlette.testclient import TestClient - from argilla._constants import API_KEY_HEADER_NAME, WORKSPACE_HEADER_NAME from argilla.client.api import active_api from argilla.server.security import auth from argilla.server.security.auth_provider.local.settings import settings from argilla.server.security.auth_provider.local.users.model import UserInDB +from fastapi import FastAPI +from starlette.testclient import TestClient class SecuredClient: diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py index 71933a9828..f33f9d8ab6 100644 --- a/tests/labeling/text_classification/test_label_errors.py +++ b/tests/labeling/text_classification/test_label_errors.py @@ -14,11 +14,9 @@ # limitations under the License. import sys +import argilla as ar import cleanlab import pytest -from pkg_resources import parse_version - -import argilla as ar from argilla.labeling.text_classification import find_label_errors from argilla.labeling.text_classification.label_errors import ( MissingPredictionError, @@ -26,6 +24,7 @@ SortBy, _construct_s_and_psx, ) +from pkg_resources import parse_version @pytest.fixture( diff --git a/tests/labeling/text_classification/test_label_models.py b/tests/labeling/text_classification/test_label_models.py index c21bac70a2..0cde14796f 100644 --- a/tests/labeling/text_classification/test_label_models.py +++ b/tests/labeling/text_classification/test_label_models.py @@ -13,11 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -from types import SimpleNamespace import numpy as np import pytest - from argilla import TextClassificationRecord from argilla.labeling.text_classification import ( FlyingSquid, diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py index 3c1d29a326..c46c7cdf3e 100644 --- a/tests/labeling/text_classification/test_rule.py +++ b/tests/labeling/text_classification/test_rule.py @@ -14,7 +14,6 @@ # limitations under the License. import httpx import pytest - from argilla import load from argilla.client.models import TextClassificationRecord from argilla.client.sdk.text_classification.models import ( diff --git a/tests/labeling/text_classification/test_weak_labels.py b/tests/labeling/text_classification/test_weak_labels.py index 0ad4cf6e97..8a42263881 100644 --- a/tests/labeling/text_classification/test_weak_labels.py +++ b/tests/labeling/text_classification/test_weak_labels.py @@ -18,8 +18,6 @@ import numpy as np import pandas as pd import pytest -from pandas.testing import assert_frame_equal - from argilla import TextClassificationRecord from argilla.client.sdk.text_classification.models import ( CreationTextClassificationRecord, @@ -35,6 +33,7 @@ NoRulesFoundError, WeakLabelsBase, ) +from pandas.testing import assert_frame_equal @pytest.fixture diff --git a/tests/listeners/test_listener.py b/tests/listeners/test_listener.py index 1aa6bca55a..8549036869 100644 --- a/tests/listeners/test_listener.py +++ b/tests/listeners/test_listener.py @@ -15,9 +15,8 @@ import time from typing import List -import pytest - import argilla as ar +import pytest from argilla import RGListenerContext, listener from argilla.client.models import Record diff --git a/tests/metrics/test_common_metrics.py b/tests/metrics/test_common_metrics.py index b1ac68084c..4cbb6f80f6 100644 --- a/tests/metrics/test_common_metrics.py +++ b/tests/metrics/test_common_metrics.py @@ -12,10 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla import argilla as ar +import pytest from argilla.metrics.commons import keywords, records_status, text_length diff --git a/tests/metrics/test_token_classification.py b/tests/metrics/test_token_classification.py index 443b7a17c8..89ee11c961 100644 --- a/tests/metrics/test_token_classification.py +++ b/tests/metrics/test_token_classification.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - import argilla import argilla as ar +import pytest from argilla.metrics import entity_consistency from argilla.metrics.token_classification import ( Annotations, diff --git a/tests/monitoring/test_base_monitor.py b/tests/monitoring/test_base_monitor.py index 96dc9c349d..f3baa32730 100644 --- a/tests/monitoring/test_base_monitor.py +++ b/tests/monitoring/test_base_monitor.py @@ -13,7 +13,7 @@ # limitations under the License. from time import sleep -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from argilla import TextClassificationRecord from argilla.client.api import Api, active_api diff --git a/tests/monitoring/test_flair_monitoring.py b/tests/monitoring/test_flair_monitoring.py index e8cb51b6ad..18aa38201e 100644 --- a/tests/monitoring/test_flair_monitoring.py +++ b/tests/monitoring/test_flair_monitoring.py @@ -15,11 +15,10 @@ def test_flair_monitoring(mocked_client, monkeypatch): + import argilla as ar from flair.data import Sentence from flair.models import SequenceTagger - import argilla as ar - dataset = "test_flair_monitoring" model = "flair/ner-english" diff --git a/tests/monitoring/test_transformers_monitoring.py b/tests/monitoring/test_transformers_monitoring.py index ee0487c98f..225f606518 100644 --- a/tests/monitoring/test_transformers_monitoring.py +++ b/tests/monitoring/test_transformers_monitoring.py @@ -14,9 +14,8 @@ from time import sleep from typing import List, Union -import pytest - import argilla +import pytest from argilla import TextClassificationRecord diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py index 2ceef02733..0276da15eb 100644 --- a/tests/server/backend/test_query_builder.py +++ b/tests/server/backend/test_query_builder.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.server.daos.backend.search.model import ( SortableField, SortConfig, diff --git a/tests/server/commons/test_records_dao.py b/tests/server/commons/test_records_dao.py index 2dd227267a..f86d68d084 100644 --- a/tests/server/commons/test_records_dao.py +++ b/tests/server/commons/test_records_dao.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.server.commons.models import TaskType from argilla.server.daos.backend import GenericElasticEngineBackend from argilla.server.daos.models.datasets import BaseDatasetDB diff --git a/tests/server/commons/test_settings.py b/tests/server/commons/test_settings.py index 563ebdd3ce..9380f00eba 100644 --- a/tests/server/commons/test_settings.py +++ b/tests/server/commons/test_settings.py @@ -15,9 +15,8 @@ import os import pytest -from pydantic import ValidationError - from argilla.server.settings import ApiSettings +from pydantic import ValidationError @pytest.mark.parametrize("bad_namespace", ["Badns", "bad-ns", "12-bad-ns", "@bad"]) diff --git a/tests/server/commons/test_telemetry.py b/tests/server/commons/test_telemetry.py index e892042649..f3ce903918 100644 --- a/tests/server/commons/test_telemetry.py +++ b/tests/server/commons/test_telemetry.py @@ -15,11 +15,10 @@ import uuid import pytest -from fastapi import Request - from argilla.server.commons import telemetry from argilla.server.commons.models import TaskType from argilla.server.errors import ServerError +from fastapi import Request mock_request = Request(scope={"type": "http", "headers": {}}) diff --git a/tests/server/daos/models/test_records.py b/tests/server/daos/models/test_records.py index e2542c804f..5ebac261fc 100644 --- a/tests/server/daos/models/test_records.py +++ b/tests/server/daos/models/test_records.py @@ -15,7 +15,6 @@ import warnings import pytest - from argilla.server.daos.models.records import BaseRecordInDB from argilla.server.settings import settings diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py index 416cf16275..2d0ebebf0b 100644 --- a/tests/server/datasets/test_api.py +++ b/tests/server/datasets/test_api.py @@ -19,6 +19,7 @@ TextClassificationBulkRequest, ) from argilla.server.commons.models import TaskType + from tests.helpers import SecuredClient @@ -97,7 +98,7 @@ def test_fetch_dataset_using_workspaces(mocked_client: SecuredClient): assert response.status_code == 409, response.json() response = mocked_client.post( - f"/api/datasets", + "/api/datasets", json=request, ) diff --git a/tests/server/datasets/test_dao.py b/tests/server/datasets/test_dao.py index 2b208c73fd..962df52564 100644 --- a/tests/server/datasets/test_dao.py +++ b/tests/server/datasets/test_dao.py @@ -14,7 +14,6 @@ # limitations under the License. import pytest - from argilla.server.commons.models import TaskType from argilla.server.daos.backend import GenericElasticEngineBackend from argilla.server.daos.datasets import DatasetsDAO diff --git a/tests/server/datasets/test_model.py b/tests/server/datasets/test_model.py index b664522d78..165203cd7f 100644 --- a/tests/server/datasets/test_model.py +++ b/tests/server/datasets/test_model.py @@ -14,10 +14,9 @@ # limitations under the License. import pytest -from pydantic import ValidationError - from argilla.server.apis.v0.models.datasets import CreateDatasetRequest from argilla.server.commons.models import TaskType +from pydantic import ValidationError @pytest.mark.parametrize( diff --git a/tests/server/functional_tests/datasets/test_get_record.py b/tests/server/functional_tests/datasets/test_get_record.py index 5bc7deba30..ee95849bef 100644 --- a/tests/server/functional_tests/datasets/test_get_record.py +++ b/tests/server/functional_tests/datasets/test_get_record.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from argilla.server.apis.v0.models.text2text import ( Text2TextBulkRequest, Text2TextRecord, diff --git a/tests/server/info/test_api.py b/tests/server/info/test_api.py index d25b143555..41b6c23518 100644 --- a/tests/server/info/test_api.py +++ b/tests/server/info/test_api.py @@ -36,7 +36,7 @@ def test_api_status(mocked_client): assert info.version == argilla_version # Checking to not get the error dictionary service.py includes whenever something goes wrong - assert not "error" in info.elasticsearch + assert "error" not in info.elasticsearch # Checking that the first key into mem_info dictionary has a nont-none value assert "rss" in info.mem_info is not None diff --git a/tests/server/security/test_dao.py b/tests/server/security/test_dao.py index cda8ea5d1b..41e03f5bc3 100644 --- a/tests/server/security/test_dao.py +++ b/tests/server/security/test_dao.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla._constants import DEFAULT_API_KEY from argilla.server.security.auth_provider.local.users.service import create_users_dao diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py index 4c4428ed06..757349d6a0 100644 --- a/tests/server/security/test_model.py +++ b/tests/server/security/test_model.py @@ -13,10 +13,9 @@ # limitations under the License. import pytest -from pydantic import ValidationError - from argilla.server.errors import EntityNotFoundError from argilla.server.security.model import User +from pydantic import ValidationError @pytest.mark.parametrize("email", ["my@email.com", "infra@recogn.ai"]) diff --git a/tests/server/security/test_provider.py b/tests/server/security/test_provider.py index f06b64dcf7..6d169e74dc 100644 --- a/tests/server/security/test_provider.py +++ b/tests/server/security/test_provider.py @@ -13,12 +13,11 @@ # limitations under the License. import pytest -from fastapi.security import SecurityScopes - from argilla._constants import DEFAULT_API_KEY from argilla.server.security.auth_provider.local.provider import ( create_local_auth_provider, ) +from fastapi.security import SecurityScopes localAuth = create_local_auth_provider() security_Scopes = SecurityScopes diff --git a/tests/server/security/test_service.py b/tests/server/security/test_service.py index 9bc5f315a3..8985b556a2 100644 --- a/tests/server/security/test_service.py +++ b/tests/server/security/test_service.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla._constants import DEFAULT_API_KEY from argilla.server.security.auth_provider.local.users.dao import create_users_dao from argilla.server.security.auth_provider.local.users.service import UsersService diff --git a/tests/server/test_app.py b/tests/server/test_app.py index 21d336d081..06e6ca9f05 100644 --- a/tests/server/test_app.py +++ b/tests/server/test_app.py @@ -16,7 +16,6 @@ from importlib import reload import pytest - from argilla.server import app diff --git a/tests/server/text2text/test_api.py b/tests/server/text2text/test_api.py index a5f8c36e73..c5db714f9d 100644 --- a/tests/server/text2text/test_api.py +++ b/tests/server/text2text/test_api.py @@ -14,14 +14,13 @@ from typing import List, Optional import pytest - from argilla.server.apis.v0.models.commons.model import BulkResponse from argilla.server.apis.v0.models.text2text import ( Text2TextBulkRequest, - Text2TextRecord, Text2TextRecordInputs, Text2TextSearchResults, ) + from tests.client.conftest import SUPPORTED_VECTOR_SEARCH diff --git a/tests/server/text_classification/test_api.py b/tests/server/text_classification/test_api.py index 5ffb8c5300..6791614471 100644 --- a/tests/server/text_classification/test_api.py +++ b/tests/server/text_classification/test_api.py @@ -16,7 +16,6 @@ from datetime import datetime import pytest - from argilla.server.apis.v0.models.commons.model import BulkResponse from argilla.server.apis.v0.models.datasets import Dataset from argilla.server.apis.v0.models.text_classification import ( @@ -28,6 +27,7 @@ TextClassificationSearchResults, ) from argilla.server.commons.models import PredictionStatus + from tests.client.conftest import SUPPORTED_VECTOR_SEARCH diff --git a/tests/server/text_classification/test_api_rules.py b/tests/server/text_classification/test_api_rules.py index 6c95ed80c0..464ab427af 100644 --- a/tests/server/text_classification/test_api_rules.py +++ b/tests/server/text_classification/test_api_rules.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from argilla.server.apis.v0.models.text_classification import ( CreateLabelingRule, LabelingRule, diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index 18a967a637..87e344cb02 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from pydantic import ValidationError - from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.server.apis.v0.models.text_classification import ( TextClassificationAnnotation, @@ -27,6 +25,7 @@ ClassPrediction, ServiceTextClassificationRecord, ) +from pydantic import ValidationError def test_flatten_metadata(): diff --git a/tests/server/token_classification/test_api.py b/tests/server/token_classification/test_api.py index 29ee17ec24..8260300012 100644 --- a/tests/server/token_classification/test_api.py +++ b/tests/server/token_classification/test_api.py @@ -15,7 +15,6 @@ from typing import Callable import pytest - from argilla.server.apis.v0.models.commons.model import BulkResponse, SortableField from argilla.server.apis.v0.models.token_classification import ( TokenClassificationBulkRequest, @@ -24,6 +23,7 @@ TokenClassificationSearchRequest, TokenClassificationSearchResults, ) + from tests.client.conftest import SUPPORTED_VECTOR_SEARCH diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py index 11701da917..8d0c12bbff 100644 --- a/tests/server/token_classification/test_model.py +++ b/tests/server/token_classification/test_model.py @@ -14,8 +14,6 @@ # limitations under the License. import pytest -from pydantic import ValidationError - from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH from argilla.server.apis.v0.models.token_classification import ( TokenClassificationAnnotation, @@ -28,6 +26,7 @@ EntitySpan, ServiceTokenClassificationRecord, ) +from pydantic import ValidationError def test_char_position(): diff --git a/tests/utils/test_span_utils.py b/tests/utils/test_span_utils.py index eb53a0b038..e3ddbdc6cf 100644 --- a/tests/utils/test_span_utils.py +++ b/tests/utils/test_span_utils.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from argilla.utils.span_utils import SpanUtils diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 47804bd72e..85fe24e9f5 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from argilla.utils import LazyargillaModule From 91a77ad85244cd885fafd40c38251ecdf2e37147 Mon Sep 17 00:00:00 2001 From: Daniel Vila Suero Date: Tue, 14 Feb 2023 18:29:43 +0100 Subject: [PATCH 06/45] Docs: Update readme with quickstart section and new links to guides (#2333) # Description - [x] Documentation update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 81 ++++++++----------------------------------------------- 1 file changed, 11 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index 87474b4e3c..b034e12f15 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ ### Advanced NLP labeling -- Programmatic labeling using [weak supervision](https://docs.argilla.io/en/latest/guides/techniques/weak_supervision.html). Built-in label models (Snorkel, Flyingsquid) +- Programmatic labeling using [rules and weak supervision](https://docs.argilla.io/en/latest/guides/programmatic_labeling_with_rules.html). Built-in label models (Snorkel, Flyingsquid) - [Bulk-labeling](https://docs.argilla.io/en/latest/reference/webapp/features.html#bulk-annotate) and [search-driven annotation](https://docs.argilla.io/en/latest/guides/features/queries.html) - Iterate on training data with any [pre-trained model](https://docs.argilla.io/en/latest/tutorials/libraries/huggingface.html) or [library](https://docs.argilla.io/en/latest/tutorials/libraries/libraries.html) - Efficiently review and refine annotations in the UI and with Python @@ -71,93 +71,34 @@ ### Monitoring - Close the gap between production data and data collection activities -- [Auto-monitoring](https://docs.argilla.io/en/latest/guides/steps/3_deploying.html) for [major NLP libraries and pipelines](https://docs.argilla.io/en/latest/tutorials/libraries/libraries.html) (spaCy, Hugging Face, FlairNLP) +- [Auto-monitoring](https://docs.argilla.io/en/latest/guides/log_load_and_prepare_data.html) for [major NLP libraries and pipelines](https://docs.argilla.io/en/latest/tutorials/libraries/libraries.html) (spaCy, Hugging Face, FlairNLP) - [ASGI middleware](https://docs.argilla.io/en/latest/tutorials/notebooks/deploying-texttokenclassification-fastapi.html) for HTTP endpoints -- Argilla Metrics to understand data and model issues, [like entity consistency for NER models](https://docs.argilla.io/en/latest/guides/steps/4_monitoring.html) +- Argilla Metrics to understand data and model issues, [like entity consistency for NER models](https://docs.argilla.io/en/latest/guides/measure_datasets_with_metrics.html) - Integrated with Kibana for custom dashboards ### Team workspaces - Bring different users and roles into the NLP data and model lifecycles -- Organize data collection, review and monitoring into different [workspaces](https://docs.argilla.io/en/latest/getting_started/installation/user_management.html#workspace) +- Organize data collection, review and monitoring into different [workspaces](https://docs.argilla.io/en/latest/getting_started/installation/configurations/user_management.html) - Manage workspace access for different users ## Quickstart -Argilla is composed of a `Python Server` with Elasticsearch as the database layer, and a `Python Client` to create and manage datasets. +👋 Welcome! If you have just discovered Argilla this is the best place to get started. Argilla is composed of: -To get started you need to **install the client and the server** with `pip`: -```bash - -pip install "argilla[server]" - -``` - -Then you need to **run [Elasticsearch (ES)](https://www.elastic.co/elasticsearch)**. - -The simplest way is to use`Docker` by running: - -```bash +* Argilla Client: a powerful Python library for reading and writing data into Argilla, using all the libraries you love (transformers, spaCy, datasets, and any other). -docker run -d --name elasticsearch-for-argilla --network argilla-net -p 9200:9200 -p 9300:9300 -e "ES_JAVA_OPTS=-Xms512m -Xmx512m" -e "discovery.type=single-node" docker.elastic.co/elasticsearch/elasticsearch:8.5.3 - -``` -> :information_source: **Check [the docs](https://docs.argilla.io/en/latest/getting_started/quickstart.html) for further options and configurations for Elasticsearch.** - -Finally you can **launch the server**: - -```bash - -python -m argilla - -``` -> :information_source: The most common error message after this step is related to the Elasticsearch instance not running. Make sure your Elasticsearch instance is running on http://localhost:9200/. If you already have an Elasticsearch instance or cluster, you point the server to its URL by using [ENV variables](#) +* Argilla Server and UI: the API and UI for data annotation and curation. +To get started you need to: -🎉 You can now access Argilla UI pointing your browser at http://localhost:6900/. +1. Launch the Argilla Server and UI. -**The default username and password are** `argilla` **and** `1234`. - -Your workspace will contain no datasets. So let's use the `datasets` library to create our first datasets! - -First, you need to install `datasets`: -```bash - -pip install datasets - -``` - -Then go to your Python IDE of choice and run: -```python - -import pandas as pd -import argilla as rg -from datasets import load_dataset - -# load dataset from the hub -dataset = load_dataset("argilla/gutenberg_spacy-ner", split="train") - -# read in dataset, assuming its a dataset for text classification -dataset_rg = rg.read_datasets(dataset, task="TokenClassification") - -# log the dataset to the Argilla web app -rg.log(dataset_rg, "gutenberg_spacy-ner") - -# load dataset from json -my_dataframe = pd.read_json( - "https://raw.githubusercontent.com/recognai/datasets/main/sst-sentimentclassification.json") - -# convert pandas dataframe to DatasetForTextClassification -dataset_rg = rg.DatasetForTextClassification.from_pandas(my_dataframe) - -# log the dataset to the Argilla web app -rg.log(dataset_rg, name="sst-sentimentclassification") -``` +2. Pick a tutorial and start rocking with Argilla using Jupyter Notebooks, or Google Colab. -This will create two datasets that you can use to do a quick tour of the core features of Argilla. +To get started follow the steps [on the Quickstart docs page](https://docs.argilla.io/en/latest/getting_started/quickstart.html). > 🚒 **If you find issues, get direct support from the team and other community members on the [Slack Community](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g)** -For getting started with your own use cases, [go to the docs](https://docs.argilla.io). ## Principles - **Open**: Argilla is free, open-source, and 100% compatible with major NLP libraries (Hugging Face transformers, spaCy, Stanford Stanza, Flair, etc.). In fact, you can **use and combine your preferred libraries** without implementing any specific interface. From 6d2885f8760747dfe6c0e4918a37203dc7c1d248 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 15 Feb 2023 14:33:33 +0100 Subject: [PATCH 07/45] Enhancement: Also validate records on assignment of variables (#2337) Closes #2257 Hello! ## Pull Request overview * Add `validate_assignment = True` to the `pydantic` Config to also apply validation on assignment e.g. via `record.prediction = ...`. * Add tests to ensure the correct behaviour. ## What was wrong? See #2257 for details on the issue. The core issue is that the following script would fail with a server-side error on `rg.log`: ```python import argilla as rg record = rg.TextClassificationRecord( text="Hello world, this is me!", prediction=[("LABEL1", 0.8), ("LABEL2", 0.2)], annotation="LABEL1", multi_label=False, ) rg.log(record, "test_item_assignment") record.prediction = "rubbish" rg.log(record, "test_item_assignment") ```
Click to see failure logs using the develop branch ``` 023-02-14 11:19:40.794 | ERROR | argilla.client.client:__log_internal__:103 - Cannot log data in dataset 'test_item_assignment' Error: ValueError Details: not enough values to unpack (expected 2, got 1) Traceback (most recent call last): File "c:/code/argilla/demo_2257.py", line 13, in rg.log(record, "test_item_assignment") File "C:\code\argilla\src\argilla\client\api.py", line 157, in log background=background, File "C:\code\argilla\src\argilla\client\client.py", line 305, in log return future.result() File "C:\Users\tom\.conda\envs\argilla\lib\concurrent\futures\_base.py", line 435, in result return self.__get_result() File "C:\Users\tom\.conda\envs\argilla\lib\concurrent\futures\_base.py", line 384, in __get_result raise self._exception File "C:\code\argilla\src\argilla\client\client.py", line 107, in __log_internal__ raise ex File "C:\code\argilla\src\argilla\client\client.py", line 99, in __log_internal__ return await api.log_async(*args, **kwargs) File "C:\code\argilla\src\argilla\client\client.py", line 389, in log_async records=[creation_class.from_client(r) for r in chunk], File "C:\code\argilla\src\argilla\client\client.py", line 389, in records=[creation_class.from_client(r) for r in chunk], File "C:\code\argilla\src\argilla\client\sdk\text_classification\models.py", line 65, in from_client for label, score in record.prediction File "C:\code\argilla\src\argilla\client\sdk\text_classification\models.py", line 64, in ClassPrediction(**{"class": label, "score": score}) ValueError: not enough values to unpack (expected 2, got 1) 0%| | 0/1 [00:00 This is unnecessary, as we can create a client-side error using our validators, too. ## The fix The fix is as simple as instructing `pydantic` to also trigger the validation on assignment, e.g. on `record.prediction = ...`. ## Relevant documentation See `validate_assignment` in https://docs.pydantic.dev/usage/model_config/#options. ## After the fix Using the same script as before, the error now becomes: ``` Traceback (most recent call last): File "c:/code/argilla/demo_2257.py", line 11, in record.prediction = "rubbish" File "C:\code\argilla\src\argilla\client\models.py", line 280, in __setattr__ super().__setattr__(name, value) File "pydantic\main.py", line 445, in pydantic.main.BaseModel.__setattr__ pydantic.error_wrappers.ValidationError: 1 validation error for TextClassificationRecord prediction value is not a valid list (type=type_error.list) ``` Which triggers directly on the assignment rather than only when the record is logged. --- **Type of change** - [x] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** All existing tests passed with my changes. Beyond that, I added some new tests of my own. These tests fail on the `develop` branch, but pass on this branch. This is because previously it was allowed to perform `record.prediction = "rubbish"` as no validation was executed on it. **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I added comments to my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --- - Tom Aarsen --- src/argilla/client/models.py | 2 ++ tests/client/test_models.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/argilla/client/models.py b/src/argilla/client/models.py index 609a92762b..845e7f8eb0 100644 --- a/src/argilla/client/models.py +++ b/src/argilla/client/models.py @@ -121,7 +121,9 @@ def _check_and_update_status(cls, values): return values class Config: + # https://docs.pydantic.dev/usage/model_config/#options extra = "forbid" + validate_assignment = True class BulkResponse(BaseModel): diff --git a/tests/client/test_models.py b/tests/client/test_models.py index 75d37bde11..bf3d72c366 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -306,3 +306,24 @@ class MockRecord(_Validators): def test_text2text_prediction_validator(prediction, expected): record = Text2TextRecord(text="mock", prediction=prediction) assert record.prediction == expected + + +@pytest.mark.parametrize( + "record", + [ + TextClassificationRecord(text="This is a test"), + TokenClassificationRecord( + text="This is a test", tokens="This is a test".split() + ), + Text2TextRecord(text="This is a test"), + ], +) +def test_record_validation_on_assignment(record): + with pytest.raises(ValidationError): + record.prediction = "rubbish" + + with pytest.raises(ValidationError): + record.annotation = [("rubbish",)] + + with pytest.raises(ValidationError): + record.vectors = "rubbish" From 179ffb92c73fba9b785503ee8c61f2bea290787a Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 15 Feb 2023 16:13:42 +0100 Subject: [PATCH 08/45] Enhancement: Distinguish between error message and context in validation of spans (#2329) # Description I expanded slightly on the error message provided when providing spans that do not match the tokenization. Consider the following example script: ```python import argilla as rg record = rg.TokenClassificationRecord( text = "This is my text", tokens=["This", "is", "my", "text"], prediction=[("ORG", 0, 6), ("PER", 8, 14)], ) ``` The (truncated) output on the `develop` branch: ``` ValueError: Following entity spans are not aligned with provided tokenization Spans: - [This i] defined in This is my ... - [my tex] defined in ...s is my text Tokens: ['This', 'is', 'my', 'text'] ``` The distinction between `defined in` and `This is` is unclear. I've worked on this. The (truncated) output after this PR: ``` ValueError: Following entity spans are not aligned with provided tokenization Spans: - [This i] defined in 'This is my ...' - [my tex] defined in '...s is my text' Tokens: ['This', 'is', 'my', 'text'] ``` Note the additional `'`. Note that the changes rely on `repr`, so if the snippet contains `'` itself, it uses `"` instead, e.g.: ```python import argilla as rg record = rg.TokenClassificationRecord( text = "This is Tom's text", tokens=["This", "is", "Tom", "'s", "text"], prediction=[("ORG", 0, 6), ("PER", 8, 16)], ) ``` ``` ValueError: Following entity spans are not aligned with provided tokenization Spans: - [This i] defined in 'This is Tom...' - [Tom's te] defined in "...s is Tom's text" Tokens: ['This', 'is', 'Tom', "'s", 'text'] ``` **Type of change** - [x] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** Modified the relevant tests, ensured they worked. **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I added comments to my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - Tom Aarsen --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/argilla/utils/span_utils.py | 9 ++------- tests/client/test_api.py | 2 +- tests/utils/test_span_utils.py | 4 ++-- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/src/argilla/utils/span_utils.py b/src/argilla/utils/span_utils.py index 4a236c6083..bed9e172f7 100644 --- a/src/argilla/utils/span_utils.py +++ b/src/argilla/utils/span_utils.py @@ -93,13 +93,8 @@ def validate(self, spans: List[Tuple[str, int, int]]): self._start_to_token_idx.get(char_start), self._end_to_token_idx.get(char_end), ): - message = f"- [{self.text[char_start:char_end]}] defined in " - if char_start - 5 > 0: - message += "..." - message += self.text[max(char_start - 5, 0) : char_end + 5] - if char_end + 5 < len(self.text): - message += "..." - + span_str = self.text[char_start:char_end] + message = f"{span} - {repr(span_str)}" misaligned_spans_errors.append(message) if not_valid_spans_errors or misaligned_spans_errors: diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 0c52482c27..12e66b634c 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -625,7 +625,7 @@ def test_token_classification_spans(span, valid): with pytest.raises( ValueError, match="Following entity spans are not aligned with provided tokenization\n" - r"Spans:\n- \[s\] defined in Esto es...\n" + r"Spans:\n\('test', 1, 2\) - 's'\n" r"Tokens:\n\['Esto', 'es', 'una', 'prueba'\]", ): ar.TokenClassificationRecord( diff --git a/tests/utils/test_span_utils.py b/tests/utils/test_span_utils.py index e3ddbdc6cf..3cde38131c 100644 --- a/tests/utils/test_span_utils.py +++ b/tests/utils/test_span_utils.py @@ -67,7 +67,7 @@ def test_validate_misaligned_spans(): with pytest.raises( ValueError, match="Following entity spans are not aligned with provided tokenization\n" - r"Spans:\n- \[test \] defined in test this.\n" + r"Spans:\n\('mock', 0, 5\) - 'test '\n" r"Tokens:\n\['test', 'this', '.'\]", ): span_utils.validate([("mock", 0, 5)]) @@ -79,7 +79,7 @@ def test_validate_not_valid_and_misaligned_spans(): ValueError, match=r"Following entity spans are not valid: \[\('mock', 2, 1\)\]\n" "Following entity spans are not aligned with provided tokenization\n" - r"Spans:\n- \[test \] defined in test this.\n" + r"Spans:\n\('mock', 0, 5\) - 'test '\n" r"Tokens:\n\['test', 'this', '.'\]", ): span_utils.validate([("mock", 2, 1), ("mock", 0, 5)]) From c83ec9e75560294d8ab58ad86c4b4b1c347b5491 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 16 Feb 2023 10:39:37 +0100 Subject: [PATCH 09/45] ci: Setup black line-length in toml file (#2352) Upgrade the black line-length parameter in the project toml file. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 3 + scripts/load_data.py | 16 +- scripts/migrations/es_migration_25042021.py | 4 +- src/argilla/client/apis/datasets.py | 15 +- src/argilla/client/apis/search.py | 4 +- src/argilla/client/apis/status.py | 4 +- src/argilla/client/client.py | 67 ++---- src/argilla/client/datasets.py | 195 ++++------------- src/argilla/client/models.py | 49 ++--- src/argilla/client/sdk/client.py | 5 +- src/argilla/client/sdk/commons/api.py | 16 +- src/argilla/client/sdk/commons/errors.py | 9 +- .../client/sdk/commons/errors_handler.py | 4 +- src/argilla/client/sdk/commons/models.py | 4 +- src/argilla/client/sdk/datasets/api.py | 4 +- src/argilla/client/sdk/metrics/api.py | 4 +- src/argilla/client/sdk/metrics/models.py | 4 +- src/argilla/client/sdk/text2text/models.py | 5 +- .../client/sdk/text_classification/api.py | 16 +- .../client/sdk/text_classification/models.py | 30 +-- .../client/sdk/token_classification/api.py | 8 +- .../client/sdk/token_classification/models.py | 18 +- .../text_classification/label_errors.py | 30 +-- .../text_classification/label_models.py | 200 +++++------------- .../labeling/text_classification/rule.py | 24 +-- .../text_classification/weak_labels.py | 183 ++++------------ src/argilla/listeners/listener.py | 22 +- src/argilla/metrics/commons.py | 12 +- src/argilla/metrics/helpers.py | 11 +- src/argilla/metrics/models.py | 9 +- .../metrics/token_classification/metrics.py | 32 +-- src/argilla/monitoring/_flair.py | 6 +- src/argilla/monitoring/_spacy.py | 8 +- src/argilla/monitoring/_transformers.py | 8 +- src/argilla/monitoring/asgi.py | 36 +--- src/argilla/monitoring/base.py | 4 +- src/argilla/monitoring/model_monitor.py | 3 +- .../server/apis/v0/handlers/datasets.py | 20 +- .../server/apis/v0/handlers/metrics.py | 4 +- .../server/apis/v0/handlers/text2text.py | 12 +- .../apis/v0/handlers/text_classification.py | 64 ++---- .../text_classification_dataset_settings.py | 12 +- .../apis/v0/handlers/token_classification.py | 20 +- .../token_classification_dataset_settings.py | 12 +- src/argilla/server/apis/v0/handlers/users.py | 8 +- src/argilla/server/apis/v0/helpers.py | 4 +- .../server/apis/v0/models/commons/model.py | 4 +- .../server/apis/v0/models/commons/params.py | 17 +- .../server/apis/v0/models/text2text.py | 4 +- .../apis/v0/models/text_classification.py | 23 +- .../apis/v0/models/token_classification.py | 8 +- .../apis/v0/validators/text_classification.py | 12 +- .../v0/validators/token_classification.py | 22 +- src/argilla/server/commons/config.py | 8 +- src/argilla/server/commons/telemetry.py | 17 +- .../backend/client_adapters/opensearch.py | 23 +- .../server/daos/backend/generic_elastic.py | 6 +- .../server/daos/backend/mappings/helpers.py | 6 +- .../server/daos/backend/metrics/base.py | 8 +- .../server/daos/backend/metrics/datasets.py | 4 +- .../backend/metrics/text_classification.py | 16 +- .../backend/metrics/token_classification.py | 19 +- .../server/daos/backend/query_helpers.py | 55 ++--- .../server/daos/backend/search/model.py | 3 +- .../daos/backend/search/query_builder.py | 23 +- src/argilla/server/daos/datasets.py | 33 +-- src/argilla/server/daos/models/records.py | 16 +- src/argilla/server/daos/records.py | 12 +- src/argilla/server/errors/base_errors.py | 12 +- src/argilla/server/helpers.py | 8 +- src/argilla/server/responses/api_responses.py | 4 +- src/argilla/server/routes.py | 4 +- .../security/auth_provider/local/provider.py | 16 +- .../security/auth_provider/local/settings.py | 8 +- .../security/auth_provider/local/users/dao.py | 4 +- src/argilla/server/security/model.py | 6 +- src/argilla/server/server.py | 10 +- src/argilla/server/services/datasets.py | 36 +--- src/argilla/server/services/info.py | 4 +- src/argilla/server/services/search/service.py | 8 +- .../server/services/storage/service.py | 4 +- .../server/services/tasks/commons/models.py | 4 +- .../server/services/tasks/text2text/models.py | 12 +- .../labeling_rules_service.py | 20 +- .../tasks/text_classification/metrics.py | 26 +-- .../tasks/text_classification/model.py | 59 ++---- .../tasks/text_classification/service.py | 37 +--- .../tasks/token_classification/metrics.py | 50 ++--- .../tasks/token_classification/model.py | 33 +-- .../tasks/token_classification/service.py | 8 +- src/argilla/utils/span_utils.py | 8 +- src/argilla/utils/utils.py | 16 +- tests/client/conftest.py | 47 +--- tests/client/sdk/commons/api.py | 12 +- tests/client/sdk/conftest.py | 19 +- tests/client/sdk/datasets/test_models.py | 8 +- tests/client/sdk/text2text/test_models.py | 4 +- .../sdk/text_classification/test_models.py | 20 +- tests/client/sdk/users/test_api.py | 8 +- tests/client/test_api.py | 16 +- tests/client/test_dataset.py | 167 ++++----------- tests/client/test_init.py | 4 +- tests/client/test_models.py | 88 ++------ tests/datasets/test_datasets.py | 4 +- .../test_delete_records_from_datasets.py | 14 +- .../search/test_search_service.py | 20 +- .../test_log_for_text_classification.py | 13 +- .../test_log_for_token_classification.py | 8 +- .../text_classification/test_label_errors.py | 45 +--- .../text_classification/test_label_models.py | 109 +++------- .../labeling/text_classification/test_rule.py | 16 +- .../text_classification/test_weak_labels.py | 198 ++++------------- tests/listeners/test_listener.py | 4 +- tests/metrics/test_common_metrics.py | 4 +- tests/monitoring/test_monitor.py | 6 +- .../test_transformers_monitoring.py | 9 +- tests/server/backend/test_query_builder.py | 8 +- tests/server/commons/test_telemetry.py | 4 +- tests/server/datasets/test_api.py | 15 +- tests/server/metrics/test_api.py | 16 +- tests/server/security/test_model.py | 12 +- tests/server/security/test_provider.py | 4 +- tests/server/test_errors.py | 5 +- tests/server/text2text/test_api.py | 4 +- tests/server/text_classification/test_api.py | 37 +--- .../text_classification/test_api_rules.py | 88 ++------ .../text_classification/test_api_settings.py | 12 +- .../server/text_classification/test_model.py | 24 +-- tests/server/token_classification/test_api.py | 17 +- .../token_classification/test_api_settings.py | 12 +- .../server/token_classification/test_model.py | 12 +- tests/utils/test_span_utils.py | 8 +- tests/utils/test_utils.py | 8 +- 133 files changed, 750 insertions(+), 2343 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f28608d42e..42133cb2c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,3 +139,6 @@ exclude = [ [tool.ruff.per-file-ignores] # Ignore imported but unused; "__init__.py" = ["F401"] + +[tool.black] +line-length = 120 diff --git a/scripts/load_data.py b/scripts/load_data.py index 06405d7650..46ae50a959 100644 --- a/scripts/load_data.py +++ b/scripts/load_data.py @@ -64,9 +64,7 @@ def load_news_text_summarization(): rg.log( dataset_rg, name="news-text-summarization", - tags={ - "description": "A text summarization dataset with news pieces and their predicted summaries." - }, + tags={"description": "A text summarization dataset with news pieces and their predicted summaries."}, ) @staticmethod @@ -80,18 +78,14 @@ def load_news_programmatic_labeling(): ) # Define labeling schema to avoid UI user modification - settings = rg.TextClassificationSettings( - label_schema={"World", "Sports", "Sci/Tech", "Business"} - ) + settings = rg.TextClassificationSettings(label_schema={"World", "Sports", "Sci/Tech", "Business"}) rg.configure_dataset(name="news-programmatic-labeling", settings=settings) # Log the dataset rg.log( dataset_rg, name="news-programmatic-labeling", - tags={ - "description": "The AG News with programmatic labeling rules (see weak labeling mode in the UI)." - }, + tags={"description": "The AG News with programmatic labeling rules (see weak labeling mode in the UI)."}, ) # Define queries and patterns for each category (using Elasticsearch DSL) @@ -103,9 +97,7 @@ def load_news_programmatic_labeling(): ] # Define rules - rules = [ - Rule(query=term, label=label) for terms, label in queries for term in terms - ] + rules = [Rule(query=term, label=label) for terms, label in queries for term in terms] # Add rules to the dataset add_rules(dataset="news-programmatic-labeling", rules=rules) diff --git a/scripts/migrations/es_migration_25042021.py b/scripts/migrations/es_migration_25042021.py index 63776fe977..1f5a04d602 100644 --- a/scripts/migrations/es_migration_25042021.py +++ b/scripts/migrations/es_migration_25042021.py @@ -47,9 +47,7 @@ def batcher(iterable, n, fillvalue=None): return zip_longest(*args, fillvalue=fillvalue) -def map_doc_2_action( - index: str, doc: Dict[str, Any], task: TaskType -) -> Optional[Dict[str, Any]]: +def map_doc_2_action(index: str, doc: Dict[str, Any], task: TaskType) -> Optional[Dict[str, Any]]: """Configures bulk action""" doc_data = doc["_source"] new_record = { diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py index 5fff021362..c887300cb1 100644 --- a/src/argilla/client/apis/datasets.py +++ b/src/argilla/client/apis/datasets.py @@ -174,9 +174,7 @@ def scan( """ - url = ( - f"{self._API_PREFIX}/{name}/records/:search?limit={self.DEFAULT_SCAN_SIZE}" - ) + url = f"{self._API_PREFIX}/{name}/records/:search?limit={self.DEFAULT_SCAN_SIZE}" query = self._parse_query(query=query) if limit == 0: @@ -278,13 +276,10 @@ def delete_records( def __save_settings__(self, dataset: _DatasetApiModel, settings: Settings): if __TASK_TO_SETTINGS__.get(dataset.task) != type(settings): raise ValueError( - f"The provided settings type {type(settings)} cannot be applied to dataset." - " Task type mismatch" + f"The provided settings type {type(settings)} cannot be applied to dataset." " Task type mismatch" ) - settings_ = self._SettingsApiModel( - label_schema={"labels": [label for label in settings.label_schema]} - ) + settings_ = self._SettingsApiModel(label_schema={"labels": [label for label in settings.label_schema]}) with api_compatibility(self, min_version=self.__SETTINGS_MIN_API_VERSION__): self.http_client.put( @@ -305,9 +300,7 @@ def load_settings(self, name: str) -> Optional[Settings]: dataset = self.find_by_name(name) try: with api_compatibility(self, min_version=self.__SETTINGS_MIN_API_VERSION__): - response = self.http_client.get( - f"{self._API_PREFIX}/{dataset.task}/{dataset.name}/settings" - ) + response = self.http_client.get(f"{self._API_PREFIX}/{dataset.task}/{dataset.name}/settings") return __TASK_TO_SETTINGS__.get(dataset.task).from_dict(response) except NotFoundApiError: return None diff --git a/src/argilla/client/apis/search.py b/src/argilla/client/apis/search.py index 07de6bb71e..7567c78030 100644 --- a/src/argilla/client/apis/search.py +++ b/src/argilla/client/apis/search.py @@ -80,7 +80,5 @@ def search_records( return SearchResults( total=response["total"], - records=[ - record_class.parse_obj(r).to_client() for r in response["records"] - ], + records=[record_class.parse_obj(r).to_client() for r in response["records"]], ) diff --git a/src/argilla/client/apis/status.py b/src/argilla/client/apis/status.py index d03dcae852..87529e5da4 100644 --- a/src/argilla/client/apis/status.py +++ b/src/argilla/client/apis/status.py @@ -63,9 +63,7 @@ def __enter__(self): if api_version.is_devrelease: api_version = parse(api_version.base_version) if not api_version >= self._min_version: - raise ApiCompatibilityError( - str(self._min_version), api_version=api_version - ) + raise ApiCompatibilityError(str(self._min_version), api_version=api_version) pass def __exit__( diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 5a3974ede3..ef36f35800 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -100,16 +100,12 @@ async def __log_internal__(api: "Argilla", *args, **kwargs): except Exception as ex: dataset = kwargs["name"] _LOGGER.error( - f"\nCannot log data in dataset '{dataset}'\n" - f"Error: {type(ex).__name__}\n" - f"Details: {ex}" + f"\nCannot log data in dataset '{dataset}'\n" f"Error: {type(ex).__name__}\n" f"Details: {ex}" ) raise ex def log(self, *args, **kwargs) -> Future: - return asyncio.run_coroutine_threadsafe( - self.__log_internal__(self.__api__, *args, **kwargs), self.__loop__ - ) + return asyncio.run_coroutine_threadsafe(self.__log_internal__(self.__api__, *args, **kwargs), self.__loop__) class Argilla: @@ -172,10 +168,7 @@ def __del__(self): def client(self) -> AuthenticatedClient: """The underlying authenticated HTTP client""" warnings.warn( - message=( - "This prop will be removed in next release. " - "Please use the http_client prop instead." - ), + message=("This prop will be removed in next release. " "Please use the http_client prop instead."), category=UserWarning, ) return self._client @@ -217,10 +210,7 @@ def set_workspace(self, workspace: str): if workspace != self.get_workspace(): if workspace == self._user.username: self._client.headers.pop(WORKSPACE_HEADER_NAME, workspace) - elif ( - self._user.workspaces is not None - and workspace not in self._user.workspaces - ): + elif self._user.workspaces is not None and workspace not in self._user.workspaces: raise Exception(f"Wrong provided workspace {workspace}") self._client.headers[WORKSPACE_HEADER_NAME] = workspace self._client.headers[_OLD_WORKSPACE_HEADER_NAME] = workspace @@ -370,10 +360,7 @@ async def log_async( bulk_class = Text2TextBulkData creation_class = CreationText2TextRecord else: - raise InputValueError( - f"Unknown record type {record_type}. Available values are" - f" {Record.__args__}" - ) + raise InputValueError(f"Unknown record type {record_type}. Available values are" f" {Record.__args__}") processed, failed = 0, 0 progress_bar = tqdm(total=len(records), disable=not verbose) @@ -398,18 +385,11 @@ async def log_async( # TODO: improve logging policy in library if verbose: - _LOGGER.info( - f"Processed {processed} records in dataset {name}. Failed: {failed}" - ) + _LOGGER.info(f"Processed {processed} records in dataset {name}. Failed: {failed}") workspace = self.get_workspace() - if ( - not workspace - ): # Just for backward comp. with datasets with no workspaces + if not workspace: # Just for backward comp. with datasets with no workspaces workspace = "-" - print( - f"{processed} records logged to" - f" {self._client.base_url}/datasets/{workspace}/{name}" - ) + print(f"{processed} records logged to" f" {self._client.base_url}/datasets/{workspace}/{name}") # Creating a composite BulkResponse with the total processed and failed return BulkResponse(dataset=name, processed=processed, failed=failed) @@ -488,8 +468,7 @@ def load( raise ValueError( "The argument `as_pandas` is deprecated and will be removed in a future" " version. Please adapt your code accordingly. ", - "If you want a pandas DataFrame do" - " `rg.load('my_dataset').to_pandas()`.", + "If you want a pandas DataFrame do" " `rg.load('my_dataset').to_pandas()`.", ) try: @@ -525,9 +504,7 @@ def load( def dataset_metrics(self, name: str) -> List[MetricInfo]: response = datasets_api.get_dataset(self._client, name) - response = metrics_api.get_dataset_metrics( - self._client, name=name, task=response.parsed.task - ) + response = metrics_api.get_dataset_metrics(self._client, name=name, task=response.parsed.task) return response.parsed @@ -572,9 +549,7 @@ def add_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]): rule=rule, ) except AlreadyExistsApiError: - _LOGGER.warning( - f"Rule {rule} already exists. Please, update the rule instead." - ) + _LOGGER.warning(f"Rule {rule} already exists. Please, update the rule instead.") except Exception as ex: _LOGGER.warning(f"Cannot create rule {rule}: {ex}") @@ -593,36 +568,26 @@ def update_dataset_labeling_rules( ) except NotFoundApiError: _LOGGER.info(f"Rule {rule} does not exists, creating...") - text_classification_api.add_dataset_labeling_rule( - self._client, name=dataset, rule=rule - ) + text_classification_api.add_dataset_labeling_rule(self._client, name=dataset, rule=rule) except Exception as ex: _LOGGER.warning(f"Cannot update rule {rule}: {ex}") def delete_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]): for rule in rules: try: - text_classification_api.delete_dataset_labeling_rule( - self._client, name=dataset, rule=rule - ) + text_classification_api.delete_dataset_labeling_rule(self._client, name=dataset, rule=rule) except Exception as ex: _LOGGER.warning(f"Cannot delete rule {rule}: {ex}") """Deletes the dataset labeling rules""" for rule in rules: - text_classification_api.delete_dataset_labeling_rule( - self._client, name=dataset, rule=rule - ) + text_classification_api.delete_dataset_labeling_rule(self._client, name=dataset, rule=rule) def fetch_dataset_labeling_rules(self, dataset: str) -> List[LabelingRule]: - response = text_classification_api.fetch_dataset_labeling_rules( - self._client, name=dataset - ) + response = text_classification_api.fetch_dataset_labeling_rules(self._client, name=dataset) return [LabelingRule.parse_obj(data) for data in response.parsed] - def rule_metrics_for_dataset( - self, dataset: str, rule: LabelingRule - ) -> LabelingRuleMetricsSummary: + def rule_metrics_for_dataset(self, dataset: str, rule: LabelingRule) -> LabelingRuleMetricsSummary: response = text_classification_api.dataset_rule_metrics( self._client, name=dataset, query=rule.query, label=rule.label ) diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py index 952a3a96ef..35cf564883 100644 --- a/src/argilla/client/datasets.py +++ b/src/argilla/client/datasets.py @@ -103,9 +103,7 @@ def _record_init_args(cls) -> List[str]: def __init__(self, records: Optional[List[Record]] = None): if self._RECORD_TYPE is None: - raise NotImplementedError( - "A Dataset implementation has to define a `_RECORD_TYPE`!" - ) + raise NotImplementedError("A Dataset implementation has to define a `_RECORD_TYPE`!") self._records = records or [] if self._records: @@ -179,8 +177,7 @@ def to_datasets(self) -> "datasets.Dataset": del ds_dict["metadata"] dataset = datasets.Dataset.from_dict(ds_dict) _LOGGER.warning( - "The 'metadata' of the records were removed, since it was incompatible" - " with the 'datasets' format." + "The 'metadata' of the records were removed, since it was incompatible" " with the 'datasets' format." ) return dataset @@ -221,15 +218,10 @@ def _prepare_dataset_and_column_mapping( import datasets if isinstance(dataset, datasets.DatasetDict): - raise ValueError( - "`datasets.DatasetDict` are not supported. Please, select the dataset" - " split before." - ) + raise ValueError("`datasets.DatasetDict` are not supported. Please, select the dataset" " split before.") # clean column mappings - column_mapping = { - key: val for key, val in column_mapping.items() if val is not None - } + column_mapping = {key: val for key, val in column_mapping.items() if val is not None} cols_to_be_renamed, cols_to_be_joined = {}, {} for field, col in column_mapping.items(): @@ -263,9 +255,7 @@ def _remove_unsupported_columns( The dataset with unsupported columns removed. """ not_supported_columns = [ - col - for col in dataset.column_names - if col not in cls._record_init_args() + extra_columns + col for col in dataset.column_names if col not in cls._record_init_args() + extra_columns ] if not_supported_columns: @@ -279,9 +269,7 @@ def _remove_unsupported_columns( return dataset @staticmethod - def _join_datasets_columns_and_delete( - row: Dict[str, Any], columns: List[str] - ) -> Dict[str, Any]: + def _join_datasets_columns_and_delete(row: Dict[str, Any], columns: List[str]) -> Dict[str, Any]: """Joins columns of a `datasets.Dataset` row into a dict, and deletes the single columns. Updates the ``row`` dictionary! @@ -354,9 +342,7 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> "Dataset": Returns: The imported records in a argilla Dataset. """ - not_supported_columns = [ - col for col in dataframe.columns if col not in cls._record_init_args() - ] + not_supported_columns = [col for col in dataframe.columns if col not in cls._record_init_args()] if not_supported_columns: _LOGGER.warning( "Following columns are not supported by the" @@ -461,14 +447,10 @@ def prepare_for_training( ) # check if train sizes sum up to 1 - assert (train_size + test_size) == 1, ValueError( - "`train_size` and `test_size` must sum to 1." - ) + assert (train_size + test_size) == 1, ValueError("`train_size` and `test_size` must sum to 1.") # check for annotations - assert any([rec.annotation for rec in self._records]), ValueError( - "Dataset has no annotations." - ) + assert any([rec.annotation for rec in self._records]), ValueError("Dataset has no annotations.") # shuffle records shuffled_records = self._records.copy() @@ -484,13 +466,10 @@ def prepare_for_training( # prepare for training for the right method if framework is Framework.TRANSFORMERS: - return self._prepare_for_training_with_transformers( - train_size=train_size, test_size=test_size, seed=seed - ) + return self._prepare_for_training_with_transformers(train_size=train_size, test_size=test_size, seed=seed) elif framework is Framework.SPACY and lang is None: raise ValueError( - "Please provide a spacy language model to prepare the" - " dataset for training with the spacy framework." + "Please provide a spacy language model to prepare the" " dataset for training with the spacy framework." ) elif framework in [Framework.SPACY, Framework.SPARK_NLP]: if train_size and test_size: @@ -503,12 +482,8 @@ def prepare_for_training( random_state=seed, ) if framework is Framework.SPACY: - train_docbin = self._prepare_for_training_with_spacy( - nlp=lang, records=records_train - ) - test_docbin = self._prepare_for_training_with_spacy( - nlp=lang, records=records_test - ) + train_docbin = self._prepare_for_training_with_spacy(nlp=lang, records=records_train) + test_docbin = self._prepare_for_training_with_spacy(nlp=lang, records=records_test) return train_docbin, test_docbin else: train_df = self._prepare_for_training_with_spark_nlp(records_train) @@ -517,18 +492,11 @@ def prepare_for_training( return train_df, test_df else: if framework is Framework.SPACY: - return self._prepare_for_training_with_spacy( - nlp=lang, records=shuffled_records - ) + return self._prepare_for_training_with_spacy(nlp=lang, records=shuffled_records) else: - return self._prepare_for_training_with_spark_nlp( - records=shuffled_records - ) + return self._prepare_for_training_with_spark_nlp(records=shuffled_records) else: - raise NotImplementedError( - f"Framework {framework} is not supported. Choose from:" - f" {list(Framework)}" - ) + raise NotImplementedError(f"Framework {framework} is not supported. Choose from:" f" {list(Framework)}") @_requires_spacy def _prepare_for_training_with_spacy( @@ -674,9 +642,7 @@ def from_datasets( records = [] for row in dataset: - row["inputs"] = cls._parse_inputs_field( - row, cols_to_be_joined.get("inputs") - ) + row["inputs"] = cls._parse_inputs_field(row, cols_to_be_joined.get("inputs")) if row.get("inputs") is not None and row.get("text") is not None: del row["text"] @@ -701,10 +667,7 @@ def from_datasets( if row.get("explanation"): row["explanation"] = ( { - key: [ - TokenAttributions(**tokattr_kwargs) - for tokattr_kwargs in val - ] + key: [TokenAttributions(**tokattr_kwargs) for tokattr_kwargs in val] for key, val in row["explanation"].items() } if row["explanation"] is not None @@ -712,9 +675,7 @@ def from_datasets( ) if cols_to_be_joined.get("metadata"): - row["metadata"] = cls._join_datasets_columns_and_delete( - row, cols_to_be_joined["metadata"] - ) + row["metadata"] = cls._join_datasets_columns_and_delete(row, cols_to_be_joined["metadata"]) records.append(TextClassificationRecord.parse_obj(row)) @@ -766,18 +727,13 @@ def _to_datasets_dict(self) -> Dict: ] elif key == "explanation": ds_dict[key] = [ - { - key: list(map(dict, tokattrs)) - for key, tokattrs in rec.explanation.items() - } + {key: list(map(dict, tokattrs)) for key, tokattrs in rec.explanation.items()} if rec.explanation is not None else None for rec in self._records ] elif key == "id": - ds_dict[key] = [ - None if rec.id is None else str(rec.id) for rec in self._records - ] + ds_dict[key] = [None if rec.id is None else str(rec.id) for rec in self._records] elif key == "metadata": ds_dict[key] = [getattr(rec, key) or None for rec in self._records] else: @@ -787,9 +743,7 @@ def _to_datasets_dict(self) -> Dict: @classmethod def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTextClassification": - return cls( - [TextClassificationRecord(**row) for row in dataframe.to_dict("records")] - ) + return cls([TextClassificationRecord(**row) for row in dataframe.to_dict("records")]) @_requires_datasets def _prepare_for_training_with_transformers( @@ -800,12 +754,7 @@ def _prepare_for_training_with_transformers( ): import datasets - inputs_keys = { - key: None - for rec in self._records - for key in rec.inputs - if rec.annotation is not None - }.keys() + inputs_keys = {key: None for rec in self._records for key in rec.inputs if rec.annotation is not None}.keys() ds_dict = {**{key: [] for key in inputs_keys}, "label": []} for rec in self._records: @@ -832,9 +781,7 @@ def _prepare_for_training_with_transformers( "label": [class_label] if self._records[0].multi_label else class_label, } - ds = datasets.Dataset.from_dict( - ds_dict, features=datasets.Features(feature_dict) - ) + ds = datasets.Dataset.from_dict(ds_dict, features=datasets.Features(feature_dict)) if self._records[0].multi_label: from sklearn.preprocessing import MultiLabelBinarizer @@ -853,16 +800,12 @@ def _prepare_for_training_with_transformers( features=datasets.Features(feature_dict), ) if test_size: - ds = ds.train_test_split( - train_size=train_size, test_size=test_size, seed=seed - ) + ds = ds.train_test_split(train_size=train_size, test_size=test_size, seed=seed) return ds @_requires_spacy - def _prepare_for_training_with_spacy( - self, nlp: "spacy.Language", records: List[Record] - ) -> "spacy.tokens.DocBin": + def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[Record]) -> "spacy.tokens.DocBin": from spacy.tokens import DocBin db = DocBin() @@ -889,9 +832,7 @@ def _prepare_for_training_with_spacy( return db - def _prepare_for_training_with_spark_nlp( - self, records: List[Record] - ) -> "pandas.DataFrame": + def _prepare_for_training_with_spark_nlp(self, records: List[Record]) -> "pandas.DataFrame": if records[0].multi_label: label_name = "labels" else: @@ -1016,9 +957,7 @@ def from_datasets( continue if row.get("tags"): - row["tags"] = cls._parse_datasets_column_with_classlabel( - row["tags"], dataset.features["tags"] - ) + row["tags"] = cls._parse_datasets_column_with_classlabel(row["tags"], dataset.features["tags"]) if row.get("prediction"): row["prediction"] = cls.__entities_to_tuple__(row["prediction"]) @@ -1027,9 +966,7 @@ def from_datasets( row["annotation"] = cls.__entities_to_tuple__(row["annotation"]) if cols_to_be_joined.get("metadata"): - row["metadata"] = cls._join_datasets_columns_and_delete( - row, cols_to_be_joined["metadata"] - ) + row["metadata"] = cls._join_datasets_columns_and_delete(row, cols_to_be_joined["metadata"]) records.append(TokenClassificationRecord.parse_obj(row)) @@ -1062,13 +999,7 @@ def _prepare_for_training_with_transformers( return datasets.Dataset.from_dict({}) class_tags = ["O"] - class_tags.extend( - [ - f"{pre}-{label}" - for label in sorted(self.__all_labels__()) - for pre in ["B", "I"] - ] - ) + class_tags.extend([f"{pre}-{label}" for label in sorted(self.__all_labels__()) for pre in ["B", "I"]]) class_tags = datasets.ClassLabel(names=class_tags) def spans2iob(example): @@ -1078,26 +1009,18 @@ def spans2iob(example): return class_tags.str2int(tags) - ds = ( - self.to_datasets() - .filter(self.__only_annotations__) - .map(lambda example: {"ner_tags": spans2iob(example)}) - ) + ds = self.to_datasets().filter(self.__only_annotations__).map(lambda example: {"ner_tags": spans2iob(example)}) new_features = ds.features.copy() new_features["ner_tags"] = [class_tags] ds = ds.cast(new_features) if train_size or test_size: - ds = ds.train_test_split( - train_size=train_size, test_size=test_size, seed=seed - ) + ds = ds.train_test_split(train_size=train_size, test_size=test_size, seed=seed) return ds @_requires_spacy - def _prepare_for_training_with_spacy( - self, nlp: "spacy.Language", records: List[Record] - ) -> "spacy.tokens.DocBin": + def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[Record]) -> "spacy.tokens.DocBin": from spacy.tokens import DocBin db = DocBin() @@ -1128,9 +1051,7 @@ def _prepare_for_training_with_spacy( return db - def _prepare_for_training_with_spark_nlp( - self, records: List[Record] - ) -> "pandas.DataFrame": + def _prepare_for_training_with_spark_nlp(self, records: List[Record]) -> "pandas.DataFrame": for record in records: if record.id is None: record.id = str(uuid.uuid4()) @@ -1163,9 +1084,7 @@ def _to_datasets_dict(self) -> Dict: # create a dict first, where we make the necessary transformations def entities_to_dict( - entities: Optional[ - List[Union[Tuple[str, int, int, float], Tuple[str, int, int]]] - ] + entities: Optional[List[Union[Tuple[str, int, int, float], Tuple[str, int, int]]]] ) -> Optional[List[Dict[str, Union[str, int, float]]]]: if entities is None: return None @@ -1179,17 +1098,11 @@ def entities_to_dict( ds_dict = {} for key in self._RECORD_TYPE.__fields__: if key == "prediction": - ds_dict[key] = [ - entities_to_dict(rec.prediction) for rec in self._records - ] + ds_dict[key] = [entities_to_dict(rec.prediction) for rec in self._records] elif key == "annotation": - ds_dict[key] = [ - entities_to_dict(rec.annotation) for rec in self._records - ] + ds_dict[key] = [entities_to_dict(rec.annotation) for rec in self._records] elif key == "id": - ds_dict[key] = [ - None if rec.id is None else str(rec.id) for rec in self._records - ] + ds_dict[key] = [None if rec.id is None else str(rec.id) for rec in self._records] elif key == "metadata": ds_dict[key] = [getattr(rec, key) or None for rec in self._records] else: @@ -1210,9 +1123,7 @@ def __entities_to_tuple__( @classmethod def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTokenClassification": - return cls( - [TokenClassificationRecord(**row) for row in dataframe.to_dict("records")] - ) + return cls([TokenClassificationRecord(**row) for row in dataframe.to_dict("records")]) @_prepend_docstring(Text2TextRecord) @@ -1299,9 +1210,7 @@ def from_datasets( row["prediction"] = cls._parse_prediction_field(row["prediction"]) if cols_to_be_joined.get("metadata"): - row["metadata"] = cls._join_datasets_columns_and_delete( - row, cols_to_be_joined["metadata"] - ) + row["metadata"] = cls._join_datasets_columns_and_delete(row, cols_to_be_joined["metadata"]) records.append(Text2TextRecord.parse_obj(row)) @@ -1337,15 +1246,11 @@ def pred_to_dict(pred: Union[str, Tuple[str, float]]): for key in self._RECORD_TYPE.__fields__: if key == "prediction": ds_dict[key] = [ - [pred_to_dict(pred) for pred in rec.prediction] - if rec.prediction is not None - else None + [pred_to_dict(pred) for pred in rec.prediction] if rec.prediction is not None else None for rec in self._records ] elif key == "id": - ds_dict[key] = [ - None if rec.id is None else str(rec.id) for rec in self._records - ] + ds_dict[key] = [None if rec.id is None else str(rec.id) for rec in self._records] elif key == "metadata": ds_dict[key] = [getattr(rec, key) or None for rec in self._records] else: @@ -1371,14 +1276,10 @@ def prepare_for_training(self, **kwargs) -> "datasets.Dataset": raise NotImplementedError -Dataset = Union[ - DatasetForTextClassification, DatasetForTokenClassification, DatasetForText2Text -] +Dataset = Union[DatasetForTextClassification, DatasetForTokenClassification, DatasetForText2Text] -def read_datasets( - dataset: "datasets.Dataset", task: Union[str, TaskType], **kwargs -) -> Dataset: +def read_datasets(dataset: "datasets.Dataset", task: Union[str, TaskType], **kwargs) -> Dataset: """Reads a datasets Dataset and returns a argilla Dataset Args: @@ -1431,9 +1332,7 @@ def read_datasets( return DatasetForTokenClassification.from_datasets(dataset, **kwargs) if task is TaskType.text2text: return DatasetForText2Text.from_datasets(dataset, **kwargs) - raise NotImplementedError( - "Reading a datasets Dataset is not implemented for the given task!" - ) + raise NotImplementedError("Reading a datasets Dataset is not implemented for the given task!") def read_pandas(dataframe: pd.DataFrame, task: Union[str, TaskType]) -> Dataset: @@ -1488,9 +1387,7 @@ def read_pandas(dataframe: pd.DataFrame, task: Union[str, TaskType]) -> Dataset: return DatasetForTokenClassification.from_pandas(dataframe) if task is TaskType.text2text: return DatasetForText2Text.from_pandas(dataframe) - raise NotImplementedError( - "Reading a pandas DataFrame is not implemented for the given task!" - ) + raise NotImplementedError("Reading a pandas DataFrame is not implemented for the given task!") class WrongRecordTypeError(Exception): diff --git a/src/argilla/client/models.py b/src/argilla/client/models.py index 845e7f8eb0..89a274865e 100644 --- a/src/argilla/client/models.py +++ b/src/argilla/client/models.py @@ -44,8 +44,7 @@ class Framework(Enum): @classmethod def _missing_(cls, value): raise ValueError( - f"{value} is not a valid {cls.__name__}, please select one of" - f" {list(cls._value2member_map_.keys())}" + f"{value} is not a valid {cls.__name__}, please select one of" f" {list(cls._value2member_map_.keys())}" ) @@ -68,8 +67,7 @@ def _check_value_length(cls, metadata): message = ( "Some metadata values could exceed the max length. For those cases," " values will be truncated by keeping only the last" - f" {DEFAULT_MAX_KEYWORD_LENGTH} characters. " - + _messages.ARGILLA_METADATA_FIELD_WARNING_MESSAGE + f" {DEFAULT_MAX_KEYWORD_LENGTH} characters. " + _messages.ARGILLA_METADATA_FIELD_WARNING_MESSAGE ) warnings.warn(message, UserWarning) @@ -114,9 +112,7 @@ def _nat_to_none_and_one_to_now(cls, v): @root_validator def _check_and_update_status(cls, values): """Updates the status if an annotation is provided and no status is specified.""" - values["status"] = values.get("status") or ( - "Default" if values.get("annotation") is None else "Validated" - ) + values["status"] = values.get("status") or ("Default" if values.get("annotation") is None else "Validated") return values @@ -261,10 +257,7 @@ def _check_text_and_inputs(cls, values): and values.get("inputs") is not None and values["text"] != values["inputs"].get("text") ): - raise ValueError( - "For a TextClassificationRecord you must provide either 'text' or" - " 'inputs'" - ) + raise ValueError("For a TextClassificationRecord you must provide either 'text' or" " 'inputs'") if values.get("text") is not None: values["inputs"] = dict(text=values["text"]) @@ -333,9 +326,7 @@ class TokenClassificationRecord(_Validators): text: Optional[str] = Field(None, min_length=1) tokens: Optional[Union[List[str], Tuple[str, ...]]] = None - prediction: Optional[ - List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]] - ] = None + prediction: Optional[List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]]] = None prediction_agent: Optional[str] = None annotation: Optional[List[Tuple[str, int, int]]] = None annotation_agent: Optional[str] = None @@ -358,16 +349,10 @@ def __init__( **data, ): if text is None and tokens is None: - raise AssertionError( - "Missing fields: At least one of `text` or `tokens` argument must be" - " provided!" - ) + raise AssertionError("Missing fields: At least one of `text` or `tokens` argument must be" " provided!") if (data.get("annotation") or data.get("prediction")) and text is None: - raise AssertionError( - "Missing field `text`: " - "char level spans must be provided with a raw text sentence" - ) + raise AssertionError("Missing field `text`: " "char level spans must be provided with a raw text sentence") if text is None: text = " ".join(tokens) @@ -392,9 +377,7 @@ def __setattr__(self, name: str, value: Any): raise AttributeError(f"You cannot assign a new value to `{name}`") super().__setattr__(name, value) - def _validate_spans( - self, spans: List[Tuple[str, int, int]] - ) -> List[Tuple[str, int, int]]: + def _validate_spans(self, spans: List[Tuple[str, int, int]]) -> List[Tuple[str, int, int]]: """Validates the entity spans with respect to the tokens. If necessary, also performs an automatic correction of the spans. @@ -427,17 +410,13 @@ def _normalize_tokens(cls, value): @validator("prediction") def _add_default_score( cls, - prediction: Optional[ - List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]] - ], + prediction: Optional[List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]]], ): """Adds the default score to the predictions if it is missing""" if prediction is None: return prediction return [ - (pred[0], pred[1], pred[2], 0.0) - if len(pred) == 3 - else (pred[0], pred[1], pred[2], pred[3] or 0.0) + (pred[0], pred[1], pred[2], 0.0) if len(pred) == 3 else (pred[0], pred[1], pred[2], pred[3] or 0.0) for pred in prediction ] @@ -492,9 +471,7 @@ def token_span(self, token_idx: int) -> Tuple[int, int]: raise IndexError(f"Token id {token_idx} out of bounds") return self._span_utils.token_to_char_idx[token_idx] - def spans2iob( - self, spans: Optional[List[Tuple[str, int, int]]] = None - ) -> Optional[List[str]]: + def spans2iob(self, spans: Optional[List[Tuple[str, int, int]]] = None) -> Optional[List[str]]: """DEPRECATED, please use the ``argilla.utils.SpanUtils.to_tags()`` method.""" warnings.warn( "'spans2iob' is deprecated and will be removed in a future version. Please" @@ -570,9 +547,7 @@ class Text2TextRecord(_Validators): search_keywords: Optional[List[str]] = None @validator("prediction") - def prediction_as_tuples( - cls, prediction: Optional[List[Union[str, Tuple[str, float]]]] - ): + def prediction_as_tuples(cls, prediction: Optional[List[Union[str, Tuple[str, float]]]]): """Preprocess the predictions and wraps them in a tuple if needed""" if prediction is None: return prediction diff --git a/src/argilla/client/sdk/client.py b/src/argilla/client/sdk/client.py index be4d2983b9..9f925b6795 100644 --- a/src/argilla/client/sdk/client.py +++ b/src/argilla/client/sdk/client.py @@ -108,10 +108,7 @@ def __hash__(self): def with_httpx_error_handler(func): def wrap_error(base_url: str): - err_str = ( - f"Your Api endpoint at {base_url} is not available or not" - " responding." - ) + err_str = f"Your Api endpoint at {base_url} is not available or not" " responding." raise BaseClientError(err_str) from None @functools.wraps(func) diff --git a/src/argilla/client/sdk/commons/api.py b/src/argilla/client/sdk/commons/api.py index 3810aebdf1..5c96eff3f8 100644 --- a/src/argilla/client/sdk/commons/api.py +++ b/src/argilla/client/sdk/commons/api.py @@ -50,9 +50,7 @@ } -def build_param_dict( - id_from: Optional[str], limit: Optional[int] -) -> Optional[Dict[str, Union[str, int]]]: +def build_param_dict(id_from: Optional[str], limit: Optional[int]) -> Optional[Dict[str, Union[str, int]]]: params = {} if id_from: params["id_from"] = id_from @@ -64,9 +62,7 @@ def build_param_dict( def bulk( client: AuthenticatedClient, name: str, - json_body: Union[ - TextClassificationBulkData, TokenClassificationBulkData, Text2TextBulkData - ], + json_body: Union[TextClassificationBulkData, TokenClassificationBulkData, Text2TextBulkData], ) -> Response[BulkResponse]: url = f"{client.base_url}/api/datasets/{name}/{_TASK_TO_ENDPOINT[type(json_body)]}:bulk" @@ -100,9 +96,7 @@ async def async_bulk( return build_bulk_response(response, name=name, body=json_body) -def build_bulk_response( - response: httpx.Response, name: str, body: Any -) -> Response[BulkResponse]: +def build_bulk_response(response: httpx.Response, name: str, body: Any) -> Response[BulkResponse]: if 200 <= response.status_code < 400: return Response( status_code=response.status_code, @@ -117,9 +111,7 @@ def build_bulk_response( T = TypeVar("T") -def build_data_response( - response: httpx.Response, data_type: Type[T] -) -> Response[List[T]]: +def build_data_response(response: httpx.Response, data_type: Type[T]) -> Response[List[T]]: if 200 <= response.status_code < 400: parsed_responses = [] for r in response.iter_lines(): diff --git a/src/argilla/client/sdk/commons/errors.py b/src/argilla/client/sdk/commons/errors.py index d3698c404a..ca7dfd056f 100644 --- a/src/argilla/client/sdk/commons/errors.py +++ b/src/argilla/client/sdk/commons/errors.py @@ -26,11 +26,7 @@ def __init__(self, message: str, response: Any): self.response = response def __str__(self): - return ( - f"\nUnexpected response: {self.message}" - "\nResponse content:" - f"\n{self.response}" - ) + return f"\nUnexpected response: {self.message}" "\nResponse content:" f"\n{self.response}" class InputValueError(BaseClientError): @@ -72,8 +68,7 @@ def __init__(self, **ctx): def __str__(self): return ( - f"Argilla server returned an error with http status: {self.HTTP_STATUS}" - + f"\nError details: [{self.ctx}]" + f"Argilla server returned an error with http status: {self.HTTP_STATUS}" + f"\nError details: [{self.ctx}]" ) diff --git a/src/argilla/client/sdk/commons/errors_handler.py b/src/argilla/client/sdk/commons/errors_handler.py index f1292095dd..9b425b24db 100644 --- a/src/argilla/client/sdk/commons/errors_handler.py +++ b/src/argilla/client/sdk/commons/errors_handler.py @@ -29,9 +29,7 @@ ) -def handle_response_error( - response: httpx.Response, parse_response: bool = True, **client_ctx -): +def handle_response_error(response: httpx.Response, parse_response: bool = True, **client_ctx): try: response_content = response.json() if parse_response else {} except JSONDecodeError: diff --git a/src/argilla/client/sdk/commons/models.py b/src/argilla/client/sdk/commons/models.py index 926e892cb1..2dff92592a 100644 --- a/src/argilla/client/sdk/commons/models.py +++ b/src/argilla/client/sdk/commons/models.py @@ -78,9 +78,7 @@ def datetime_to_isoformat(cls, v: Optional[datetime]): def _from_client_vectors(vectors: ClientVectors) -> SdkVectors: sdk_vectors = None if vectors: - sdk_vectors = { - name: VectorInfo(value=vector) for name, vector in vectors.items() - } + sdk_vectors = {name: VectorInfo(value=vector) for name, vector in vectors.items()} return sdk_vectors @staticmethod diff --git a/src/argilla/client/sdk/datasets/api.py b/src/argilla/client/sdk/datasets/api.py index b836b949d0..a254b93974 100644 --- a/src/argilla/client/sdk/datasets/api.py +++ b/src/argilla/client/sdk/datasets/api.py @@ -86,9 +86,7 @@ def delete_dataset( return handle_response_error(response, dataset=name) -def _build_response( - response: httpx.Response, name: str -) -> Response[Union[Dataset, ErrorMessage, HTTPValidationError]]: +def _build_response(response: httpx.Response, name: str) -> Response[Union[Dataset, ErrorMessage, HTTPValidationError]]: if response.status_code == 200: parsed_response = Dataset(**response.json()) return Response( diff --git a/src/argilla/client/sdk/metrics/api.py b/src/argilla/client/sdk/metrics/api.py index 30171dc458..c700c34970 100644 --- a/src/argilla/client/sdk/metrics/api.py +++ b/src/argilla/client/sdk/metrics/api.py @@ -32,9 +32,7 @@ def get_dataset_metrics( client: AuthenticatedClient, name: str, task: str ) -> Response[Union[List[MetricInfo], ErrorMessage, HTTPValidationError]]: - url = "{}/api/datasets/{task}/{name}/metrics".format( - client.base_url, name=name, task=task - ) + url = "{}/api/datasets/{task}/{name}/metrics".format(client.base_url, name=name, task=task) response = httpx.get( url=url, diff --git a/src/argilla/client/sdk/metrics/models.py b/src/argilla/client/sdk/metrics/models.py index 3e272966e1..dbcb60c5c8 100644 --- a/src/argilla/client/sdk/metrics/models.py +++ b/src/argilla/client/sdk/metrics/models.py @@ -22,6 +22,4 @@ class MetricInfo(BaseModel): id: str = Field(description="The metric id") name: str = Field(description="The metric name") - description: Optional[str] = Field( - default=None, description="The metric description" - ) + description: Optional[str] = Field(default=None, description="The metric description") diff --git a/src/argilla/client/sdk/text2text/models.py b/src/argilla/client/sdk/text2text/models.py index f8f563caee..7f067820d1 100644 --- a/src/argilla/client/sdk/text2text/models.py +++ b/src/argilla/client/sdk/text2text/models.py @@ -82,10 +82,7 @@ class Text2TextRecord(CreationText2TextRecord): def to_client(self) -> ClientText2TextRecord: return ClientText2TextRecord( text=self.text, - prediction=[ - (sentence.text, sentence.score) - for sentence in self.prediction.sentences - ] + prediction=[(sentence.text, sentence.score) for sentence in self.prediction.sentences] if self.prediction else None, prediction_agent=self.prediction.agent if self.prediction else None, diff --git a/src/argilla/client/sdk/text_classification/api.py b/src/argilla/client/sdk/text_classification/api.py index 39c603411e..50778ccd85 100644 --- a/src/argilla/client/sdk/text_classification/api.py +++ b/src/argilla/client/sdk/text_classification/api.py @@ -52,9 +52,7 @@ def data( params=params if params else None, json=request.dict() if request else {}, ) as response: - return build_data_response( - response=response, data_type=TextClassificationRecord - ) + return build_data_response(response=response, data_type=TextClassificationRecord) def add_dataset_labeling_rule( @@ -62,9 +60,7 @@ def add_dataset_labeling_rule( name: str, rule: LabelingRule, ) -> Response[Union[LabelingRule, HTTPValidationError, ErrorMessage]]: - url = "{}/api/datasets/{name}/TextClassification/labeling/rules".format( - client.base_url, name=name - ) + url = "{}/api/datasets/{name}/TextClassification/labeling/rules".format(client.base_url, name=name) response = httpx.post( url=url, @@ -118,9 +114,7 @@ def fetch_dataset_labeling_rules( client: AuthenticatedClient, name: str, ) -> Response[Union[List[LabelingRule], HTTPValidationError, ErrorMessage]]: - url = "{}/api/datasets/TextClassification/{name}/labeling/rules".format( - client.base_url, name=name - ) + url = "{}/api/datasets/TextClassification/{name}/labeling/rules".format(client.base_url, name=name) response = httpx.get( url=url, @@ -149,6 +143,4 @@ def dataset_rule_metrics( timeout=client.get_timeout(), ) - return build_typed_response( - response, response_type_class=LabelingRuleMetricsSummary - ) + return build_typed_response(response, response_type_class=LabelingRuleMetricsSummary) diff --git a/src/argilla/client/sdk/text_classification/models.py b/src/argilla/client/sdk/text_classification/models.py index cfe016cb7b..2fc82b6762 100644 --- a/src/argilla/client/sdk/text_classification/models.py +++ b/src/argilla/client/sdk/text_classification/models.py @@ -60,24 +60,15 @@ def from_client(cls, record: ClientTextClassificationRecord): prediction = None if record.prediction is not None: prediction = TextClassificationAnnotation( - labels=[ - ClassPrediction(**{"class": label, "score": score}) - for label, score in record.prediction - ], + labels=[ClassPrediction(**{"class": label, "score": score}) for label, score in record.prediction], agent=record.prediction_agent or MACHINE_NAME, ) annotation = None if record.annotation is not None: - annotation_list = ( - record.annotation - if isinstance(record.annotation, list) - else [record.annotation] - ) + annotation_list = record.annotation if isinstance(record.annotation, list) else [record.annotation] annotation = TextClassificationAnnotation( - labels=[ - ClassPrediction(**{"class": label}) for label in annotation_list - ], + labels=[ClassPrediction(**{"class": label}) for label in annotation_list], agent=record.annotation_agent or MACHINE_NAME, ) @@ -101,11 +92,7 @@ class TextClassificationRecord(CreationTextClassificationRecord): def to_client(self) -> ClientTextClassificationRecord: """Returns the client model""" - annotations = ( - [label.class_label for label in self.annotation.labels] - if self.annotation - else None - ) + annotations = [label.class_label for label in self.annotation.labels] if self.annotation else None if annotations and not self.multi_label: annotations = annotations[0] @@ -116,9 +103,7 @@ def to_client(self) -> ClientTextClassificationRecord: multi_label=self.multi_label, status=self.status, metadata=self.metadata or {}, - prediction=[ - (label.class_label, label.score) for label in self.prediction.labels - ] + prediction=[(label.class_label, label.score) for label in self.prediction.labels] if self.prediction else None, prediction_agent=self.prediction.agent if self.prediction else None, @@ -126,10 +111,7 @@ def to_client(self) -> ClientTextClassificationRecord: annotation_agent=self.annotation.agent if self.annotation else None, vectors=self._to_client_vectors(self.vectors), explanation={ - key: [ - ClientTokenAttributions.parse_obj(attribution) - for attribution in attributions - ] + key: [ClientTokenAttributions.parse_obj(attribution) for attribution in attributions] for key, attributions in self.explanation.items() } if self.explanation diff --git a/src/argilla/client/sdk/token_classification/api.py b/src/argilla/client/sdk/token_classification/api.py index 10a91d0f59..9d9d8909fc 100644 --- a/src/argilla/client/sdk/token_classification/api.py +++ b/src/argilla/client/sdk/token_classification/api.py @@ -36,9 +36,7 @@ def data( request: Optional[TokenClassificationQuery] = None, limit: Optional[int] = None, id_from: Optional[str] = None, -) -> Response[ - Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage] -]: +) -> Response[Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage]]: path = f"/api/datasets/{name}/TokenClassification/data" params = build_param_dict(id_from, limit) @@ -48,6 +46,4 @@ def data( params=params if params else None, json=request.dict() if request else {}, ) as response: - return build_data_response( - response=response, data_type=TokenClassificationRecord - ) + return build_data_response(response=response, data_type=TokenClassificationRecord) diff --git a/src/argilla/client/sdk/token_classification/models.py b/src/argilla/client/sdk/token_classification/models.py index 8a549de39f..0bea10bba2 100644 --- a/src/argilla/client/sdk/token_classification/models.py +++ b/src/argilla/client/sdk/token_classification/models.py @@ -63,9 +63,7 @@ def from_client(cls, record: ClientTokenClassificationRecord): entities=[ EntitySpan(label=ent[0], start=ent[1], end=ent[2]) if len(ent) == 3 - else EntitySpan( - label=ent[0], start=ent[1], end=ent[2], score=ent[3] - ) + else EntitySpan(label=ent[0], start=ent[1], end=ent[2], score=ent[3]) for ent in record.prediction ], agent=record.prediction_agent or MACHINE_NAME, @@ -74,10 +72,7 @@ def from_client(cls, record: ClientTokenClassificationRecord): annotation = None if record.annotation is not None: annotation = TokenClassificationAnnotation( - entities=[ - EntitySpan(label=ent[0], start=ent[1], end=ent[2]) - for ent in record.annotation - ], + entities=[EntitySpan(label=ent[0], start=ent[1], end=ent[2]) for ent in record.annotation], agent=record.annotation_agent or MACHINE_NAME, ) @@ -102,16 +97,11 @@ def to_client(self) -> ClientTokenClassificationRecord: return ClientTokenClassificationRecord( text=self.text, tokens=self.tokens, - prediction=[ - (ent.label, ent.start, ent.end, ent.score) - for ent in self.prediction.entities - ] + prediction=[(ent.label, ent.start, ent.end, ent.score) for ent in self.prediction.entities] if self.prediction else None, prediction_agent=self.prediction.agent if self.prediction else None, - annotation=[ - (ent.label, ent.start, ent.end) for ent in self.annotation.entities - ] + annotation=[(ent.label, ent.start, ent.end) for ent in self.annotation.entities] if self.annotation else None, annotation_agent=self.annotation.agent if self.annotation else None, diff --git a/src/argilla/labeling/text_classification/label_errors.py b/src/argilla/labeling/text_classification/label_errors.py index e9a8583f50..2460c816cf 100644 --- a/src/argilla/labeling/text_classification/label_errors.py +++ b/src/argilla/labeling/text_classification/label_errors.py @@ -95,9 +95,7 @@ def find_label_errors( # select only records with prediction and annotation records = [rec for rec in records if rec.prediction and rec.annotation] if not records: - raise NoRecordsError( - "It seems that none of your records have a prediction AND annotation!" - ) + raise NoRecordsError("It seems that none of your records have a prediction AND annotation!") # check and update kwargs for get_noise_indices _check_and_update_kwargs(cleanlab.__version__, records[0], sort_by, kwargs) @@ -117,9 +115,7 @@ def find_label_errors( return records_with_label_errors -def _check_and_update_kwargs( - version: str, record: TextClassificationRecord, sort_by: SortBy, kwargs: Dict -): +def _check_and_update_kwargs(version: str, record: TextClassificationRecord, sort_by: SortBy, kwargs: Dict): """Helper function to check and update the kwargs passed on to cleanlab's `get_noise_indices`. Args: @@ -140,9 +136,7 @@ def _check_and_update_kwargs( if parse_version(version) < parse_version("2.0"): if "sorted_index_method" in kwargs: - raise ValueError( - "The 'sorted_index_method' kwarg is not supported, please use 'sort_by' instead." - ) + raise ValueError("The 'sorted_index_method' kwarg is not supported, please use 'sort_by' instead.") kwargs["sorted_index_method"] = "normalized_margin" if sort_by is SortBy.PREDICTION: kwargs["sorted_index_method"] = "prob_given_label" @@ -150,9 +144,7 @@ def _check_and_update_kwargs( kwargs["sorted_index_method"] = None else: if "return_indices_ranked_by" in kwargs: - raise ValueError( - "The 'return_indices_ranked_by' kwarg is not supported, please use 'sort_by' instead." - ) + raise ValueError("The 'return_indices_ranked_by' kwarg is not supported, please use 'sort_by' instead.") kwargs["return_indices_ranked_by"] = "normalized_margin" if sort_by is SortBy.PREDICTION: kwargs["return_indices_ranked_by"] = "self_confidence" @@ -188,20 +180,14 @@ def _construct_s_and_psx( labels.update(predictions[-1].keys()) labels_mapping = {label: i for i, label in enumerate(sorted(labels))} - s = ( - np.empty(len(records), dtype=object) - if records[0].multi_label - else np.zeros(len(records), dtype=np.short) - ) + s = np.empty(len(records), dtype=object) if records[0].multi_label else np.zeros(len(records), dtype=np.short) psx = np.zeros((len(records), len(labels)), dtype=np.float) for i, rec, pred in zip(range(len(records)), records, predictions): try: psx[i] = [pred[label] for label in labels_mapping] except KeyError as error: - raise MissingPredictionError( - f"It seems a prediction for {error} is missing in the following record: {rec}" - ) + raise MissingPredictionError(f"It seems a prediction for {error} is missing in the following record: {rec}") try: s[i] = ( @@ -210,9 +196,7 @@ def _construct_s_and_psx( else labels_mapping[rec.annotation] ) except KeyError as error: - raise MissingPredictionError( - f"It seems predictions are missing for the label {error}!" - ) + raise MissingPredictionError(f"It seems predictions are missing for the label {error}!") return s, psx diff --git a/src/argilla/labeling/text_classification/label_models.py b/src/argilla/labeling/text_classification/label_models.py index 7186c46b9e..705894549b 100644 --- a/src/argilla/labeling/text_classification/label_models.py +++ b/src/argilla/labeling/text_classification/label_models.py @@ -35,8 +35,7 @@ class TieBreakPolicy(Enum): @classmethod def _missing_(cls, value): raise ValueError( - f"{value} is not a valid {cls.__name__}, please select one of" - f" {list(cls._value2member_map_.keys())}" + f"{value} is not a valid {cls.__name__}, please select one of" f" {list(cls._value2member_map_.keys())}" ) @@ -133,12 +132,8 @@ def predict( Returns: A dataset of records that include the predictions of the label model. """ - wl_matrix = self._weak_labels.matrix( - has_annotation=None if include_annotated_records else False - ) - records = self._weak_labels.records( - has_annotation=None if include_annotated_records else False - ) + wl_matrix = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False) + records = self._weak_labels.records(has_annotation=None if include_annotated_records else False) assert records, ValueError( "No records are being passed. Use `include_annotated_records` to include" @@ -178,9 +173,7 @@ def _compute_single_label_probs(self, wl_matrix: np.ndarray) -> np.ndarray: """ counts = np.column_stack( [ - np.count_nonzero( - wl_matrix == self._weak_labels.label2int[label], axis=1 - ) + np.count_nonzero(wl_matrix == self._weak_labels.label2int[label], axis=1) for label in self._weak_labels.labels ] ) @@ -228,34 +221,22 @@ def _make_single_label_records( tie = True # maybe skip record - if not include_abstentions and ( - tie and tie_break_policy is TieBreakPolicy.ABSTAIN - ): + if not include_abstentions and (tie and tie_break_policy is TieBreakPolicy.ABSTAIN): continue if not tie: - pred_for_rec = [ - (self._weak_labels.labels[idx], prob[idx]) - for idx in np.argsort(prob)[::-1] - ] + pred_for_rec = [(self._weak_labels.labels[idx], prob[idx]) for idx in np.argsort(prob)[::-1]] # resolve ties following the tie break policy elif tie_break_policy is TieBreakPolicy.ABSTAIN: pred_for_rec = None elif tie_break_policy is TieBreakPolicy.RANDOM: - random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len( - equal_prob_idx - ) + random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(equal_prob_idx) for idx in equal_prob_idx: if idx == random_idx: prob[idx] += self._PROBABILITY_INCREASE_ON_TIE_BREAK else: - prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / ( - len(equal_prob_idx) - 1 - ) - pred_for_rec = [ - (self._weak_labels.labels[idx], prob[idx]) - for idx in np.argsort(prob)[::-1] - ] + prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / (len(equal_prob_idx) - 1) + pred_for_rec = [(self._weak_labels.labels[idx], prob[idx]) for idx in np.argsort(prob)[::-1]] else: raise NotImplementedError( f"The tie break policy '{tie_break_policy.value}' is not" @@ -321,10 +302,7 @@ def _make_multi_label_records( pred_for_rec = None if not all_abstained: - pred_for_rec = [ - (self._weak_labels.labels[i], prob[i]) - for i in np.argsort(prob)[::-1] - ] + pred_for_rec = [(self._weak_labels.labels[i], prob[i]) for i in np.argsort(prob)[::-1]] records_with_prediction.append(rec.copy(deep=True)) records_with_prediction[-1].prediction = pred_for_rec @@ -388,9 +366,7 @@ def score( probabilities = self._compute_single_label_probs(wl_matrix) - annotation, prediction = self._score_single_label( - probabilities, tie_break_policy - ) + annotation, prediction = self._score_single_label(probabilities, tie_break_policy) target_names = self._weak_labels.labels[: annotation.max() + 1] return classification_report( @@ -419,18 +395,13 @@ def _score_single_label( A tuple of the annotation and prediction array. """ # 1.e-8 is taken from the abs tolerance of np.isclose - is_max = ( - np.abs(probabilities.max(axis=1, keepdims=True) - probabilities) < 1.0e-8 - ) + is_max = np.abs(probabilities.max(axis=1, keepdims=True) - probabilities) < 1.0e-8 is_tie = is_max.sum(axis=1) > 1 prediction = np.argmax(is_max, axis=1) # we need to transform the indexes! annotation = np.array( - [ - self._weak_labels.labels.index(self._weak_labels.int2label[i]) - for i in self._weak_labels.annotation() - ], + [self._weak_labels.labels.index(self._weak_labels.int2label[i]) for i in self._weak_labels.annotation()], dtype=np.short, ) @@ -442,21 +413,16 @@ def _score_single_label( elif tie_break_policy is TieBreakPolicy.RANDOM: for i in np.nonzero(is_tie)[0]: equal_prob_idx = np.nonzero(is_max[i])[0] - random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len( - equal_prob_idx - ) + random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(equal_prob_idx) prediction[i] = equal_prob_idx[random_idx] else: raise NotImplementedError( - f"The tie break policy '{tie_break_policy.value}' is not implemented" - " for MajorityVoter!" + f"The tie break policy '{tie_break_policy.value}' is not implemented" " for MajorityVoter!" ) return annotation, prediction - def _score_multi_label( - self, probabilities: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray]: + def _score_multi_label(self, probabilities: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Helper method to compute scores for multi-label classifications. Args: @@ -496,9 +462,7 @@ class Snorkel(LabelModel): >>> records = label_model.predict() """ - def __init__( - self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu" - ): + def __init__(self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu"): try: import snorkel # noqa: F401 except ModuleNotFoundError: @@ -513,23 +477,18 @@ def __init__( # Check if we need to remap the weak labels to int mapping # Snorkel expects the abstain id to be -1 and the rest of the labels to be sequential - if self._weak_labels.label2int[None] != -1 or sorted( - self._weak_labels.int2label - ) != list(range(-1, self._weak_labels.cardinality)): + if self._weak_labels.label2int[None] != -1 or sorted(self._weak_labels.int2label) != list( + range(-1, self._weak_labels.cardinality) + ): self._need_remap = True self._weaklabelsInt2snorkelInt = { - self._weak_labels.label2int[label]: i - for i, label in enumerate([None] + self._weak_labels.labels, -1) + self._weak_labels.label2int[label]: i for i, label in enumerate([None] + self._weak_labels.labels, -1) } else: self._need_remap = False - self._weaklabelsInt2snorkelInt = { - i: i for i in range(-1, self._weak_labels.cardinality) - } + self._weaklabelsInt2snorkelInt = {i: i for i in range(-1, self._weak_labels.cardinality)} - self._snorkelInt2weaklabelsInt = { - val: key for key, val in self._weaklabelsInt2snorkelInt.items() - } + self._snorkelInt2weaklabelsInt = {val: key for key, val in self._weaklabelsInt2snorkelInt.items()} # instantiate Snorkel's label model self._model = SnorkelLabelModel( @@ -548,13 +507,9 @@ def fit(self, include_annotated_records: bool = False, **kwargs): They must not contain ``L_train``, the label matrix is provided automatically. """ if "L_train" in kwargs: - raise ValueError( - "Your kwargs must not contain 'L_train', it is provided automatically." - ) + raise ValueError("Your kwargs must not contain 'L_train', it is provided automatically.") - l_train = self._weak_labels.matrix( - has_annotation=None if include_annotated_records else False - ) + l_train = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False) if self._need_remap: l_train = self._copy_and_remap(l_train) @@ -614,9 +569,7 @@ def predict( if isinstance(tie_break_policy, str): tie_break_policy = TieBreakPolicy(tie_break_policy) - l_pred = self._weak_labels.matrix( - has_annotation=None if include_annotated_records else False - ) + l_pred = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False) if self._need_remap: l_pred = self._copy_and_remap(l_pred) @@ -630,9 +583,7 @@ def predict( # add predictions to records records_with_prediction = [] for rec, pred, prob in zip( - self._weak_labels.records( - has_annotation=None if include_annotated_records else False - ), + self._weak_labels.records(has_annotation=None if include_annotated_records else False), predictions, probabilities, ): @@ -652,15 +603,11 @@ def predict( if idx == pred: prob[idx] += self._PROBABILITY_INCREASE_ON_TIE_BREAK else: - prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / ( - len(equal_prob_idx) - 1 - ) + prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / (len(equal_prob_idx) - 1) pred_for_rec = [ ( - self._weak_labels.int2label[ - self._snorkelInt2weaklabelsInt[snorkel_idx] - ], + self._weak_labels.int2label[self._snorkelInt2weaklabelsInt[snorkel_idx]], prob[snorkel_idx], ) for snorkel_idx in np.argsort(prob)[::-1] @@ -713,8 +660,7 @@ def score( if self._weak_labels.annotation().size == 0: raise MissingAnnotationError( - "You need annotated records to compute scores/metrics for your label" - " model." + "You need annotated records to compute scores/metrics for your label" " model." ) l_pred = self._weak_labels.matrix(has_annotation=True) @@ -779,14 +725,10 @@ def __init__(self, weak_labels: WeakLabels, **kwargs): super().__init__(weak_labels) if len(self._weak_labels.rules) < 3: - raise TooFewRulesError( - "The FlyingSquid label model needs at least three (independent) rules!" - ) + raise TooFewRulesError("The FlyingSquid label model needs at least three (independent) rules!") if "m" in kwargs: - raise ValueError( - "Your kwargs must not contain 'm', it is provided automatically." - ) + raise ValueError("Your kwargs must not contain 'm', it is provided automatically.") self._init_kwargs = kwargs self._models: List[FlyingSquidLabelModel] = [] @@ -800,21 +742,15 @@ def fit(self, include_annotated_records: bool = False, **kwargs): `LabelModel.fit() `__ method. """ - wl_matrix = self._weak_labels.matrix( - has_annotation=None if include_annotated_records else False - ) + wl_matrix = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False) models = [] # create a label model for each label (except for binary classification) # much of the implementation is taken from wrench: # https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/flyingsquid.py # If binary, we only need one model - for i in range( - 1 if self._weak_labels.cardinality == 2 else self._weak_labels.cardinality - ): - model = self._FlyingSquidLabelModel( - m=len(self._weak_labels.rules), **self._init_kwargs - ) + for i in range(1 if self._weak_labels.cardinality == 2 else self._weak_labels.cardinality): + model = self._FlyingSquidLabelModel(m=len(self._weak_labels.rules), **self._init_kwargs) wl_matrix_i = self._copy_and_transform_wl_matrix(wl_matrix, i) model.fit(L_train=wl_matrix_i, **kwargs) models.append(model) @@ -839,9 +775,7 @@ def _copy_and_transform_wl_matrix(self, weak_label_matrix: np.ndarray, i: int): """ wl_matrix_i = weak_label_matrix.copy() - target_mask = ( - wl_matrix_i == self._weak_labels.label2int[self._weak_labels.labels[i]] - ) + target_mask = wl_matrix_i == self._weak_labels.label2int[self._weak_labels.labels[i]] abstain_mask = wl_matrix_i == self._weak_labels.label2int[None] other_mask = (~target_mask) & (~abstain_mask) @@ -883,9 +817,7 @@ def predict( if isinstance(tie_break_policy, str): tie_break_policy = TieBreakPolicy(tie_break_policy) - wl_matrix = self._weak_labels.matrix( - has_annotation=None if include_annotated_records else False - ) + wl_matrix = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False) probabilities = self._predict(wl_matrix, verbose) # add predictions to records @@ -893,9 +825,7 @@ def predict( for i, prob, rec in zip( range(len(probabilities)), probabilities, - self._weak_labels.records( - has_annotation=None if include_annotated_records else False - ), + self._weak_labels.records(has_annotation=None if include_annotated_records else False), ): # Check if model abstains, that is if the highest probability is assigned to more than one label # 1.e-8 is taken from the abs tolerance of np.isclose @@ -905,38 +835,25 @@ def predict( tie = True # maybe skip record - if not include_abstentions and ( - tie and tie_break_policy is TieBreakPolicy.ABSTAIN - ): + if not include_abstentions and (tie and tie_break_policy is TieBreakPolicy.ABSTAIN): continue if not tie: - pred_for_rec = [ - (self._weak_labels.labels[i], prob[i]) - for i in np.argsort(prob)[::-1] - ] + pred_for_rec = [(self._weak_labels.labels[i], prob[i]) for i in np.argsort(prob)[::-1]] # resolve ties following the tie break policy elif tie_break_policy is TieBreakPolicy.ABSTAIN: pred_for_rec = None elif tie_break_policy is TieBreakPolicy.RANDOM: - random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len( - equal_prob_idx - ) + random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(equal_prob_idx) for idx in equal_prob_idx: if idx == random_idx: prob[idx] += self._PROBABILITY_INCREASE_ON_TIE_BREAK else: - prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / ( - len(equal_prob_idx) - 1 - ) - pred_for_rec = [ - (self._weak_labels.labels[i], prob[i]) - for i in np.argsort(prob)[::-1] - ] + prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / (len(equal_prob_idx) - 1) + pred_for_rec = [(self._weak_labels.labels[i], prob[i]) for i in np.argsort(prob)[::-1]] else: raise NotImplementedError( - f"The tie break policy '{tie_break_policy.value}' is not" - " implemented for FlyingSquid!" + f"The tie break policy '{tie_break_policy.value}' is not" " implemented for FlyingSquid!" ) records_with_prediction.append(rec.copy(deep=True)) @@ -962,26 +879,19 @@ def _predict(self, weak_label_matrix: np.ndarray, verbose: bool) -> np.ndarray: NotFittedError: If the label model was still not fitted. """ if not self._models: - raise NotFittedError( - "This FlyingSquid instance is not fitted yet. Call `fit` before using" - " this model." - ) + raise NotFittedError("This FlyingSquid instance is not fitted yet. Call `fit` before using" " this model.") # create predictions for each label if self._weak_labels.cardinality > 2: probas = np.zeros((len(weak_label_matrix), self._weak_labels.cardinality)) for i in range(self._weak_labels.cardinality): wl_matrix_i = self._copy_and_transform_wl_matrix(weak_label_matrix, i) - probas[:, i] = self._models[i].predict_proba( - L_matrix=wl_matrix_i, verbose=verbose - )[:, 0] + probas[:, i] = self._models[i].predict_proba(L_matrix=wl_matrix_i, verbose=verbose)[:, 0] probas = np.nan_to_num(probas, nan=-np.inf) # handle NaN probas = np.exp(probas) / np.sum(np.exp(probas), axis=1, keepdims=True) # if binary, we only have one model else: wl_matrix_i = self._copy_and_transform_wl_matrix(weak_label_matrix, 0) - probas = self._models[0].predict_proba( - L_matrix=wl_matrix_i, verbose=verbose - ) + probas = self._models[0].predict_proba(L_matrix=wl_matrix_i, verbose=verbose) return probas @@ -1038,18 +948,13 @@ def score( probabilities = self._predict(wl_matrix, verbose) # 1.e-8 is taken from the abs tolerance of np.isclose - is_max = ( - np.abs(probabilities.max(axis=1, keepdims=True) - probabilities) < 1.0e-8 - ) + is_max = np.abs(probabilities.max(axis=1, keepdims=True) - probabilities) < 1.0e-8 is_tie = is_max.sum(axis=1) > 1 prediction = np.argmax(is_max, axis=1) # we need to transform the indexes! annotation = np.array( - [ - self._weak_labels.labels.index(self._weak_labels.int2label[i]) - for i in self._weak_labels.annotation() - ], + [self._weak_labels.labels.index(self._weak_labels.int2label[i]) for i in self._weak_labels.annotation()], dtype=np.short, ) @@ -1061,14 +966,11 @@ def score( elif tie_break_policy is TieBreakPolicy.RANDOM: for i in np.nonzero(is_tie)[0]: equal_prob_idx = np.nonzero(is_max[i])[0] - random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len( - equal_prob_idx - ) + random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(equal_prob_idx) prediction[i] = equal_prob_idx[random_idx] else: raise NotImplementedError( - f"The tie break policy '{tie_break_policy.value}' is not implemented" - " for FlyingSquid!" + f"The tie break policy '{tie_break_policy.value}' is not implemented" " for FlyingSquid!" ) return classification_report( diff --git a/src/argilla/labeling/text_classification/rule.py b/src/argilla/labeling/text_classification/rule.py index 658e63f7fc..b845882863 100644 --- a/src/argilla/labeling/text_classification/rule.py +++ b/src/argilla/labeling/text_classification/rule.py @@ -89,22 +89,16 @@ def _convert_to_labeling_rule(self): def add_to_dataset(self, dataset: str): """Add to rule to the given dataset""" - api.active_api().add_dataset_labeling_rules( - dataset, rules=[self._convert_to_labeling_rule()] - ) + api.active_api().add_dataset_labeling_rules(dataset, rules=[self._convert_to_labeling_rule()]) def remove_from_dataset(self, dataset: str): """Removes the rule from the given dataset""" - api.active_api().delete_dataset_labeling_rules( - dataset, rules=[self._convert_to_labeling_rule()] - ) + api.active_api().delete_dataset_labeling_rules(dataset, rules=[self._convert_to_labeling_rule()]) def update_at_dataset(self, dataset: str): """Updates the rule at the given dataset""" - api.active_api().update_dataset_labeling_rules( - dataset, rules=[self._convert_to_labeling_rule()] - ) + api.active_api().update_dataset_labeling_rules(dataset, rules=[self._convert_to_labeling_rule()]) def apply(self, dataset: str): """Apply the rule to a dataset and save matching ids of the records. @@ -140,15 +134,11 @@ def metrics(self, dataset: str) -> Dict[str, Union[int, float]]: "coverage": metrics.coverage, "annotated_coverage": metrics.coverage_annotated, "correct": int(metrics.correct) if metrics.correct is not None else None, - "incorrect": int(metrics.incorrect) - if metrics.incorrect is not None - else None, + "incorrect": int(metrics.incorrect) if metrics.incorrect is not None else None, "precision": metrics.precision if metrics.precision is not None else None, } - def __call__( - self, record: TextClassificationRecord - ) -> Optional[Union[str, List[str]]]: + def __call__(self, record: TextClassificationRecord) -> Optional[Union[str, List[str]]]: """Check if the given record is among the matching ids from the ``self.apply`` call. Args: @@ -161,9 +151,7 @@ def __call__( RuleNotAppliedError: If the rule was not applied to the dataset before. """ if self._matching_ids is None: - raise RuleNotAppliedError( - "Rule was still not applied. Please call `self.apply(dataset)` first." - ) + raise RuleNotAppliedError("Rule was still not applied. Please call `self.apply(dataset)` first.") try: self._matching_ids[record.id] diff --git a/src/argilla/labeling/text_classification/weak_labels.py b/src/argilla/labeling/text_classification/weak_labels.py index 21fbeeaf45..0cc6c11615 100644 --- a/src/argilla/labeling/text_classification/weak_labels.py +++ b/src/argilla/labeling/text_classification/weak_labels.py @@ -62,16 +62,12 @@ def __init__( query: Optional[str] = None, ): if not isinstance(dataset, str): - raise TypeError( - f"The name of the dataset must be a string, but you provided: {dataset}" - ) + raise TypeError(f"The name of the dataset must be a string, but you provided: {dataset}") self._dataset = dataset self._rules = rules or load_rules(dataset) if not self._rules: - raise NoRulesFoundError( - f"No rules were found in the given dataset '{dataset}'" - ) + raise NoRulesFoundError(f"No rules were found in the given dataset '{dataset}'") self._rules_index2name = { # covers our Rule class, snorkel's LabelingFunction class and arbitrary methods @@ -94,14 +90,10 @@ def __init__( f"Following rule names are duplicated x times: { {key: val for key, val in counts.items() if val > 1} }" " Please make sure to provide unique rule names." ) - self._rules_name2index = { - val: key for key, val in self._rules_index2name.items() - } + self._rules_name2index = {val: key for key, val in self._rules_index2name.items()} # load records and check compatibility - self._records: DatasetForTextClassification = load( - dataset, query=query, ids=ids - ) + self._records: DatasetForTextClassification = load(dataset, query=query, ids=ids) if not self._records: raise NoRecordsFoundError( f"No records found in dataset '{dataset}'" @@ -127,9 +119,7 @@ def cardinality(self) -> int: """The number of labels.""" raise NotImplementedError - def records( - self, has_annotation: Optional[bool] = None - ) -> List[TextClassificationRecord]: + def records(self, has_annotation: Optional[bool] = None) -> List[TextClassificationRecord]: """Returns the records corresponding to the weak label matrix Args: @@ -228,15 +218,10 @@ def _compute_overlaps_conflicts( Array of fractions of overlaps/conflicts for each rule, optionally normalized by their coverages. """ overlaps_or_conflicts = ( - has_weak_label - * np.repeat(has_overlaps_or_conflicts, len(self._rules)).reshape( - has_weak_label.shape - ) + has_weak_label * np.repeat(has_overlaps_or_conflicts, len(self._rules)).reshape(has_weak_label.shape) ).sum(axis=0) / len(self._records) # total - overlaps_or_conflicts = np.append( - overlaps_or_conflicts, has_overlaps_or_conflicts.sum() / len(self._records) - ) + overlaps_or_conflicts = np.append(overlaps_or_conflicts, has_overlaps_or_conflicts.sum() / len(self._records)) if normalize_by_coverage: # ignore division by 0 warnings, as we convert the nan back to 0.0 afterwards @@ -291,17 +276,15 @@ def extend_matrix( np.copy(embeddings).astype(np.float32), abstains, supports, gpu=gpu ) elif self._extension_queries is None: - raise ValueError( - "Embeddings are not optional the first time a matrix is extended." - ) + raise ValueError("Embeddings are not optional the first time a matrix is extended.") dists, nearest = self._extension_queries self._extended_matrix = np.copy(self._matrix) new_points = [(dists[i] > thresholds[i]) for i in range(self._matrix.shape[1])] for i in range(self._matrix.shape[1]): - self._extended_matrix[abstains[i][new_points[i]], i] = self._matrix[ - supports[i], i - ][nearest[i][new_points[i]]] + self._extended_matrix[abstains[i][new_points[i]], i] = self._matrix[supports[i], i][ + nearest[i][new_points[i]] + ] self._extend_matrix_postprocess() @@ -323,15 +306,11 @@ def _find_dists_and_nearest( faiss.normalize_L2(embeddings) embeddings_length = embeddings.shape[1] - label_fn_indexes = [ - faiss.IndexFlatIP(embeddings_length) for i in range(self._matrix.shape[1]) - ] + label_fn_indexes = [faiss.IndexFlatIP(embeddings_length) for i in range(self._matrix.shape[1])] if gpu: res = faiss.StandardGpuResources() - label_fn_indexes = [ - faiss.index_cpu_to_gpu(res, 0, x) for x in label_fn_indexes - ] + label_fn_indexes = [faiss.index_cpu_to_gpu(res, 0, x) for x in label_fn_indexes] for i in range(self._matrix.shape[1]): label_fn_indexes[i].add(embeddings[support[i]]) @@ -342,12 +321,8 @@ def _find_dists_and_nearest( faiss.normalize_L2(embs_query) dists_and_nearest.append(label_fn_indexes[i].search(embs_query, 1)) - dists = [ - dist_and_nearest[0].flatten() for dist_and_nearest in dists_and_nearest - ] - nearest = [ - dist_and_nearest[1].flatten() for dist_and_nearest in dists_and_nearest - ] + dists = [dist_and_nearest[0].flatten() for dist_and_nearest in dists_and_nearest] + nearest = [dist_and_nearest[1].flatten() for dist_and_nearest in dists_and_nearest] return dists, nearest @@ -453,9 +428,7 @@ def _apply_rules( rule.apply(self._dataset) # create weak label matrix, annotation array, final label2int - weak_label_matrix = np.empty( - (len(self._records), len(self._rules)), dtype=np.short - ) + weak_label_matrix = np.empty((len(self._records), len(self._rules)), dtype=np.short) annotation_array = np.empty(len(self._records), dtype=np.short) _label2int = {None: -1} if label2int is None else label2int if None not in _label2int: @@ -463,9 +436,7 @@ def _apply_rules( "Your provided `label2int` mapping does not contain the required abstention label `None`." ) - for n, record in tqdm( - enumerate(self._records), total=len(self._records), desc="Applying rules" - ): + for n, record in tqdm(enumerate(self._records), total=len(self._records), desc="Applying rules"): # FIRST: fill annotation array try: annotation = _label2int[record.annotation] @@ -535,9 +506,7 @@ def matrix(self, has_annotation: Optional[bool] = None) -> np.ndarray: Returns: The weak label matrix, or optionally just a part of it. """ - matrix = ( - self._matrix if self._extended_matrix is None else self._extended_matrix - ) + matrix = self._matrix if self._extended_matrix is None else self._extended_matrix if has_annotation is True: return matrix[self._annotation != self._label2int[None]] @@ -606,9 +575,7 @@ def summary( polarity = [ set( self._int2label[integer] - for integer in np.unique( - self.matrix()[:, i][self.matrix()[:, i] != self._label2int[None]] - ) + for integer in np.unique(self.matrix()[:, i][self.matrix()[:, i] != self._label2int[None]]) ) for i in range(len(self._rules)) ] @@ -623,9 +590,7 @@ def summary( # overlaps has_overlaps = has_weak_label.sum(axis=1) > 1 - overlaps = self._compute_overlaps_conflicts( - has_weak_label, has_overlaps, coverage, normalize_by_coverage - ) + overlaps = self._compute_overlaps_conflicts(has_weak_label, has_overlaps, coverage, normalize_by_coverage) # conflicts # TODO: For a lot of records (~1e6), this could become slow (~10s) ... a vectorized solution would be better. @@ -634,9 +599,7 @@ def summary( axis=1, arr=self.matrix(), ) - conflicts = self._compute_overlaps_conflicts( - has_weak_label, has_conflicts, coverage, normalize_by_coverage - ) + conflicts = self._compute_overlaps_conflicts(has_weak_label, has_conflicts, coverage, normalize_by_coverage) # index for the summary index = list(self._rules_name2index.keys()) + ["total"] @@ -645,13 +608,10 @@ def summary( has_annotation = annotation != self._label2int[None] if any(has_annotation): # annotated coverage - annotated_coverage = ( - has_weak_label[has_annotation].sum(axis=0) / has_annotation.sum() - ) + annotated_coverage = has_weak_label[has_annotation].sum(axis=0) / has_annotation.sum() annotated_coverage = np.append( annotated_coverage, - (has_weak_label[has_annotation].sum(axis=1) > 0).sum() - / has_annotation.sum(), + (has_weak_label[has_annotation].sum(axis=1) > 0).sum() / has_annotation.sum(), ) # correct/incorrect @@ -692,9 +652,7 @@ def _compute_correct_incorrect( self, has_weak_label: np.ndarray, annotation: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: """Helper method to compute the correctly and incorrectly predicted annotations by the rules""" - annotation_matrix = np.repeat(annotation, len(self._rules)).reshape( - self.matrix().shape - ) + annotation_matrix = np.repeat(annotation, len(self._rules)).reshape(self.matrix().shape) # correct correct_with_abstain = annotation_matrix == self.matrix() @@ -737,13 +695,8 @@ def show_records( # get rule mask if rules is not None: - rules = [ - self._rules_name2index[rule] if isinstance(rule, str) else rule - for rule in rules - ] - idx_by_rules = (self.matrix()[:, rules] != self._label2int[None]).sum( - axis=1 - ) == len(rules) + rules = [self._rules_name2index[rule] if isinstance(rule, str) else rule for rule in rules] + idx_by_rules = (self.matrix()[:, rules] != self._label2int[None]).sum(axis=1) == len(rules) else: idx_by_rules = np.ones_like(self._records).astype(bool) @@ -767,9 +720,7 @@ def change_mapping(self, label2int: Dict[str, int]): for label in self._label2int: # Check new label2int mapping if label not in label2int: - raise MissingLabelError( - f"The label '{label}' is missing in the new mapping." - ) + raise MissingLabelError(f"The label '{label}' is missing in the new mapping.") # compute masks label_masks[label] = self.matrix() == self._label2int[label] annotation_masks[label] = self._annotation == self._label2int[label] @@ -794,22 +745,18 @@ def extend_matrix( def _extend_matrix_preprocess(self) -> Tuple[List[np.ndarray], List[np.ndarray]]: abstains = [ - np.argwhere(self._matrix[:, i] == self._label2int[None]).flatten() - for i in range(self._matrix.shape[1]) + np.argwhere(self._matrix[:, i] == self._label2int[None]).flatten() for i in range(self._matrix.shape[1]) ] supports = [ - np.argwhere(self._matrix[:, i] != self._label2int[None]).flatten() - for i in range(self._matrix.shape[1]) + np.argwhere(self._matrix[:, i] != self._label2int[None]).flatten() for i in range(self._matrix.shape[1]) ] return abstains, supports def _extend_matrix_postprocess(self): """Keeps the rows of the original weak label matrix, for which at least on rule did not abstain.""" - recs_with_votes = np.argwhere( - (self._matrix != self._label2int[None]).sum(-1) > 0 - ).flatten() + recs_with_votes = np.argwhere((self._matrix != self._label2int[None]).sum(-1) > 0).flatten() self._extended_matrix[recs_with_votes] = self._matrix[recs_with_votes] @@ -866,26 +813,16 @@ def _apply_rules(self) -> Tuple[np.ndarray, np.ndarray, List[str]]: # we make two passes over the records: # FIRST: Get labels from rules and annotations annotations, weak_labels = [], [] - for record in tqdm( - self._records, total=len(self._records), desc="Applying rules" - ): - annotations.append( - record.annotation - if isinstance(record.annotation, list) - else [record.annotation] - ) + for record in tqdm(self._records, total=len(self._records), desc="Applying rules"): + annotations.append(record.annotation if isinstance(record.annotation, list) else [record.annotation]) weak_labels.append([np.atleast_1d(rule(record)) for rule in self._rules]) annotation_set = {ann for anns in annotations for ann in anns} - weak_label_set = { - wl for wl_record in weak_labels for wl_rule in wl_record for wl in wl_rule - } + weak_label_set = {wl for wl_record in weak_labels for wl_rule in wl_record for wl in wl_rule} labels = sorted(list(annotation_set.union(weak_label_set) - {None})) # create weak label matrix (3D), annotation matrix - weak_label_matrix = np.empty( - (len(self._records), len(self._rules), len(labels)), dtype=np.byte - ) + weak_label_matrix = np.empty((len(self._records), len(self._rules), len(labels)), dtype=np.byte) annotation_matrix = np.empty((len(self._records), len(labels)), dtype=np.byte) # SECOND: Fill arrays with weak labels @@ -939,9 +876,7 @@ def matrix(self, has_annotation: Optional[bool] = None) -> np.ndarray: Returns: The 3 dimensional weak label matrix, or optionally just a part of it. """ - matrix = ( - self._matrix if self._extended_matrix is None else self._extended_matrix - ) + matrix = self._matrix if self._extended_matrix is None else self._extended_matrix if has_annotation is True: return matrix[self._annotation.sum(1) >= 0] @@ -1023,9 +958,7 @@ def summary( # overlaps has_overlaps = has_weak_label.sum(axis=1) > 1 - overlaps = self._compute_overlaps_conflicts( - has_weak_label, has_overlaps, coverage, normalize_by_coverage - ) + overlaps = self._compute_overlaps_conflicts(has_weak_label, has_overlaps, coverage, normalize_by_coverage) # index for the summary index = list(self._rules_name2index.keys()) + ["total"] @@ -1034,13 +967,10 @@ def summary( has_annotation = annotation.sum(1) >= 0 if any(has_annotation): # annotated coverage - annotated_coverage = ( - has_weak_label[has_annotation].sum(axis=0) / has_annotation.sum() - ) + annotated_coverage = has_weak_label[has_annotation].sum(axis=0) / has_annotation.sum() annotated_coverage = np.append( annotated_coverage, - (has_weak_label[has_annotation].sum(axis=1) > 0).sum() - / has_annotation.sum(), + (has_weak_label[has_annotation].sum(axis=1) > 0).sum() / has_annotation.sum(), ) # correct/incorrect @@ -1072,24 +1002,16 @@ def summary( index=index, ) - def _compute_correct_incorrect( - self, annotation: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray]: + def _compute_correct_incorrect(self, annotation: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Helper method to compute the correctly and incorrectly predicted annotations by the rules""" # transform annotation to tensor - annotation = np.repeat(annotation, len(self._rules), axis=0).reshape( - self.matrix().shape - ) + annotation = np.repeat(annotation, len(self._rules), axis=0).reshape(self.matrix().shape) # correct, we don't want to count the "correct non predictions" correct = ((annotation == self.matrix()) & (self.matrix() == 1)).sum(2).sum(0) # incorrect, we don't want to count the "misses", since we focus on precision, not recall - incorrect = ( - ((annotation != self.matrix()) & (self.matrix() == 1) & (annotation != -1)) - .sum(2) - .sum(0) - ) + incorrect = ((annotation != self.matrix()) & (self.matrix() == 1) & (annotation != -1)).sum(2).sum(0) # add totals at the end return np.append(correct, correct.sum()), np.append(incorrect, incorrect.sum()) @@ -1120,13 +1042,8 @@ def show_records( # get rule mask if rules is not None: - rules = [ - self._rules_name2index[rule] if isinstance(rule, str) else rule - for rule in rules - ] - idx_by_rules = (self.matrix()[:, rules, :].sum(axis=2) >= 0).sum( - axis=1 - ) == len(rules) + rules = [self._rules_name2index[rule] if isinstance(rule, str) else rule for rule in rules] + idx_by_rules = (self.matrix()[:, rules, :].sum(axis=2) >= 0).sum(axis=1) == len(rules) else: idx_by_rules = np.ones_like(self._records).astype(bool) @@ -1135,9 +1052,7 @@ def show_records( return pd.DataFrame(map(lambda x: x.dict(), filtered_records)) - @_add_docstr( - WeakLabelsBase.extend_matrix.__doc__.format(class_name="WeakMultiLabels") - ) + @_add_docstr(WeakLabelsBase.extend_matrix.__doc__.format(class_name="WeakMultiLabels")) def extend_matrix( self, thresholds: Union[List[float], np.ndarray], @@ -1147,15 +1062,9 @@ def extend_matrix( super().extend_matrix(thresholds=thresholds, embeddings=embeddings, gpu=gpu) def _extend_matrix_preprocess(self) -> Tuple[List[np.ndarray], List[np.ndarray]]: - abstains = [ - np.argwhere(self._matrix[:, i].sum(-1) < 0).flatten() - for i in range(self._matrix.shape[1]) - ] + abstains = [np.argwhere(self._matrix[:, i].sum(-1) < 0).flatten() for i in range(self._matrix.shape[1])] - supports = [ - np.argwhere(self._matrix[:, i].sum(-1) >= 0).flatten() - for i in range(self._matrix.shape[1]) - ] + supports = [np.argwhere(self._matrix[:, i].sum(-1) >= 0).flatten() for i in range(self._matrix.shape[1])] return abstains, supports diff --git a/src/argilla/listeners/listener.py b/src/argilla/listeners/listener.py index d836bbf5e6..75ed7f6103 100644 --- a/src/argilla/listeners/listener.py +++ b/src/argilla/listeners/listener.py @@ -69,9 +69,7 @@ def formatted_query(self) -> Optional[str]: return None return self.query.format(**(self.query_params or {})) - __listener_job__: Optional[schedule.Job] = dataclasses.field( - init=False, default=None - ) + __listener_job__: Optional[schedule.Job] = dataclasses.field(init=False, default=None) __stop_schedule_event__ = None __current_thread__ = None __scheduler__ = schedule.Scheduler() @@ -120,13 +118,11 @@ def start(self, *action_args, **action_kwargs): if self.is_running(): raise ValueError("Listener is already running") - job_step = self.__catch_exceptions__(cancel_on_failure=True)( - self.__listener_iteration_job__ - ) + job_step = self.__catch_exceptions__(cancel_on_failure=True)(self.__listener_iteration_job__) - self.__listener_job__ = self.__scheduler__.every( - self.interval_in_seconds - ).seconds.do(job_step, *action_args, **action_kwargs) + self.__listener_job__ = self.__scheduler__.every(self.interval_in_seconds).seconds.do( + job_step, *action_args, **action_kwargs + ) class _ScheduleThread(threading.Thread): _WAIT_EVENT = threading.Event() @@ -183,9 +179,7 @@ def __listener_iteration_job__(self, *args, **kwargs): ctx = RGListenerContext( listener=self, query_params=self.query_params, - metrics=self.__compute_metrics__( - current_api, dataset, query=self.formatted_query - ), + metrics=self.__compute_metrics__(current_api, dataset, query=self.formatted_query), ) if self.condition is None: self._LOGGER.debug("No condition found! Running action...") @@ -230,9 +224,7 @@ def __run_action__(self, ctx: Optional[RGListenerContext] = None, *args, **kwarg try: action_args = [ctx] if ctx else [] if self.query_records: - action_args.insert( - 0, argilla.load(name=self.dataset, query=self.formatted_query) - ) + action_args.insert(0, argilla.load(name=self.dataset, query=self.formatted_query)) self._LOGGER.debug(f"Running action with arguments: {action_args}") return self.action(*args, *action_args, **kwargs) except: # noqa: E722 diff --git a/src/argilla/metrics/commons.py b/src/argilla/metrics/commons.py index a97f2d57ca..449c5343bf 100644 --- a/src/argilla/metrics/commons.py +++ b/src/argilla/metrics/commons.py @@ -41,9 +41,7 @@ def text_length(name: str, query: Optional[str] = None) -> MetricSummary: return MetricSummary.new_summary( data=metric.results, - visualization=lambda: helpers.histogram( - data=metric.results, title=metric.description - ), + visualization=lambda: helpers.histogram(data=metric.results, title=metric.description), ) @@ -65,15 +63,11 @@ def records_status(name: str, query: Optional[str] = None) -> MetricSummary: >>> summary.visualize() # will plot an histogram with results >>> summary.data # returns the raw result data """ - metric = api.active_api().compute_metric( - name, metric="status_distribution", query=query - ) + metric = api.active_api().compute_metric(name, metric="status_distribution", query=query) return MetricSummary.new_summary( data=metric.results, - visualization=lambda: helpers.bar( - data=metric.results, title=metric.description - ), + visualization=lambda: helpers.bar(data=metric.results, title=metric.description), ) diff --git a/src/argilla/metrics/helpers.py b/src/argilla/metrics/helpers.py index 098d599036..aa07aee14f 100644 --- a/src/argilla/metrics/helpers.py +++ b/src/argilla/metrics/helpers.py @@ -34,10 +34,7 @@ def bar(data: dict, title: str = "Bar", x_legend: str = "", y_legend: str = ""): return empty_visualization() keys, values = zip(*data.items()) - keys = [ - key.encode("unicode-escape").decode() if isinstance(key, str) else key - for key in keys - ] + keys = [key.encode("unicode-escape").decode() if isinstance(key, str) else key for key in keys] fig = go.Figure(data=go.Bar(y=values, x=keys)) fig.update_layout( title=title, @@ -138,11 +135,7 @@ def f1(data: Dict[str, float], title: str): row=1, col=2, ) - per_label = { - k: v - for k, v in data.items() - if all(key not in k for key in ["macro", "micro", "support"]) - } + per_label = {k: v for k, v in data.items() if all(key not in k for key in ["macro", "micro", "support"])} fig.add_bar( x=[k for k, v in per_label.items()], diff --git a/src/argilla/metrics/models.py b/src/argilla/metrics/models.py index 9fe9c462bb..5f9bcffc77 100644 --- a/src/argilla/metrics/models.py +++ b/src/argilla/metrics/models.py @@ -29,15 +29,10 @@ def visualize(self): try: return self._build_visualization() except ModuleNotFoundError: - warnings.warn( - "Please, install plotly in order to use this feature\n" - "%>pip install plotly" - ) + warnings.warn("Please, install plotly in order to use this feature\n" "%>pip install plotly") @classmethod - def new_summary( - cls, data: Dict[str, Any], visualization: Callable - ) -> "MetricSummary": + def new_summary(cls, data: Dict[str, Any], visualization: Callable) -> "MetricSummary": summary = cls(data=data) summary._build_visualization = visualization return summary diff --git a/src/argilla/metrics/token_classification/metrics.py b/src/argilla/metrics/token_classification/metrics.py index 31e8b52b0d..b6591e26e2 100644 --- a/src/argilla/metrics/token_classification/metrics.py +++ b/src/argilla/metrics/token_classification/metrics.py @@ -22,9 +22,7 @@ from argilla.metrics.models import MetricSummary -def tokens_length( - name: str, query: Optional[str] = None, interval: int = 1 -) -> MetricSummary: +def tokens_length(name: str, query: Optional[str] = None, interval: int = 1) -> MetricSummary: """Computes the text length distribution measured in number of tokens. Args: @@ -42,9 +40,7 @@ def tokens_length( >>> summary.visualize() # will plot a histogram with results >>> summary.data # the raw histogram data with bins of size 5 """ - metric = api.active_api().compute_metric( - name, metric="tokens_length", query=query, interval=interval - ) + metric = api.active_api().compute_metric(name, metric="tokens_length", query=query, interval=interval) return MetricSummary.new_summary( data=metric.results, @@ -56,9 +52,7 @@ def tokens_length( ) -def token_frequency( - name: str, query: Optional[str] = None, tokens: int = 1000 -) -> MetricSummary: +def token_frequency(name: str, query: Optional[str] = None, tokens: int = 1000) -> MetricSummary: """Computes the token frequency distribution for a numbe of tokens. Args: @@ -76,9 +70,7 @@ def token_frequency( >>> summary.visualize() # will plot a histogram with results >>> summary.data # the top-50 tokens frequency """ - metric = api.active_api().compute_metric( - name, metric="token_frequency", query=query, size=tokens - ) + metric = api.active_api().compute_metric(name, metric="token_frequency", query=query, size=tokens) return MetricSummary.new_summary( data=metric.results, @@ -143,9 +135,7 @@ def token_capitalness(name: str, query: Optional[str] = None) -> MetricSummary: >>> summary.visualize() # will plot a histogram with results >>> summary.data # The token capitalness distribution """ - metric = api.active_api().compute_metric( - name, metric="token_capitalness", query=query - ) + metric = api.active_api().compute_metric(name, metric="token_capitalness", query=query) return MetricSummary.new_summary( data=metric.results, @@ -214,9 +204,7 @@ def mention_length( """ level = (level or "token").lower().strip() accepted_levels = ["token", "char"] - assert ( - level in accepted_levels - ), f"Unexpected value for level. Accepted values are {accepted_levels}" + assert level in accepted_levels, f"Unexpected value for level. Accepted values are {accepted_levels}" metric = api.active_api().compute_metric( name, @@ -421,9 +409,7 @@ def top_k_mentions( for mention in metric.results["mentions"]: entities = mention["entities"] if post_label_filter: - entities = [ - entity for entity in entities if entity["label"] in post_label_filter - ] + entities = [entity for entity in entities if entity["label"] in post_label_filter] if entities: mention["entities"] = entities filtered_mentions.append(mention) @@ -434,9 +420,7 @@ def top_k_mentions( for entity in mention["entities"]: label = entity["label"] mentions_for_label = entities.get(label, [0] * len(filtered_mentions)) - mentions_for_label[mention_values.index(mention["mention"])] = entity[ - "count" - ] + mentions_for_label[mention_values.index(mention["mention"])] = entity["count"] entities[label] = mentions_for_label return MetricSummary.new_summary( diff --git a/src/argilla/monitoring/_flair.py b/src/argilla/monitoring/_flair.py index 6330b9fc78..ad03da91ad 100644 --- a/src/argilla/monitoring/_flair.py +++ b/src/argilla/monitoring/_flair.py @@ -69,11 +69,7 @@ def predict(self, sentences: Union[List[Sentence], Sentence], *args, **kwargs): if not metadata: metadata = [{}] * len(sentences) - filtered_data = [ - (sentence, meta) - for sentence, meta in zip(sentences, metadata) - if self.is_record_accepted() - ] + filtered_data = [(sentence, meta) for sentence, meta in zip(sentences, metadata) if self.is_record_accepted()] if filtered_data: self._log_future = self.send_records(filtered_data) diff --git a/src/argilla/monitoring/_spacy.py b/src/argilla/monitoring/_spacy.py index a1c80833ac..195e42f0cb 100644 --- a/src/argilla/monitoring/_spacy.py +++ b/src/argilla/monitoring/_spacy.py @@ -61,14 +61,10 @@ def doc2token_classification( event_timestamp=datetime.utcnow(), ) - def _prepare_log_data( - self, docs_info: Tuple[Doc, Optional[Dict[str, Any]]] - ) -> Dict[str, Any]: + def _prepare_log_data(self, docs_info: Tuple[Doc, Optional[Dict[str, Any]]]) -> Dict[str, Any]: return dict( records=[ - self.doc2token_classification( - doc, agent=self.__wrapped__.path.name, metadata=metadata - ) + self.doc2token_classification(doc, agent=self.__wrapped__.path.name, metadata=metadata) for doc, metadata in docs_info ], name=self.dataset, diff --git a/src/argilla/monitoring/_transformers.py b/src/argilla/monitoring/_transformers.py index 46fe973f26..8ddcde1c8b 100644 --- a/src/argilla/monitoring/_transformers.py +++ b/src/argilla/monitoring/_transformers.py @@ -63,9 +63,7 @@ def _prepare_log_data( record = TextClassificationRecord( text=input_ if isinstance(input_, str) else None, inputs=input_ if not isinstance(input_, str) else None, - prediction=[ - (prediction.label, prediction.score) for prediction in predictions - ], + prediction=[(prediction.label, prediction.score) for prediction in predictions], prediction_agent=agent, metadata=metadata or {}, multi_label=multi_label, @@ -82,9 +80,7 @@ def _prepare_log_data( name=dataset_name, tags={ "name": self.model_config.name_or_path, - "transformers_version": self.fetch_transformers_version( - self.model_config - ), + "transformers_version": self.fetch_transformers_version(self.model_config), "model_type": self.model_config.model_type, "task": self.__model__.task, }, diff --git a/src/argilla/monitoring/asgi.py b/src/argilla/monitoring/asgi.py index f861eee958..deabe0ef47 100644 --- a/src/argilla/monitoring/asgi.py +++ b/src/argilla/monitoring/asgi.py @@ -56,9 +56,7 @@ def token_classification_mapper(inputs, outputs): tokens=tokens or _default_tokenization_pattern.split(text), prediction=[ (entity["label"], entity["start"], entity["end"]) - for entity in ( - outputs.get("entities") if isinstance(outputs, dict) else outputs - ) + for entity in (outputs.get("entities") if isinstance(outputs, dict) else outputs) ], event_timestamp=datetime.datetime.now(), ) @@ -67,12 +65,7 @@ def token_classification_mapper(inputs, outputs): def text_classification_mapper(inputs, outputs): return TextClassificationRecord( inputs=inputs, - prediction=[ - (label, score) - for label, score in zip( - outputs.get("labels", []), outputs.get("scores", []) - ) - ], + prediction=[(label, score) for label, score in zip(outputs.get("labels", []), outputs.get("scores", []))], event_timestamp=datetime.datetime.now(), ) @@ -167,16 +160,11 @@ async def dispatch( elif cached_request.method == "GET": inputs = cached_request.query_params._dict else: - raise NotImplementedError( - "Only request methods POST, PUT and GET are implemented." - ) + raise NotImplementedError("Only request methods POST, PUT and GET are implemented.") # Must obtain response from request response: Response = await call_next(cached_request) - if ( - not isinstance(response, (JSONResponse, StreamingResponse)) - or response.status_code >= 400 - ): + if not isinstance(response, (JSONResponse, StreamingResponse)) or response.status_code >= 400: return response new_response, outputs = await self._extract_response_content(response) @@ -186,9 +174,7 @@ async def dispatch( _logger.error("Cannot log to argilla", exc_info=ex) return await call_next(request) - async def _extract_response_content( - self, response: Response - ) -> Tuple[Response, List[Dict[str, Any]]]: + async def _extract_response_content(self, response: Response) -> Tuple[Response, List[Dict[str, Any]]]: """Extracts response body content from response and returns a new processable response""" body = b"" new_response = response @@ -206,23 +192,17 @@ async def _extract_response_content( body = response.body return new_response, json.loads(body) - def _prepare_argilla_data( - self, inputs: List[Dict[str, Any]], outputs: List[Dict[str, Any]], **tags - ): + def _prepare_argilla_data(self, inputs: List[Dict[str, Any]], outputs: List[Dict[str, Any]], **tags): # using the base monitor, we only need to provide the input data to the rg.log function # and the monitor will handle the sample rate, queue and argilla interaction try: records = self._records_mapper(inputs, outputs) - assert records, ValueError( - "The records_mapper returns and empty record list." - ) + assert records, ValueError("The records_mapper returns and empty record list.") if not isinstance(records, list): records = [records] except Exception as ex: records = [] - _logger.error( - "Cannot log to argilla. Error in records mapper.", exc_info=ex - ) + _logger.error("Cannot log to argilla. Error in records mapper.", exc_info=ex) for record in records: if self._monitor.agent is not None and not record.prediction_agent: diff --git a/src/argilla/monitoring/base.py b/src/argilla/monitoring/base.py index 59fc8976d4..86fb2491e7 100644 --- a/src/argilla/monitoring/base.py +++ b/src/argilla/monitoring/base.py @@ -187,9 +187,7 @@ def __init__( super().__init__(*args, **kwargs) assert dataset, "Missing dataset" - assert ( - 0.0 < sample_rate <= 1.0 - ), "Wrong sample rate. Set a value in (0, 1] range." + assert 0.0 < sample_rate <= 1.0, "Wrong sample rate. Set a value in (0, 1] range." self.dataset = dataset self.sample_rate = sample_rate diff --git a/src/argilla/monitoring/model_monitor.py b/src/argilla/monitoring/model_monitor.py index 0caca68e27..9d126329b8 100644 --- a/src/argilla/monitoring/model_monitor.py +++ b/src/argilla/monitoring/model_monitor.py @@ -62,7 +62,6 @@ def monitor( return model_monitor warnings.warn( - "The provided task model is not supported by monitoring module. " - "Predictions won't be logged into argilla" + "The provided task model is not supported by monitoring module. " "Predictions won't be logged into argilla" ) return task_model diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index 0d1470f0f9..46adf21652 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -49,9 +49,7 @@ async def list_datasets( ) -> List[Dataset]: return service.list( user=current_user, - workspaces=[request_deps.workspace] - if request_deps.workspace is not None - else None, + workspaces=[request_deps.workspace] if request_deps.workspace is not None else None, ) @@ -113,9 +111,7 @@ def update_dataset( service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> Dataset: - found_ds = service.find_by_name( - user=current_user, name=name, workspace=ds_params.workspace - ) + found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) return service.update( user=current_user, @@ -159,9 +155,7 @@ def close_dataset( service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ): - found_ds = service.find_by_name( - user=current_user, name=name, workspace=ds_params.workspace - ) + found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) service.close(user=current_user, dataset=found_ds) @@ -175,9 +169,7 @@ def open_dataset( service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ): - found_ds = service.find_by_name( - user=current_user, name=name, workspace=ds_params.workspace - ) + found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) service.open(user=current_user, dataset=found_ds) @@ -194,9 +186,7 @@ def copy_dataset( service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> Dataset: - found = service.find_by_name( - user=current_user, name=name, workspace=ds_params.workspace - ) + found = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) return service.copy_dataset( user=current_user, dataset=found, diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py index 68b92bbb37..3f4e0b6743 100644 --- a/src/argilla/server/apis/v0/handlers/metrics.py +++ b/src/argilla/server/apis/v0/handlers/metrics.py @@ -31,9 +31,7 @@ class MetricInfo(BaseModel): id: str = Field(description="The metric id") name: str = Field(description="The metric name") - description: Optional[str] = Field( - default=None, description="The metric description" - ) + description: Optional[str] = Field(default=None, description="The metric description") @dataclass diff --git a/src/argilla/server/apis/v0/handlers/text2text.py b/src/argilla/server/apis/v0/handlers/text2text.py index d0ca90bb62..cf6230c4ad 100644 --- a/src/argilla/server/apis/v0/handlers/text2text.py +++ b/src/argilla/server/apis/v0/handlers/text2text.py @@ -120,9 +120,7 @@ def search_records( name: str, search: Text2TextSearchRequest = None, common_params: CommonTaskHandlerDependencies = Depends(), - include_metrics: bool = Query( - False, description="If enabled, return related record metrics" - ), + include_metrics: bool = Query(False, description="If enabled, return related record metrics"), pagination: RequestPagination = Depends(), service: Text2TextService = Depends(Text2TextService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), @@ -149,9 +147,7 @@ def search_records( return Text2TextSearchResults( total=result.total, records=[Text2TextRecord.parse_obj(r) for r in result.records], - aggregations=Text2TextSearchAggregations.parse_obj(result.metrics) - if result.metrics - else None, + aggregations=Text2TextSearchAggregations.parse_obj(result.metrics) if result.metrics else None, ) def scan_data_response( @@ -183,9 +179,7 @@ def grouper(n, iterable, fillvalue=None): ) ) + "\n" - return StreamingResponseWithErrorHandling( - stream_generator(data_stream), media_type="application/json" - ) + return StreamingResponseWithErrorHandling(stream_generator(data_stream), media_type="application/json") @router.post( path=f"{base_endpoint}/data", diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py index 414529dbbf..e5d4ca840b 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification.py +++ b/src/argilla/server/apis/v0/handlers/text_classification.py @@ -89,9 +89,7 @@ async def bulk_records( name: str, bulk: TextClassificationBulkRequest, common_params: CommonTaskHandlerDependencies = Depends(), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), current_user: User = Security(auth.get_user, scopes=[]), @@ -120,9 +118,7 @@ async def bulk_records( # TODO(@frascuchon): Validator should be applied in the service layer records = [ServiceTextClassificationRecord.parse_obj(r) for r in bulk.records] - await validator.validate_dataset_records( - user=current_user, dataset=dataset, records=records - ) + await validator.validate_dataset_records(user=current_user, dataset=dataset, records=records) result = await service.add_records( dataset=dataset, @@ -144,13 +140,9 @@ def search_records( name: str, search: TextClassificationSearchRequest = None, common_params: CommonTaskHandlerDependencies = Depends(), - include_metrics: bool = Query( - False, description="If enabled, return related record metrics" - ), + include_metrics: bool = Query(False, description="If enabled, return related record metrics"), pagination: RequestPagination = Depends(), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> TextClassificationSearchResults: @@ -204,9 +196,7 @@ def search_records( return TextClassificationSearchResults( total=result.total, records=result.records, - aggregations=TextClassificationSearchAggregations.parse_obj(result.metrics) - if result.metrics - else None, + aggregations=TextClassificationSearchAggregations.parse_obj(result.metrics) if result.metrics else None, ) def scan_data_response( @@ -238,9 +228,7 @@ def grouper(n, iterable, fillvalue=None): ) ) + "\n" - return StreamingResponseWithErrorHandling( - stream_generator(data_stream), media_type="application/json" - ) + return StreamingResponseWithErrorHandling(stream_generator(data_stream), media_type="application/json") @router.post( f"{base_endpoint}/data", @@ -253,9 +241,7 @@ async def stream_data( common_params: CommonTaskHandlerDependencies = Depends(), id_from: Optional[str] = None, limit: Optional[int] = Query(None, description="Limit loaded records", gt=0), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> StreamingResponse: @@ -319,9 +305,7 @@ async def list_labeling_rules( name: str, common_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> List[LabelingRule]: dataset = datasets.find_by_name( @@ -332,9 +316,7 @@ async def list_labeling_rules( as_dataset_class=TasksFactory.get_task_dataset(task_type), ) - return [ - LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset) - ] + return [LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset)] @deprecate_endpoint( path=f"{new_base_endpoint}/labeling/rules", @@ -349,9 +331,7 @@ async def create_rule( name: str, rule: CreateLabelingRule, common_params: CommonTaskHandlerDependencies = Depends(), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> LabelingRule: @@ -385,13 +365,9 @@ async def create_rule( async def compute_rule_metrics( name: str, query: str, - labels: Optional[List[str]] = Query( - None, description="Label related to query rule", alias="label" - ), + labels: Optional[List[str]] = Query(None, description="Label related to query rule", alias="label"), common_params: CommonTaskHandlerDependencies = Depends(), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> LabelingRuleMetricsSummary: @@ -417,9 +393,7 @@ async def compute_rule_metrics( async def compute_dataset_rules_metrics( name: str, common_params: CommonTaskHandlerDependencies = Depends(), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> DatasetLabelingRulesMetricsSummary: @@ -444,9 +418,7 @@ async def delete_labeling_rule( name: str, query: str, common_params: CommonTaskHandlerDependencies = Depends(), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> None: @@ -473,9 +445,7 @@ async def get_rule( name: str, query: str, common_params: CommonTaskHandlerDependencies = Depends(), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> LabelingRule: @@ -506,9 +476,7 @@ async def update_rule( query: str, update: UpdateLabelingRule, common_params: CommonTaskHandlerDependencies = Depends(), - service: TextClassificationService = Depends( - TextClassificationService.get_instance - ), + service: TextClassificationService = Depends(TextClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> LabelingRule: diff --git a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py index e40c97e079..70c6b327d5 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py @@ -63,9 +63,7 @@ async def get_dataset_settings( task=task, ) - settings = await datasets.get_settings( - user=user, dataset=found_ds, class_type=__svc_settings_class__ - ) + settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__) return TextClassificationSettings.parse_obj(settings) @deprecate_endpoint( @@ -79,9 +77,7 @@ async def get_dataset_settings( response_model=TextClassificationSettings, ) async def save_settings( - request: TextClassificationSettings = Body( - ..., description=f"The {task} dataset settings" - ), + request: TextClassificationSettings = Body(..., description=f"The {task} dataset settings"), name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), @@ -94,9 +90,7 @@ async def save_settings( task=task, workspace=ws_params.workspace, ) - await validator.validate_dataset_settings( - user=user, dataset=found_ds, settings=request - ) + await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request) settings = await datasets.save_settings( user=user, dataset=found_ds, diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py index 7a21310e1d..b4a851b8de 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification.py +++ b/src/argilla/server/apis/v0/handlers/token_classification.py @@ -83,9 +83,7 @@ async def bulk_records( name: str, bulk: TokenClassificationBulkRequest, common_params: CommonTaskHandlerDependencies = Depends(), - service: TokenClassificationService = Depends( - TokenClassificationService.get_instance - ), + service: TokenClassificationService = Depends(TokenClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), validator: DatasetValidator = Depends(DatasetValidator.get_instance), current_user: User = Security(auth.get_user, scopes=[]), @@ -145,9 +143,7 @@ def search_records( description="If enabled, return related record metrics", ), pagination: RequestPagination = Depends(), - service: TokenClassificationService = Depends( - TokenClassificationService.get_instance - ), + service: TokenClassificationService = Depends(TokenClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> TokenClassificationSearchResults: @@ -173,9 +169,7 @@ def search_records( return TokenClassificationSearchResults( total=results.total, records=[TokenClassificationRecord.parse_obj(r) for r in results.records], - aggregations=TokenClassificationAggregations.parse_obj(results.metrics) - if results.metrics - else None, + aggregations=TokenClassificationAggregations.parse_obj(results.metrics) if results.metrics else None, ) def scan_data_response( @@ -207,9 +201,7 @@ def grouper(n, iterable, fillvalue=None): ) ) + "\n" - return StreamingResponseWithErrorHandling( - stream_generator(data_stream), media_type="application/json" - ) + return StreamingResponseWithErrorHandling(stream_generator(data_stream), media_type="application/json") @router.post( path=f"{base_endpoint}/data", @@ -221,9 +213,7 @@ async def stream_data( query: Optional[TokenClassificationQuery] = None, common_params: CommonTaskHandlerDependencies = Depends(), limit: Optional[int] = Query(None, description="Limit loaded records", gt=0), - service: TokenClassificationService = Depends( - TokenClassificationService.get_instance - ), + service: TokenClassificationService = Depends(TokenClassificationService.get_instance), datasets: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), id_from: Optional[str] = None, diff --git a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py index 4756aeb9da..a5ac2411f3 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py +++ b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py @@ -63,9 +63,7 @@ async def get_dataset_settings( task=task, ) - settings = await datasets.get_settings( - user=user, dataset=found_ds, class_type=__svc_settings_class__ - ) + settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__) return TokenClassificationSettings.parse_obj(settings) @deprecate_endpoint( @@ -79,9 +77,7 @@ async def get_dataset_settings( response_model=TokenClassificationSettings, ) async def save_settings( - request: TokenClassificationSettings = Body( - ..., description=f"The {task} dataset settings" - ), + request: TokenClassificationSettings = Body(..., description=f"The {task} dataset settings"), name: str = DATASET_NAME_PATH_PARAM, ws_params: CommonTaskHandlerDependencies = Depends(), datasets: DatasetsService = Depends(DatasetsService.get_instance), @@ -94,9 +90,7 @@ async def save_settings( task=task, workspace=ws_params.workspace, ) - await validator.validate_dataset_settings( - user=user, dataset=found_ds, settings=request - ) + await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request) settings = await datasets.save_settings( user=user, dataset=found_ds, diff --git a/src/argilla/server/apis/v0/handlers/users.py b/src/argilla/server/apis/v0/handlers/users.py index d48d89c367..3d17b69492 100644 --- a/src/argilla/server/apis/v0/handlers/users.py +++ b/src/argilla/server/apis/v0/handlers/users.py @@ -22,12 +22,8 @@ router = APIRouter(tags=["users"]) -@router.get( - "/me", response_model=User, response_model_exclude_none=True, operation_id="whoami" -) -async def whoami( - request: Request, current_user: User = Security(auth.get_user, scopes=[]) -): +@router.get("/me", response_model=User, response_model_exclude_none=True, operation_id="whoami") +async def whoami(request: Request, current_user: User = Security(auth.get_user, scopes=[])): """ User info endpoint diff --git a/src/argilla/server/apis/v0/helpers.py b/src/argilla/server/apis/v0/helpers.py index 55874f58f5..905a7ed421 100644 --- a/src/argilla/server/apis/v0/helpers.py +++ b/src/argilla/server/apis/v0/helpers.py @@ -15,9 +15,7 @@ from typing import Callable -def deprecate_endpoint( - path: str, new_path: str, router_method: Callable, *args, **kwargs -): +def deprecate_endpoint(path: str, new_path: str, router_method: Callable, *args, **kwargs): """ Helper decorator to deprecate a `path` endpoint adding the `new_path` endpoint. diff --git a/src/argilla/server/apis/v0/models/commons/model.py b/src/argilla/server/apis/v0/models/commons/model.py index 25a62c0657..f25289ab9a 100644 --- a/src/argilla/server/apis/v0/models/commons/model.py +++ b/src/argilla/server/apis/v0/models/commons/model.py @@ -67,9 +67,7 @@ class BulkResponse(BaseModel): failed: int = 0 -class BaseSearchResults( - GenericModel, Generic[_Record, ServiceSearchResultsAggregations] -): +class BaseSearchResults(GenericModel, Generic[_Record, ServiceSearchResultsAggregations]): total: int = 0 records: List[_Record] = Field(default_factory=list) aggregations: ServiceSearchResultsAggregations = None diff --git a/src/argilla/server/apis/v0/models/commons/params.py b/src/argilla/server/apis/v0/models/commons/params.py index 09a27b559f..4bd92aa4e2 100644 --- a/src/argilla/server/apis/v0/models/commons/params.py +++ b/src/argilla/server/apis/v0/models/commons/params.py @@ -23,9 +23,7 @@ ) from argilla.server.security.model import WORKSPACE_NAME_PATTERN -DATASET_NAME_PATH_PARAM = Path( - ..., regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name" -) +DATASET_NAME_PATH_PARAM = Path(..., regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name") @dataclass @@ -33,9 +31,7 @@ class RequestPagination: """Query pagination params""" limit: int = Query(50, gte=0, le=1000, description="Response records limit") - from_: int = Query( - 0, ge=0, le=10000, alias="from", description="Record sequence from" - ) + from_: int = Query(0, ge=0, le=10000, alias="from", description="Record sequence from") @dataclass @@ -60,14 +56,9 @@ class CommonTaskHandlerDependencies: @property def workspace(self) -> str: """Return read workspace. Query param prior to header param""" - workspace = ( - self.__workspace_param__ - or self.__workspace_header__ - or self.__old_workspace_header__ - ) + workspace = self.__workspace_param__ or self.__workspace_header__ or self.__old_workspace_header__ if workspace: assert WORKSPACE_NAME_PATTERN.match(workspace), ( - "Wrong workspace format. " - f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}" + "Wrong workspace format. " f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}" ) return workspace diff --git a/src/argilla/server/apis/v0/models/text2text.py b/src/argilla/server/apis/v0/models/text2text.py index 5345f95746..1caa2b107d 100644 --- a/src/argilla/server/apis/v0/models/text2text.py +++ b/src/argilla/server/apis/v0/models/text2text.py @@ -71,9 +71,7 @@ class Text2TextSearchAggregations(ServiceBaseSearchResultsAggregations): annotated_text: Dict[str, int] = Field(default_factory=dict) -class Text2TextSearchResults( - BaseSearchResults[Text2TextRecord, Text2TextSearchAggregations] -): +class Text2TextSearchResults(BaseSearchResults[Text2TextRecord, Text2TextSearchAggregations]): pass diff --git a/src/argilla/server/apis/v0/models/text_classification.py b/src/argilla/server/apis/v0/models/text_classification.py index b2886fe56b..a4aa5621b4 100644 --- a/src/argilla/server/apis/v0/models/text_classification.py +++ b/src/argilla/server/apis/v0/models/text_classification.py @@ -47,17 +47,12 @@ class UpdateLabelingRule(BaseModel): - label: Optional[str] = Field( - default=None, description="@Deprecated::The label associated with the rule." - ) + label: Optional[str] = Field(default=None, description="@Deprecated::The label associated with the rule.") labels: List[str] = Field( default_factory=list, - description="For multi label problems, a list of labels. " - "It will replace the `label` field", - ) - description: Optional[str] = Field( - None, description="A brief description of the rule" + description="For multi label problems, a list of labels. " "It will replace the `label` field", ) + description: Optional[str] = Field(None, description="A brief description of the rule") @root_validator def initialize_labels(cls, values): @@ -83,9 +78,7 @@ def strip_query(cls, query: str) -> str: class LabelingRule(CreateLabelingRule): author: str = Field(description="User who created the rule") - created_at: Optional[datetime] = Field( - default_factory=datetime.utcnow, description="Rule creation timestamp" - ) + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="Rule creation timestamp") class LabelingRuleMetricsSummary(_LabelingRuleMetricsSummary): @@ -110,9 +103,7 @@ class TextClassificationRecordInputs(BaseRecordInputs[TextClassificationAnnotati explanation: Optional[Dict[str, List[TokenAttributions]]] = None -class TextClassificationRecord( - TextClassificationRecordInputs, BaseRecord[TextClassificationAnnotation] -): +class TextClassificationRecord(TextClassificationRecordInputs, BaseRecord[TextClassificationAnnotation]): pass @@ -125,9 +116,7 @@ def check_multi_label_integrity(cls, records: List[TextClassificationRecord]): if records: multi_label = records[0].multi_label for record in records[1:]: - assert ( - multi_label == record.multi_label - ), "All records must be single/multi labelled" + assert multi_label == record.multi_label, "All records must be single/multi labelled" return records diff --git a/src/argilla/server/apis/v0/models/token_classification.py b/src/argilla/server/apis/v0/models/token_classification.py index 63eaaa9632..4ec6608a85 100644 --- a/src/argilla/server/apis/v0/models/token_classification.py +++ b/src/argilla/server/apis/v0/models/token_classification.py @@ -61,9 +61,7 @@ def check_text_content(cls, text: str): return text -class TokenClassificationRecord( - TokenClassificationRecordInputs, BaseRecord[TokenClassificationAnnotation] -): +class TokenClassificationRecord(TokenClassificationRecordInputs, BaseRecord[TokenClassificationAnnotation]): pass @@ -88,9 +86,7 @@ class TokenClassificationAggregations(ServiceBaseSearchResultsAggregations): mentions: Dict[str, Dict[str, int]] = Field(default_factory=dict) -class TokenClassificationSearchResults( - BaseSearchResults[TokenClassificationRecord, TokenClassificationAggregations] -): +class TokenClassificationSearchResults(BaseSearchResults[TokenClassificationRecord, TokenClassificationAggregations]): pass diff --git a/src/argilla/server/apis/v0/validators/text_classification.py b/src/argilla/server/apis/v0/validators/text_classification.py index e225749bed..c7a3209a64 100644 --- a/src/argilla/server/apis/v0/validators/text_classification.py +++ b/src/argilla/server/apis/v0/validators/text_classification.py @@ -54,9 +54,7 @@ def get_instance( cls._INSTANCE = cls(datasets, metrics=metrics) return cls._INSTANCE - async def validate_dataset_settings( - self, user: User, dataset: Dataset, settings: TextClassificationSettings - ): + async def validate_dataset_settings(self, user: User, dataset: Dataset, settings: TextClassificationSettings): if settings and settings.label_schema: results = self.__metrics__.summarize_metric( dataset=dataset, @@ -66,9 +64,7 @@ async def validate_dataset_settings( ) if results: labels = results.get("labels", []) - label_schema = set( - [label.name for label in settings.label_schema.labels] - ) + label_schema = set([label.name for label in settings.label_schema.labels]) for label in labels: if label not in label_schema: raise BadRequestError( @@ -87,9 +83,7 @@ async def validate_dataset_records( user=user, dataset=dataset, class_type=__svc_settings_class__ ) if settings and settings.label_schema: - label_schema = set( - [label.name for label in settings.label_schema.labels] - ) + label_schema = set([label.name for label in settings.label_schema.labels]) for r in records: if r.prediction: self.__check_label_classes__( diff --git a/src/argilla/server/apis/v0/validators/token_classification.py b/src/argilla/server/apis/v0/validators/token_classification.py index a5004c9dfa..52616d366c 100644 --- a/src/argilla/server/apis/v0/validators/token_classification.py +++ b/src/argilla/server/apis/v0/validators/token_classification.py @@ -55,9 +55,7 @@ def get_instance( cls._INSTANCE = cls(datasets, metrics=metrics) return cls._INSTANCE - async def validate_dataset_settings( - self, user: User, dataset: Dataset, settings: TokenClassificationSettings - ): + async def validate_dataset_settings(self, user: User, dataset: Dataset, settings: TokenClassificationSettings): if settings and settings.label_schema: results = self.__metrics__.summarize_metric( dataset=dataset, @@ -67,9 +65,7 @@ async def validate_dataset_settings( ) if results: labels = results.get("labels", []) - label_schema = set( - [label.name for label in settings.label_schema.labels] - ) + label_schema = set([label.name for label in settings.label_schema.labels]) for label in labels: if label not in label_schema: raise BadRequestError( @@ -84,15 +80,11 @@ async def validate_dataset_records( records: List[ServiceTokenClassificationRecord], ): try: - settings: TokenClassificationSettings = ( - await self.__datasets__.get_settings( - user=user, dataset=dataset, class_type=__svc_settings_class__ - ) + settings: TokenClassificationSettings = await self.__datasets__.get_settings( + user=user, dataset=dataset, class_type=__svc_settings_class__ ) if settings and settings.label_schema: - label_schema = set( - [label.name for label in settings.label_schema.labels] - ) + label_schema = set([label.name for label in settings.label_schema.labels]) for r in records: if r.prediction: @@ -103,9 +95,7 @@ async def validate_dataset_records( pass @staticmethod - def __check_label_entities__( - label_schema: Set[str], annotation: ServiceTokenClassificationAnnotation - ): + def __check_label_entities__(label_schema: Set[str], annotation: ServiceTokenClassificationAnnotation): if not annotation: return for entity in annotation.entities: diff --git a/src/argilla/server/commons/config.py b/src/argilla/server/commons/config.py index b12dfc368c..47d068fc39 100644 --- a/src/argilla/server/commons/config.py +++ b/src/argilla/server/commons/config.py @@ -85,18 +85,14 @@ def __get_task_config__(cls, task): return config @classmethod - def find_task_metric( - cls, task: TaskType, metric_id: str - ) -> Optional[ServiceBaseMetric]: + def find_task_metric(cls, task: TaskType, metric_id: str) -> Optional[ServiceBaseMetric]: metrics = cls.find_task_metrics(task, {metric_id}) if metrics: return metrics[0] raise EntityNotFoundError(name=metric_id, type=ServiceBaseMetric) @classmethod - def find_task_metrics( - cls, task: TaskType, metric_ids: Set[str] - ) -> List[ServiceBaseMetric]: + def find_task_metrics(cls, task: TaskType, metric_ids: Set[str]) -> List[ServiceBaseMetric]: if not metric_ids: return [] diff --git a/src/argilla/server/commons/telemetry.py b/src/argilla/server/commons/telemetry.py index 3f39929033..e00abd15c6 100644 --- a/src/argilla/server/commons/telemetry.py +++ b/src/argilla/server/commons/telemetry.py @@ -67,9 +67,7 @@ def get(cls): try: cls.__INSTANCE__ = cls(client=_configure_analytics()) except Exception as err: - logging.getLogger(__name__).warning( - f"Cannot initialize telemetry. Error: {err}. Disabling..." - ) + logging.getLogger(__name__).warning(f"Cannot initialize telemetry. Error: {err}. Disabling...") settings.enable_telemetry = False return None return cls.__INSTANCE__ @@ -88,9 +86,7 @@ def __post_init__(self): "version": __version__, } - def track_data( - self, action: str, data: Dict[str, Any], include_system_info: bool = True - ): + def track_data(self, action: str, data: Dict[str, Any], include_system_info: bool = True): event_data = data.copy() self.client.track( user_id=self.__server_id_str__, @@ -101,18 +97,13 @@ def track_data( def _process_request_info(request: Request): - return { - header: request.headers.get(header) - for header in ["user-agent", "accept-language"] - } + return {header: request.headers.get(header) for header in ["user-agent", "accept-language"]} async def track_error(error: ServerError, request: Request): client = _TelemetryClient.get() if client: - client.track_data( - "ServerErrorFound", {"code": error.code, **_process_request_info(request)} - ) + client.track_data("ServerErrorFound", {"code": error.code, **_process_request_info(request)}) async def track_bulk(task: TaskType, records: int): diff --git a/src/argilla/server/daos/backend/client_adapters/opensearch.py b/src/argilla/server/daos/backend/client_adapters/opensearch.py index 088b48da84..83629300cb 100644 --- a/src/argilla/server/daos/backend/client_adapters/opensearch.py +++ b/src/argilla/server/daos/backend/client_adapters/opensearch.py @@ -220,11 +220,7 @@ def compute_index_metric( } ) - filtered_params = { - argument: params[argument] - for argument in metric.metric_arg_names - if argument in params - } + filtered_params = {argument: params[argument] for argument in metric.metric_arg_names if argument in params} aggregations = metric.aggregation_request(**filtered_params) if not aggregations: @@ -250,9 +246,7 @@ def compute_index_metric( search_aggregations = search.get("aggregations", {}) if search_aggregations: - parsed_aggregations = query_helpers.parse_aggregations( - search_aggregations - ) + parsed_aggregations = query_helpers.parse_aggregations(search_aggregations) results.update(parsed_aggregations) return metric.aggregation_result(results.get(metric.id, results)) @@ -459,9 +453,7 @@ def is_subfield(key: str): return True return False - schema = { - key: value for key, value in schema.items() if not is_subfield(key) - } + schema = {key: value for key, value in schema.items() if not is_subfield(key)} return schema except IndexNotFoundError: @@ -688,11 +680,7 @@ def is_read_only_index(self, index: str) -> bool: allow_no_indices=True, flat_settings=True, ) - return ( - response[index]["settings"]["index.blocks.write"] == "true" - if response - else False - ) + return response[index]["settings"]["index.blocks.write"] == "true" if response else False def enable_read_only_index(self, index: str): return self._enable_or_disable_read_only_index( @@ -770,8 +758,7 @@ def _get_fields_schema( data = response return { - key: list(definition["mapping"].values())[0]["type"] - for key, definition in data["mappings"].items() + key: list(definition["mapping"].values())[0]["type"] for key, definition in data["mappings"].items() } def _normalize_document( diff --git a/src/argilla/server/daos/backend/generic_elastic.py b/src/argilla/server/daos/backend/generic_elastic.py index 9a1bdd103b..61d2ece88a 100644 --- a/src/argilla/server/daos/backend/generic_elastic.py +++ b/src/argilla/server/daos/backend/generic_elastic.py @@ -291,8 +291,7 @@ def _check_max_number_of_vectors(): if len(vector_names) > settings.vectors_fields_limit: raise BadRequestError( - detail=f"Cannot create more than {settings.vectors_fields_limit} " - "kind of vectors per dataset" + detail=f"Cannot create more than {settings.vectors_fields_limit} " "kind of vectors per dataset" ) _check_max_number_of_vectors() @@ -489,8 +488,7 @@ def find_dataset( if len(docs) > 1: raise ValueError( - f"Ambiguous dataset info found for name {name}. " - "Please provide a valid owner/workspace" + f"Ambiguous dataset info found for name {name}. " "Please provide a valid owner/workspace" ) document = docs[0] return document diff --git a/src/argilla/server/daos/backend/mappings/helpers.py b/src/argilla/server/daos/backend/mappings/helpers.py index b0726f94a2..4b32da9244 100644 --- a/src/argilla/server/daos/backend/mappings/helpers.py +++ b/src/argilla/server/daos/backend/mappings/helpers.py @@ -185,11 +185,7 @@ def dynamic_metadata_text(): def dynamic_annotations_text(path: str): path = f"{path}.*" - return { - path: mappings.path_match_keyword_template( - path=path, enable_text_search_in_keywords=True - ) - } + return {path: mappings.path_match_keyword_template(path=path, enable_text_search_in_keywords=True)} def tasks_common_mappings(): diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py index 256ffcdbf1..82bd60a1e7 100644 --- a/src/argilla/server/daos/backend/metrics/base.py +++ b/src/argilla/server/daos/backend/metrics/base.py @@ -37,9 +37,7 @@ def __post_init__(self): def get_function_arg_names(func): return func.__code__.co_varnames - def aggregation_request( - self, *args, **kwargs - ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + def aggregation_request(self, *args, **kwargs) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """ Configures the summary es aggregation definition """ @@ -121,9 +119,7 @@ def _build_aggregation(self, interval: Optional[float] = None) -> Dict[str, Any] if self.fixed_interval: interval = self.fixed_interval - return aggregations.histogram_aggregation( - field_name=self.field, script=self.script, interval=interval - ) + return aggregations.histogram_aggregation(field_name=self.field, script=self.script, interval=interval) @dataclasses.dataclass diff --git a/src/argilla/server/daos/backend/metrics/datasets.py b/src/argilla/server/daos/backend/metrics/datasets.py index a673b7ba02..b261599aed 100644 --- a/src/argilla/server/daos/backend/metrics/datasets.py +++ b/src/argilla/server/daos/backend/metrics/datasets.py @@ -15,6 +15,4 @@ # All metrics related to the datasets index from argilla.server.daos.backend.metrics.base import TermsAggregation -METRICS = { - "all_workspaces": TermsAggregation(id="all_workspaces", field="owner.keyword") -} +METRICS = {"all_workspaces": TermsAggregation(id="all_workspaces", field="owner.keyword")} diff --git a/src/argilla/server/daos/backend/metrics/text_classification.py b/src/argilla/server/daos/backend/metrics/text_classification.py index fa59630558..8d3a83228c 100644 --- a/src/argilla/server/daos/backend/metrics/text_classification.py +++ b/src/argilla/server/daos/backend/metrics/text_classification.py @@ -29,9 +29,7 @@ def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]: rules_filters = [filters.text_query(rule_query) for rule_query in queries] return aggregations.filters_aggregation( filters={ - "covered_records": filters.boolean_filter( - should_filters=rules_filters, minimum_should_match=1 - ), + "covered_records": filters.boolean_filter(should_filters=rules_filters, minimum_should_match=1), "annotated_covered_records": filters.boolean_filter( filter_query=filters.exists_field("annotated_as"), should_filters=rules_filters, @@ -45,9 +43,7 @@ def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]: class LabelingRulesMetric(ElasticsearchMetric): id: str - def _build_aggregation( - self, rule_query: str, labels: Optional[List[str]] - ) -> Dict[str, Any]: + def _build_aggregation(self, rule_query: str, labels: Optional[List[str]]) -> Dict[str, Any]: annotated_records_filter = filters.exists_field("annotated_as") rule_query_filter = filters.text_query(rule_query) aggr_filters = { @@ -60,9 +56,7 @@ def _build_aggregation( if labels is not None: for label in labels: - rule_label_annotated_filter = filters.term_filter( - "annotated_as", value=label - ) + rule_label_annotated_filter = filters.term_filter("annotated_as", value=label) encoded_label = self._encode_label_name(label) aggr_filters.update( { @@ -99,9 +93,7 @@ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, An aggregation_result = unflatten_dict(aggregation_result) results = { "covered_records": aggregation_result.pop("covered_records"), - "annotated_covered_records": aggregation_result.pop( - "annotated_covered_records" - ), + "annotated_covered_records": aggregation_result.pop("annotated_covered_records"), } all_correct = [] diff --git a/src/argilla/server/daos/backend/metrics/token_classification.py b/src/argilla/server/daos/backend/metrics/token_classification.py index 2c5dc27f0e..35e2f7bb11 100644 --- a/src/argilla/server/daos/backend/metrics/token_classification.py +++ b/src/argilla/server/daos/backend/metrics/token_classification.py @@ -57,11 +57,7 @@ def _inner_aggregation( self.compound_nested_field(self.labels_field), size=entity_size, ), - "count": { - "cardinality": { - "field": self.compound_nested_field(self.labels_field) - } - }, + "count": {"cardinality": {"field": self.compound_nested_field(self.labels_field)}}, "entities_variability_filter": { "bucket_selector": { "buckets_path": {"numLabels": "count"}, @@ -83,10 +79,7 @@ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, An result = [ { "mention": mention, - "entities": [ - {"label": entity, "count": count} - for entity, count in mention_aggs["entities"].items() - ], + "entities": [{"label": entity, "count": count} for entity, count in mention_aggs["entities"].items()], } for mention, mention_aggs in aggregation_result.items() ] @@ -194,16 +187,12 @@ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, An "predicted_mentions_distribution": NestedBidimensionalTermsAggregation( id="predicted_mentions_distribution", nested_path="metrics.predicted.mentions", - biterms=BidimensionalTermsAggregation( - id="bi-dimensional", field_x="label", field_y="value" - ), + biterms=BidimensionalTermsAggregation(id="bi-dimensional", field_x="label", field_y="value"), ), "annotated_mentions_distribution": NestedBidimensionalTermsAggregation( id="predicted_mentions_distribution", nested_path="metrics.annotated.mentions", - biterms=BidimensionalTermsAggregation( - id="bi-dimensional", field_x="label", field_y="value" - ), + biterms=BidimensionalTermsAggregation(id="bi-dimensional", field_x="label", field_y="value"), ), "predicted_top_k_mentions_consistency": TopKMentionsConsistency( id="predicted_top_k_mentions_consistency", diff --git a/src/argilla/server/daos/backend/query_helpers.py b/src/argilla/server/daos/backend/query_helpers.py index 1c57c5ccee..92a0f4bbd5 100644 --- a/src/argilla/server/daos/backend/query_helpers.py +++ b/src/argilla/server/daos/backend/query_helpers.py @@ -33,16 +33,11 @@ def resolve_mapping(info) -> Dict[str, Any]: return { "type": "nested", "include_in_root": True, - "properties": { - key: resolve_mapping(info) - for key, info in model_class.schema()["properties"].items() - }, + "properties": {key: resolve_mapping(info) for key, info in model_class.schema()["properties"].items()}, } -def parse_aggregations( - es_aggregations: Dict[str, Any] = None -) -> Optional[Dict[str, Any]]: +def parse_aggregations(es_aggregations: Dict[str, Any] = None) -> Optional[Dict[str, Any]]: """Transforms elasticsearch raw aggregations into a more friendly structure""" if es_aggregations is None: @@ -80,9 +75,7 @@ def parse_buckets(buckets: List[Dict[str, Any]]) -> Dict[str, Any]: key_metrics = {} for metric_key, metric in list(bucket.items()): if "buckets" in metric: - key_metrics.update( - {metric_key: parse_buckets(metric.get("buckets", []))} - ) + key_metrics.update({metric_key: parse_buckets(metric.get("buckets", []))}) else: metric_values = list(metric.values()) value = metric_values[0] if len(metric_values) == 1 else metric @@ -171,13 +164,7 @@ def metadata(metadata: Dict[str, Union[str, List[str]]]) -> List[Dict[str, Any]] return [] return [ - { - "terms": { - f"metadata.{key}": query_text - if isinstance(query_text, List) - else [query_text] - } - } + {"terms": {f"metadata.{key}": query_text if isinstance(query_text, List) else [query_text]}} for key, query_text in metadata.items() ] @@ -237,9 +224,7 @@ class aggregations: MAX_AGGREGATION_SIZE = 5000 # TODO: improve by setting env var @staticmethod - def nested_aggregation( - nested_path: str, inner_aggregation: Dict[str, Any] - ) -> Dict[str, Any]: + def nested_aggregation(nested_path: str, inner_aggregation: Dict[str, Any]) -> Dict[str, Any]: inner_meta = list(inner_aggregation.values())[0].get("meta", {}) return { "meta": { @@ -250,15 +235,11 @@ def nested_aggregation( } @staticmethod - def bidimentional_terms_aggregations( - field_name_x: str, field_name_y: str, size=DEFAULT_AGGREGATION_SIZE - ): + def bidimentional_terms_aggregations(field_name_x: str, field_name_y: str, size=DEFAULT_AGGREGATION_SIZE): return { **aggregations.terms_aggregation(field_name_x, size=size), "meta": {"kind": "2d-terms"}, - "aggs": { - field_name_y: aggregations.terms_aggregation(field_name_y, size=size) - }, + "aggs": {field_name_y: aggregations.terms_aggregation(field_name_y, size=size)}, } @staticmethod @@ -321,9 +302,7 @@ def custom_fields( ) -> Dict[str, Dict[str, Any]]: """Build a set of aggregations for a given field definition (extracted from index mapping)""" - def __resolve_aggregation_for_field_type( - field_type: str, field_name: str - ) -> Optional[Dict[str, Any]]: + def __resolve_aggregation_for_field_type(field_type: str, field_name: str) -> Optional[Dict[str, Any]]: if field_type in ["keyword", "long", "integer", "boolean"]: return aggregations.terms_aggregation(field_name=field_name, size=size) if field_type in ["float", "date"]: @@ -337,9 +316,7 @@ def __resolve_aggregation_for_field_type( return { key: aggregation for key, type_ in fields_definitions.items() - for aggregation in [ - __resolve_aggregation_for_field_type(type_, field_name=key) - ] + for aggregation in [__resolve_aggregation_for_field_type(type_, field_name=key)] if aggregation } @@ -348,9 +325,7 @@ def filters_aggregation(filters: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: return {"filters": {"filters": filters}} -def find_nested_field_path( - field_name: str, mapping_definition: Dict[str, Any] -) -> Optional[str]: +def find_nested_field_path(field_name: str, mapping_definition: Dict[str, Any]) -> Optional[str]: """ Given a field name, find the nested path if any related to field name definition in provided mapping definition @@ -367,9 +342,7 @@ def find_nested_field_path( The found nested path if any, None otherwise """ - def build_flatten_properties_map( - properties: Dict[str, Any], prefix: str = "" - ) -> Dict[str, Any]: + def build_flatten_properties_map(properties: Dict[str, Any], prefix: str = "") -> Dict[str, Any]: results = {} for prop_name, prop_value in properties.items(): if prefix: @@ -377,11 +350,7 @@ def build_flatten_properties_map( if "type" in prop_value: results[prop_name] = prop_value["type"] if "properties" in prop_value: - results.update( - build_flatten_properties_map( - prop_value["properties"], prefix=prop_name - ) - ) + results.update(build_flatten_properties_map(prop_value["properties"], prefix=prop_name)) return results properties_map = build_flatten_properties_map(mapping_definition) diff --git a/src/argilla/server/daos/backend/search/model.py b/src/argilla/server/daos/backend/search/model.py index fcaad0e5fa..49a4986595 100644 --- a/src/argilla/server/daos/backend/search/model.py +++ b/src/argilla/server/daos/backend/search/model.py @@ -70,8 +70,7 @@ class VectorSearch(BaseModel): value: List[float] k: Optional[int] = Field( default=None, - description="Number of elements to retrieve. " - "If not provided, the request size will be used instead", + description="Number of elements to retrieve. " "If not provided, the request size will be used instead", ) diff --git a/src/argilla/server/daos/backend/search/query_builder.py b/src/argilla/server/daos/backend/search/query_builder.py index 76d9e989ff..3da0591864 100644 --- a/src/argilla/server/daos/backend/search/query_builder.py +++ b/src/argilla/server/daos/backend/search/query_builder.py @@ -36,13 +36,9 @@ class HighlightParser: __HIGHLIGHT_PRE_TAG__ = "<@@-ar-key>" __HIGHLIGHT_POST_TAG__ = "" - __HIGHLIGHT_VALUES_REGEX__ = re.compile( - rf"{__HIGHLIGHT_PRE_TAG__}(.+?){__HIGHLIGHT_POST_TAG__}" - ) + __HIGHLIGHT_VALUES_REGEX__ = re.compile(rf"{__HIGHLIGHT_PRE_TAG__}(.+?){__HIGHLIGHT_POST_TAG__}") - __HIGHLIGHT_PHRASE_PRE_PARSER_REGEX__ = re.compile( - rf"{__HIGHLIGHT_POST_TAG__}\s+{__HIGHLIGHT_PRE_TAG__}" - ) + __HIGHLIGHT_PHRASE_PRE_PARSER_REGEX__ = re.compile(rf"{__HIGHLIGHT_POST_TAG__}\s+{__HIGHLIGHT_PRE_TAG__}") @property def search_keywords_field(self) -> str: @@ -102,9 +98,7 @@ def get_instance(cls): cls._INSTANCE = cls() return cls._INSTANCE - def _datasets_to_es_query( - self, query: Optional[BackendDatasetsQuery] = None - ) -> Dict[str, Any]: + def _datasets_to_es_query(self, query: Optional[BackendDatasetsQuery] = None) -> Dict[str, Any]: if not query: return filters.match_all() @@ -120,9 +114,7 @@ def _datasets_to_es_query( minimum_should_match=1, # OR Condition should_filters=[ owners_filter, - filters.boolean_filter( - must_not_query=filters.exists_field("owner") - ), + filters.boolean_filter(must_not_query=filters.exists_field("owner")), ], ) ) @@ -279,8 +271,7 @@ def map_2_es_sort_configuration( if valid_fields: if sortable_field.id.split(".")[0] not in valid_fields: raise AssertionError( - f"Wrong sort id {sortable_field.id}. Valid values are: " - f"{[str(v) for v in valid_fields]}" + f"Wrong sort id {sortable_field.id}. Valid values are: " f"{[str(v) for v in valid_fields]}" ) field = sortable_field.id if field == id_field and use_id_keyword: @@ -323,9 +314,7 @@ def _to_es_query(cls, query: BackendRecordsQuery) -> Dict[str, Any]: elif isinstance(value, (str, Enum)): key_filter = filters.term_filter(key, value) elif isinstance(value, QueryRange): - key_filter = filters.range_filter( - field=key, value_from=value.range_from, value_to=value.range_to - ) + key_filter = filters.range_filter(field=key, value_from=value.range_from, value_to=value.range_to) else: cls._LOGGER.warning(f"Cannot parse query value {value} for key {key}") diff --git a/src/argilla/server/daos/datasets.py b/src/argilla/server/daos/datasets.py index c45f8ddd20..d798f823ea 100644 --- a/src/argilla/server/daos/datasets.py +++ b/src/argilla/server/daos/datasets.py @@ -39,9 +39,7 @@ class DatasetsDAO: @classmethod def get_instance( cls, - es: GenericElasticEngineBackend = Depends( - GenericElasticEngineBackend.get_instance - ), + es: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance), records_dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), ) -> "DatasetsDAO": """ @@ -117,9 +115,7 @@ def update_dataset( self, dataset: DatasetDB, ) -> DatasetDB: - self._es.update_dataset_document( - id=dataset.id, document=self._dataset_to_es_doc(dataset) - ) + self._es.update_dataset_document(id=dataset.id, document=self._dataset_to_es_doc(dataset)) return dataset def delete_dataset(self, dataset: DatasetDB): @@ -141,9 +137,7 @@ def find_by_name( return None base_ds = self._es_doc_to_instance(document) if task and task != base_ds.task: - raise WrongTaskError( - detail=f"Provided task {task} cannot be applied to dataset" - ) + raise WrongTaskError(detail=f"Provided task {task} cannot be applied to dataset") dataset_type = as_dataset_class or BaseDatasetDB return self._es_doc_to_instance(document, ds_class=dataset_type) @@ -154,9 +148,7 @@ def _es_doc_to_instance( ) -> DatasetDB: """Transforms a stored elasticsearch document into a `BaseDatasetDB`""" - def key_value_list_to_dict( - key_value_list: List[Dict[str, Any]] - ) -> Dict[str, Any]: + def key_value_list_to_dict(key_value_list: List[Dict[str, Any]]) -> Dict[str, Any]: return {data["key"]: json.loads(data["value"]) for data in key_value_list} tags = doc.get("tags", []) @@ -173,9 +165,7 @@ def key_value_list_to_dict( @staticmethod def _dataset_to_es_doc(dataset: DatasetDB) -> Dict[str, Any]: def dict_to_key_value_list(data: Dict[str, Any]) -> List[Dict[str, Any]]: - return [ - {"key": key, "value": json.dumps(value)} for key, value in data.items() - ] + return [{"key": key, "value": json.dumps(value)} for key, value in data.items()] data = dataset.dict(by_alias=True) tags = data.get("tags", {}) @@ -208,9 +198,7 @@ def open(self, dataset: DatasetDB): def get_all_workspaces(self) -> List[str]: """Get all datasets (Only for super users)""" - metric_data = self._es.compute_argilla_metric( - metric_id="all_argilla_workspaces" - ) + metric_data = self._es.compute_argilla_metric(metric_id="all_argilla_workspaces") return [k for k in metric_data] def save_settings( @@ -228,19 +216,14 @@ def save_settings( def _configure_vectors(self, dataset, settings): if not settings.vectors: return - vectors_cfg = { - k: v.dim if isinstance(v, EmbeddingsConfig) else int(v) - for k, v in settings.vectors.items() - } + vectors_cfg = {k: v.dim if isinstance(v, EmbeddingsConfig) else int(v) for k, v in settings.vectors.items()} self._es.create_dataset( id=dataset.id, task=dataset.task, vectors_cfg=vectors_cfg, ) - def load_settings( - self, dataset: DatasetDB, as_class: Type[DatasetSettingsDB] - ) -> Optional[DatasetSettingsDB]: + def load_settings(self, dataset: DatasetDB, as_class: Type[DatasetSettingsDB]) -> Optional[DatasetSettingsDB]: doc = self._es.find_dataset(id=dataset.id) if doc and "settings" in doc: settings = doc["settings"] diff --git a/src/argilla/server/daos/models/records.py b/src/argilla/server/daos/models/records.py index 117d39581e..14bb2e96d7 100644 --- a/src/argilla/server/daos/models/records.py +++ b/src/argilla/server/daos/models/records.py @@ -60,9 +60,7 @@ class BaseRecordInDB(GenericModel, Generic[AnnotationDB]): metadata: Dict[str, Any] = Field(default=None) event_timestamp: Optional[datetime] = None status: Optional[TaskStatus] = None - prediction: Optional[AnnotationDB] = Field( - None, description="Deprecated. Use `predictions` instead" - ) + prediction: Optional[AnnotationDB] = Field(None, description="Deprecated. Use `predictions` instead") annotation: Optional[AnnotationDB] = None vectors: Optional[Dict[str, BaseEmbeddingVectorDB]] = Field( @@ -98,21 +96,13 @@ def update_annotation(values, annotation_field: str): if not annotation.agent: raise AssertionError("Agent must be defined!") - annotations.update( - { - annotation.agent: annotation.__class__.parse_obj( - annotation.dict(exclude={"agent"}) - ) - } - ) + annotations.update({annotation.agent: annotation.__class__.parse_obj(annotation.dict(exclude={"agent"}))}) values[field_to_update] = annotations if annotations and not annotation: # set first annotation key, value = list(annotations.items())[0] - values[annotation_field] = value.__class__( - agent=key, **value.dict(exclude={"agent"}) - ) + values[annotation_field] = value.__class__(agent=key, **value.dict(exclude={"agent"})) return values diff --git a/src/argilla/server/daos/records.py b/src/argilla/server/daos/records.py index 1ab091d885..83ea922db1 100644 --- a/src/argilla/server/daos/records.py +++ b/src/argilla/server/daos/records.py @@ -38,9 +38,7 @@ class DatasetRecordsDAO: @classmethod def get_instance( cls, - es: GenericElasticEngineBackend = Depends( - GenericElasticEngineBackend.get_instance - ), + es: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance), ) -> "DatasetRecordsDAO": """ Creates a dataset records dao instance @@ -91,9 +89,7 @@ def add_records( mapping = self._es.get_schema(id=dataset.id) exclude_fields = [ - name - for name in record_class.schema()["properties"] - if name not in mapping["mappings"]["properties"] + name for name in record_class.schema()["properties"] if name not in mapping["mappings"]["properties"] ] vectors_configuration = {} @@ -173,9 +169,7 @@ def search_records( except ClosedIndexError: raise ClosedDatasetError(dataset.name) except IndexNotFoundError: - raise MissingDatasetRecordsError( - f"No records index found for dataset {dataset.name}" - ) + raise MissingDatasetRecordsError(f"No records index found for dataset {dataset.name}") def scan_dataset( self, diff --git a/src/argilla/server/errors/base_errors.py b/src/argilla/server/errors/base_errors.py index 36ed5bd700..80e8c2018d 100644 --- a/src/argilla/server/errors/base_errors.py +++ b/src/argilla/server/errors/base_errors.py @@ -46,11 +46,7 @@ def code(self) -> str: @property def arguments(self): - return ( - {k: v for k, v in vars(self).items() if v is not None} - if vars(self) - else None - ) + return {k: v for k, v in vars(self).items() if v is not None} if vars(self) else None def __str__(self): args = self.arguments or {} @@ -77,11 +73,7 @@ def __init__(self, error: Exception): @classmethod def api_documentation(cls): return { - "content": { - "application/json": { - "example": {"detail": {"code": "builtins.TypeError"}} - } - }, + "content": {"application/json": {"example": {"detail": {"code": "builtins.TypeError"}}}}, } diff --git a/src/argilla/server/helpers.py b/src/argilla/server/helpers.py index c523f918ba..81b6ad6be6 100644 --- a/src/argilla/server/helpers.py +++ b/src/argilla/server/helpers.py @@ -22,9 +22,7 @@ _LOGGER = logging.getLogger("argilla.server") -def unflatten_dict( - data: Dict[str, Any], sep: str = ".", stop_keys: Optional[List[str]] = None -) -> Dict[str, Any]: +def unflatten_dict(data: Dict[str, Any], sep: str = ".", stop_keys: Optional[List[str]] = None) -> Dict[str, Any]: """ Given a flat dictionary keys, build a hierarchical version by grouping keys @@ -57,9 +55,7 @@ def unflatten_dict( return resultDict -def flatten_dict( - data: Dict[str, Any], sep: str = ".", drop_empty: bool = False -) -> Dict[str, Any]: +def flatten_dict(data: Dict[str, Any], sep: str = ".", drop_empty: bool = False) -> Dict[str, Any]: """ Flatten a data dictionary diff --git a/src/argilla/server/responses/api_responses.py b/src/argilla/server/responses/api_responses.py index 157e38a2b1..0056b2cda0 100644 --- a/src/argilla/server/responses/api_responses.py +++ b/src/argilla/server/responses/api_responses.py @@ -23,9 +23,7 @@ async def stream_response(self, send: Send) -> None: try: return await super().stream_response(send) except Exception as ex: - json_response: JSONResponse = ( - await APIErrorHandler.common_exception_handler(send, error=ex) - ) + json_response: JSONResponse = await APIErrorHandler.common_exception_handler(send, error=ex) await send( { "type": "http.response.body", diff --git a/src/argilla/server/routes.py b/src/argilla/server/routes.py index be28d9fbb7..45fce83d42 100644 --- a/src/argilla/server/routes.py +++ b/src/argilla/server/routes.py @@ -34,9 +34,7 @@ ) from argilla.server.errors.base_errors import __ALL__ -api_router = APIRouter( - responses={error.HTTP_STATUS: error.api_documentation() for error in __ALL__} -) +api_router = APIRouter(responses={error.HTTP_STATUS: error.api_documentation() for error in __ALL__}) dependencies = [] diff --git a/src/argilla/server/security/auth_provider/local/provider.py b/src/argilla/server/security/auth_provider/local/provider.py index 3a8851df79..3658f7c049 100644 --- a/src/argilla/server/security/auth_provider/local/provider.py +++ b/src/argilla/server/security/auth_provider/local/provider.py @@ -77,17 +77,11 @@ async def login_for_access_token( user = self.users.authenticate_user(form_data.username, form_data.password) if not user: raise UnauthorizedError() - access_token_expires = timedelta( - minutes=self.settings.token_expiration_in_minutes - ) - access_token = self._create_access_token( - user.username, expires_delta=access_token_expires - ) + access_token_expires = timedelta(minutes=self.settings.token_expiration_in_minutes) + access_token = self._create_access_token(user.username, expires_delta=access_token_expires) return Token(access_token=access_token) - def _create_access_token( - self, username: str, expires_delta: Optional[timedelta] = None - ) -> str: + def _create_access_token(self, username: str, expires_delta: Optional[timedelta] = None) -> str: """ Creates an access token @@ -162,9 +156,7 @@ async def get_user( ------- """ - user = await self._find_user_by_api_key( - api_key - ) or await self._find_user_by_api_key(old_api_key) + user = await self._find_user_by_api_key(api_key) or await self._find_user_by_api_key(old_api_key) if user: return user if token: diff --git a/src/argilla/server/security/auth_provider/local/settings.py b/src/argilla/server/security/auth_provider/local/settings.py index 43bce77a3e..b2157e1e61 100644 --- a/src/argilla/server/security/auth_provider/local/settings.py +++ b/src/argilla/server/security/auth_provider/local/settings.py @@ -43,17 +43,13 @@ class Settings(BaseSettings): token_api_url: str = "/api/security/token" default_apikey: str = DEFAULT_API_KEY - default_password: str = ( - "$2y$12$MPcRR71ByqgSI8AaqgxrMeSdrD4BcxDIdYkr.ePQoKz7wsGK7SAca" # 1234 - ) + default_password: str = "$2y$12$MPcRR71ByqgSI8AaqgxrMeSdrD4BcxDIdYkr.ePQoKz7wsGK7SAca" # 1234 users_db_file: str = ".users.yml" @property def public_oauth_token_url(self): """The final public token url used for openapi doc setup""" - return ( - f"{server_settings.base_url}{helpers.remove_prefix(self.token_api_url,'/')}" - ) + return f"{server_settings.base_url}{helpers.remove_prefix(self.token_api_url,'/')}" class Config: env_prefix = "ARGILLA_LOCAL_AUTH_" diff --git a/src/argilla/server/security/auth_provider/local/users/dao.py b/src/argilla/server/security/auth_provider/local/users/dao.py index 0d6e1025bd..895c6fdf77 100644 --- a/src/argilla/server/security/auth_provider/local/users/dao.py +++ b/src/argilla/server/security/auth_provider/local/users/dao.py @@ -34,9 +34,7 @@ def __init__(self, users_file: str): self.__users__: Dict[str, UserInDB] = {} try: with open(users_file) as file: - user_list = [ - UserInDB(**user_data) for user_data in yaml.safe_load(file) - ] + user_list = [UserInDB(**user_data) for user_data in yaml.safe_load(file)] self.__users__ = {user.username: user for user in user_list} except FileNotFoundError: self.__users__ = { diff --git a/src/argilla/server/security/model.py b/src/argilla/server/security/model.py index bb8ac6c972..9f182f5e6c 100644 --- a/src/argilla/server/security/model.py +++ b/src/argilla/server/security/model.py @@ -37,8 +37,7 @@ class User(BaseModel): def check_username(cls, value): if not re.compile(DATASET_NAME_REGEX_PATTERN).match(value): raise ValueError( - "Wrong username. " - f"The username {value} does not match the pattern {DATASET_NAME_REGEX_PATTERN}" + "Wrong username. " f"The username {value} does not match the pattern {DATASET_NAME_REGEX_PATTERN}" ) return value @@ -48,8 +47,7 @@ def check_workspace_pattern(cls, workspace: str): if not workspace: return workspace assert WORKSPACE_NAME_PATTERN.match(workspace), ( - "Wrong workspace format. " - f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}" + "Wrong workspace format. " f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}" ) return workspace diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py index 67e146d670..def025bf53 100644 --- a/src/argilla/server/server.py +++ b/src/argilla/server/server.py @@ -65,9 +65,7 @@ def configure_api_exceptions(api: FastAPI): """Configures fastapi exception handlers""" api.exception_handler(EntityNotFoundError)(APIErrorHandler.common_exception_handler) api.exception_handler(Exception)(APIErrorHandler.common_exception_handler) - api.exception_handler(RequestValidationError)( - APIErrorHandler.common_exception_handler - ) + api.exception_handler(RequestValidationError)(APIErrorHandler.common_exception_handler) def configure_api_router(app: FastAPI): @@ -185,11 +183,7 @@ def configure_telemetry(app): """ ) message += "\n\n " - message += ( - "#set ARGILLA_ENABLE_TELEMETRY=0" - if os.name == "nt" - else "$>export ARGILLA_ENABLE_TELEMETRY=0" - ) + message += "#set ARGILLA_ENABLE_TELEMETRY=0" if os.name == "nt" else "$>export ARGILLA_ENABLE_TELEMETRY=0" message += "\n" @app.on_event("startup") diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index d67914d354..b6af813d4b 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -38,18 +38,14 @@ class ServiceBaseDatasetSettings(BaseDatasetSettingsDB): ServiceDataset = TypeVar("ServiceDataset", bound=ServiceBaseDataset) -ServiceDatasetSettings = TypeVar( - "ServiceDatasetSettings", bound=ServiceBaseDatasetSettings -) +ServiceDatasetSettings = TypeVar("ServiceDatasetSettings", bound=ServiceBaseDatasetSettings) class DatasetsService: _INSTANCE: "DatasetsService" = None @classmethod - def get_instance( - cls, dao: DatasetsDAO = Depends(DatasetsDAO.get_instance) - ) -> "DatasetsService": + def get_instance(cls, dao: DatasetsDAO = Depends(DatasetsDAO.get_instance)) -> "DatasetsService": if not cls._INSTANCE: cls._INSTANCE = cls(dao) return cls._INSTANCE @@ -61,16 +57,10 @@ def create_dataset(self, user: User, dataset: ServiceDataset) -> ServiceDataset: user.check_workspace(dataset.owner) try: - self.find_by_name( - user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner - ) - raise EntityAlreadyExistsError( - name=dataset.name, type=ServiceDataset, workspace=dataset.owner - ) + self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner) + raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.owner) except WrongTaskError: # Found a dataset with same name but different task - raise EntityAlreadyExistsError( - name=dataset.name, type=ServiceDataset, workspace=dataset.owner - ) + raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.owner) except EntityNotFoundError: # The dataset does not exist -> create it ! date_now = datetime.utcnow() @@ -117,9 +107,7 @@ def __find_by_name_with_superuser_fallback__( as_dataset_class: Optional[Type[ServiceDataset]], task: Optional[str] = None, ): - found_ds = self.__dao__.find_by_name( - name=name, owner=owner, task=task, as_dataset_class=as_dataset_class - ) + found_ds = self.__dao__.find_by_name(name=name, owner=owner, task=task, as_dataset_class=as_dataset_class) if not found_ds and user.is_superuser(): try: found_ds = self.__dao__.find_by_name( @@ -157,15 +145,11 @@ def update( tags: Dict[str, str], metadata: Dict[str, Any], ) -> ServiceDataset: - found = self.find_by_name( - user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner - ) + found = self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner) dataset.tags = {**found.tags, **(tags or {})} dataset.metadata = {**found.metadata, **(metadata or {})} - updated = found.copy( - update={**dataset.dict(by_alias=True), "last_updated": datetime.utcnow()} - ) + updated = found.copy(update={**dataset.dict(by_alias=True), "last_updated": datetime.utcnow()}) return self.__dao__.update_dataset(updated) def list( @@ -175,9 +159,7 @@ def list( task2dataset_map: Dict[str, Type[ServiceDataset]] = None, ) -> List[ServiceDataset]: owners = user.check_workspaces(workspaces) - return self.__dao__.list_datasets( - owner_list=owners, task2dataset_map=task2dataset_map - ) + return self.__dao__.list_datasets(owner_list=owners, task2dataset_map=task2dataset_map) def close(self, user: User, dataset: ServiceDataset): found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.owner) diff --git a/src/argilla/server/services/info.py b/src/argilla/server/services/info.py index 6e8b7c2f0d..d4f64a3cc6 100644 --- a/src/argilla/server/services/info.py +++ b/src/argilla/server/services/info.py @@ -72,9 +72,7 @@ class ApiInfoService: @classmethod def get_instance( cls, - backend: GenericElasticEngineBackend = Depends( - GenericElasticEngineBackend.get_instance - ), + backend: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance), ) -> "ApiInfoService": """ Creates an api info service diff --git a/src/argilla/server/services/search/service.py b/src/argilla/server/services/search/service.py index b543a50cf7..4aa71bd2fe 100644 --- a/src/argilla/server/services/search/service.py +++ b/src/argilla/server/services/search/service.py @@ -83,9 +83,7 @@ def search( size=size, record_from=record_from, exclude_fields=exclude_fields, - highligth_results=query is not None - and query.query_text is not None - and len(query.query_text) > 0, + highligth_results=query is not None and query.query_text is not None and len(query.query_text) > 0, ) metrics_results = {} for metric in metrics or []: @@ -98,9 +96,7 @@ def search( ) metrics_results[metric.id] = metrics_ except Exception as ex: - self.__LOGGER__.warning( - "Cannot compute metric [%s]. Error: %s", metric.id, ex - ) + self.__LOGGER__.warning("Cannot compute metric [%s]. Error: %s", metric.id, ex) metrics_results[metric.id] = {} return ServiceSearchResults( diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py index 703adaab55..9fecae36d6 100644 --- a/src/argilla/server/services/storage/service.py +++ b/src/argilla/server/services/storage/service.py @@ -98,9 +98,7 @@ async def delete_records( "Only dataset creators or administrators can delete datasets" ) - processed, deleted = await self.__dao__.delete_records_by_query( - dataset, query=query - ) + processed, deleted = await self.__dao__.delete_records_by_query(dataset, query=query) return DeleteRecordsOut( processed=processed or 0, diff --git a/src/argilla/server/services/tasks/commons/models.py b/src/argilla/server/services/tasks/commons/models.py index a39388230c..6a99db7c7c 100644 --- a/src/argilla/server/services/tasks/commons/models.py +++ b/src/argilla/server/services/tasks/commons/models.py @@ -36,9 +36,7 @@ class BulkResponse(BaseModel): ServiceAnnotation = TypeVar("ServiceAnnotation", bound=ServiceBaseAnnotation) -class ServiceBaseRecordInputs( - BaseRecordInDB[ServiceAnnotation], Generic[ServiceAnnotation] -): +class ServiceBaseRecordInputs(BaseRecordInDB[ServiceAnnotation], Generic[ServiceAnnotation]): pass diff --git a/src/argilla/server/services/tasks/text2text/models.py b/src/argilla/server/services/tasks/text2text/models.py index f8d1e8438c..d07e0c1002 100644 --- a/src/argilla/server/services/tasks/text2text/models.py +++ b/src/argilla/server/services/tasks/text2text/models.py @@ -50,19 +50,11 @@ def all_text(self) -> str: @property def predicted_as(self) -> Optional[List[str]]: - return ( - [sentence.text for sentence in self.prediction.sentences] - if self.prediction - else None - ) + return [sentence.text for sentence in self.prediction.sentences] if self.prediction else None @property def annotated_as(self) -> Optional[List[str]]: - return ( - [sentence.text for sentence in self.annotation.sentences] - if self.annotation - else None - ) + return [sentence.text for sentence in self.annotation.sentences] if self.annotation else None @property def scores(self) -> List[float]: diff --git a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py index 9b0b4871c0..38471bfccc 100644 --- a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py +++ b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py @@ -59,9 +59,7 @@ def __init__(self, datasets: DatasetsDAO, records: DatasetRecordsDAO): self.__records__ = records # TODO(@frascuchon): Move all rules management methods to the common datasets service like settings - def list_rules( - self, dataset: ServiceTextClassificationDataset - ) -> List[ServiceLabelingRule]: + def list_rules(self, dataset: ServiceTextClassificationDataset) -> List[ServiceLabelingRule]: """List a set of rules for a given dataset""" return dataset.rules @@ -72,9 +70,7 @@ def delete_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str dataset.rules = new_rules_set self.__datasets__.update_dataset(dataset) - def add_rule( - self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule - ) -> ServiceLabelingRule: + def add_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> ServiceLabelingRule: """Adds a rule to a dataset""" for r in dataset.rules: if r.query == rule.query: @@ -105,9 +101,7 @@ def compute_rule_metrics( LabelingRuleSummary.parse_obj(metric_data), ) - def _count_annotated_records( - self, dataset: ServiceTextClassificationDataset - ) -> int: + def _count_annotated_records(self, dataset: ServiceTextClassificationDataset) -> int: results = self.__records__.search_records( dataset, size=0, @@ -132,18 +126,14 @@ def all_rules_metrics( DatasetLabelingRulesSummary.parse_obj(metric_data), ) - def find_rule_by_query( - self, dataset: ServiceTextClassificationDataset, rule_query: str - ) -> ServiceLabelingRule: + def find_rule_by_query(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule: rule_query = rule_query.strip() for rule in dataset.rules: if rule.query == rule_query: return rule raise EntityNotFoundError(rule_query, type=ServiceLabelingRule) - def replace_rule( - self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule - ): + def replace_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule): for idx, r in enumerate(dataset.rules): if r.query == rule.query: dataset.rules[idx] = rule diff --git a/src/argilla/server/services/tasks/text_classification/metrics.py b/src/argilla/server/services/tasks/text_classification/metrics.py index 93c2a856dc..d09891620b 100644 --- a/src/argilla/server/services/tasks/text_classification/metrics.py +++ b/src/argilla/server/services/tasks/text_classification/metrics.py @@ -41,11 +41,7 @@ class F1Metric(ServicePythonMetric): def apply(self, records: Iterable[ServiceTextClassificationRecord]) -> Any: filtered_records = list(filter(lambda r: r.predicted is not None, records)) # TODO: This must be precomputed with using a global dataset metric - ds_labels = { - label - for record in filtered_records - for label in record.annotated_as + record.predicted_as - } + ds_labels = {label for record in filtered_records for label in record.annotated_as + record.predicted_as} if not len(ds_labels): return {} @@ -69,12 +65,8 @@ def apply(self, records: Iterable[ServiceTextClassificationRecord]) -> Any: y_true = mlb.fit_transform(y_true) y_pred = mlb.fit_transform(y_pred) - micro_p, micro_r, micro_f, _ = precision_recall_fscore_support( - y_true=y_true, y_pred=y_pred, average="micro" - ) - macro_p, macro_r, macro_f, _ = precision_recall_fscore_support( - y_true=y_true, y_pred=y_pred, average="macro" - ) + micro_p, micro_r, micro_f, _ = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, average="micro") + macro_p, macro_r, macro_f, _ = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, average="macro") per_label = {} for label, p, r, f, s in zip( @@ -127,21 +119,15 @@ def apply( records: Iterable[ServiceTextClassificationRecord], ) -> Dict[str, Any]: ds_labels = set() - for _ in range( - 0, self.records_to_fetch - ): # Only a few of records will be parsed + for _ in range(0, self.records_to_fetch): # Only a few of records will be parsed record = next(records, None) if record is None: break if record.annotation: - ds_labels.update( - [label.class_label for label in record.annotation.labels] - ) + ds_labels.update([label.class_label for label in record.annotation.labels]) if record.prediction: - ds_labels.update( - [label.class_label for label in record.prediction.labels] - ) + ds_labels.update([label.class_label for label in record.prediction.labels]) return {"labels": ds_labels or []} diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py index a94a2632a4..73c5cd7d05 100644 --- a/src/argilla/server/services/tasks/text_classification/model.py +++ b/src/argilla/server/services/tasks/text_classification/model.py @@ -36,21 +36,14 @@ class ServiceLabelingRule(BaseModel): query: str = Field(description="The es rule query") author: str = Field(description="User who created the rule") - created_at: Optional[datetime] = Field( - default_factory=datetime.utcnow, description="Rule creation timestamp" - ) + created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="Rule creation timestamp") - label: Optional[str] = Field( - default=None, description="@Deprecated::The label associated with the rule." - ) + label: Optional[str] = Field(default=None, description="@Deprecated::The label associated with the rule.") labels: List[str] = Field( default_factory=list, - description="For multi label problems, a list of labels. " - "It will replace the `label` field", - ) - description: Optional[str] = Field( - None, description="A brief description of the rule" + description="For multi label problems, a list of labels. " "It will replace the `label` field", ) + description: Optional[str] = Field(None, description="A brief description of the rule") @root_validator def initialize_labels(cls, values): @@ -198,19 +191,13 @@ def _check_annotation_integrity( status: TaskStatus, ): if status == TaskStatus.validated and not multi_label: - assert ( - annotation and len(annotation.labels) > 0 - ), "Annotation must include some label for validated records" + assert annotation and len(annotation.labels) > 0, "Annotation must include some label for validated records" if not multi_label and annotation: - assert ( - len(annotation.labels) == 1 - ), "Single label record must include only one annotation label" + assert len(annotation.labels) == 1, "Single label record must include only one annotation label" @classmethod - def _check_score_integrity( - cls, prediction: TextClassificationAnnotation, multi_label: bool - ): + def _check_score_integrity(cls, prediction: TextClassificationAnnotation, multi_label: bool): """ Checks the score value integrity @@ -235,24 +222,16 @@ def task(cls) -> TaskType: @property def predicted(self) -> Optional[PredictionStatus]: if self.predicted_by and self.annotated_by: - return ( - PredictionStatus.OK - if set(self.predicted_as) == set(self.annotated_as) - else PredictionStatus.KO - ) + return PredictionStatus.OK if set(self.predicted_as) == set(self.annotated_as) else PredictionStatus.KO return None @property def predicted_as(self) -> List[str]: - return self._labels_from_annotation( - self.prediction, multi_label=self.multi_label - ) + return self._labels_from_annotation(self.prediction, multi_label=self.multi_label) @property def annotated_as(self) -> List[str]: - return self._labels_from_annotation( - self.annotation, multi_label=self.multi_label - ) + return self._labels_from_annotation(self.annotation, multi_label=self.multi_label) @property def scores(self) -> List[float]: @@ -263,11 +242,7 @@ def scores(self) -> List[float]: if self.multi_label else [ prediction_class.score - for prediction_class in [ - self._max_class_prediction( - self.prediction, multi_label=self.multi_label - ) - ] + for prediction_class in [self._max_class_prediction(self.prediction, multi_label=self.multi_label)] if prediction_class ] ) @@ -303,22 +278,16 @@ def _labels_from_annotation( return [] if multi_label: - return [ - label.class_label for label in annotation.labels if label.score > 0.5 - ] + return [label.class_label for label in annotation.labels if label.score > 0.5] - class_prediction = cls._max_class_prediction( - annotation, multi_label=multi_label - ) + class_prediction = cls._max_class_prediction(annotation, multi_label=multi_label) if class_prediction is None: return [] return [class_prediction.class_label] @staticmethod - def _max_class_prediction( - p: TextClassificationAnnotation, multi_label: bool - ) -> Optional[ClassPrediction]: + def _max_class_prediction(p: TextClassificationAnnotation, multi_label: bool) -> Optional[ClassPrediction]: if multi_label or p is None or not p.labels: return None return p.labels[0] diff --git a/src/argilla/server/services/tasks/text_classification/service.py b/src/argilla/server/services/tasks/text_classification/service.py index e6351d13e9..8a5407258c 100644 --- a/src/argilla/server/services/tasks/text_classification/service.py +++ b/src/argilla/server/services/tasks/text_classification/service.py @@ -191,16 +191,13 @@ def _check_multi_label_integrity( is_multi_label_dataset = self._is_dataset_multi_label(dataset) if is_multi_label_dataset is not None: is_multi_label = records[0].multi_label - assert is_multi_label == is_multi_label_dataset, ( - "You cannot pass {labels_type} records for this dataset. " - "Stored records are {labels_type}".format( - labels_type="multi-label" if is_multi_label else "single-label" - ) + assert ( + is_multi_label == is_multi_label_dataset + ), "You cannot pass {labels_type} records for this dataset. " "Stored records are {labels_type}".format( + labels_type="multi-label" if is_multi_label else "single-label" ) - def _is_dataset_multi_label( - self, dataset: ServiceTextClassificationDataset - ) -> Optional[bool]: + def _is_dataset_multi_label(self, dataset: ServiceTextClassificationDataset) -> Optional[bool]: try: results = self.__search__.search( dataset, @@ -212,14 +209,10 @@ def _is_dataset_multi_label( if results.records: return results.records[0].multi_label - def get_labeling_rules( - self, dataset: ServiceTextClassificationDataset - ) -> Iterable[ServiceLabelingRule]: + def get_labeling_rules(self, dataset: ServiceTextClassificationDataset) -> Iterable[ServiceLabelingRule]: return self.__labeling__.list_rules(dataset) - def add_labeling_rule( - self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule - ) -> None: + def add_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> None: """ Adds a labeling rule @@ -252,14 +245,10 @@ def update_labeling_rule( self.__labeling__.replace_rule(dataset, found_rule) return found_rule - def find_labeling_rule( - self, dataset: ServiceTextClassificationDataset, rule_query: str - ) -> ServiceLabelingRule: + def find_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule: return self.__labeling__.find_rule_by_query(dataset, rule_query=rule_query) - def delete_labeling_rule( - self, dataset: ServiceTextClassificationDataset, rule_query: str - ): + def delete_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str): if rule_query.strip(): return self.__labeling__.delete_rule(dataset, rule_query) @@ -309,9 +298,7 @@ def compute_rule_metrics( ) coverage = metrics.covered_records / total if total > 0 else None - coverage_annotated = ( - metrics.annotated_covered_records / annotated if annotated > 0 else None - ) + coverage_annotated = metrics.annotated_covered_records / annotated if annotated > 0 else None return LabelingRuleMetricsSummary( total_records=total, @@ -326,9 +313,7 @@ def compute_rule_metrics( def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDataset): total, annotated, metrics = self.__labeling__.all_rules_metrics(dataset) coverage = metrics.covered_records / total if total else None - coverage_annotated = ( - metrics.annotated_covered_records / annotated if annotated else None - ) + coverage_annotated = metrics.annotated_covered_records / annotated if annotated else None return DatasetLabelingRulesMetricsSummary( coverage=coverage, coverage_annotated=coverage_annotated, diff --git a/src/argilla/server/services/tasks/token_classification/metrics.py b/src/argilla/server/services/tasks/token_classification/metrics.py index e210b0960f..9169229b18 100644 --- a/src/argilla/server/services/tasks/token_classification/metrics.py +++ b/src/argilla/server/services/tasks/token_classification/metrics.py @@ -36,9 +36,7 @@ class F1Metric(ServicePythonMetric[ServiceTokenClassificationRecord]): A named entity is correct only if it is an exact match (...).”` """ - def apply( - self, records: Iterable[ServiceTokenClassificationRecord] - ) -> Dict[str, Any]: + def apply(self, records: Iterable[ServiceTokenClassificationRecord]) -> Dict[str, Any]: # store entities per label in dicts predicted_entities = {} annotated_entities = {} @@ -66,9 +64,7 @@ def apply( { f"{label}_precision": precision, f"{label}_recall": recall, - f"{label}_f1": self._safe_divide( - 2 * precision * recall, precision + recall - ), + f"{label}_f1": self._safe_divide(2 * precision * recall, precision + recall), } ) @@ -83,9 +79,7 @@ def apply( averaged_metrics = { "precision_macro": precision_macro, "recall_macro": recall_macro, - "f1_macro": self._safe_divide( - 2 * precision_macro * recall_macro, precision_macro + recall_macro - ), + "f1_macro": self._safe_divide(2 * precision_macro * recall_macro, precision_macro + recall_macro), } precision_micro = self._safe_divide(correct_total, predicted_total) @@ -94,18 +88,14 @@ def apply( { "precision_micro": precision_micro, "recall_micro": recall_micro, - "f1_micro": self._safe_divide( - 2 * precision_micro * recall_micro, precision_micro + recall_micro - ), + "f1_micro": self._safe_divide(2 * precision_micro * recall_micro, precision_micro + recall_micro), } ) return {**averaged_metrics, **per_label_metrics} @staticmethod - def _add_entities_to_dict( - entities: List[EntitySpan], dictionary: Dict[str, Set[Tuple[int, int]]] - ): + def _add_entities_to_dict(entities: List[EntitySpan], dictionary: Dict[str, Set[Tuple[int, int]]]): """Helper function for the apply method.""" for ent in entities: try: @@ -144,21 +134,15 @@ def apply( ) -> Dict[str, Any]: ds_labels = set() - for _ in range( - 0, self.records_to_fetch - ): # Only a few of records will be parsed + for _ in range(0, self.records_to_fetch): # Only a few of records will be parsed record: ServiceTokenClassificationRecord = next(records, None) if record is None: break if record.annotation: - ds_labels.update( - [entity.label for entity in record.annotation.entities] - ) + ds_labels.update([entity.label for entity in record.annotation.entities]) if record.prediction: - ds_labels.update( - [entity.label for entity in record.prediction.entities] - ) + ds_labels.update([entity.label for entity in record.prediction.entities]) return {"labels": ds_labels or []} @@ -252,9 +236,7 @@ def mention_tokens_length(entity: EntitySpan) -> int: label=entity.label, score=entity.score, capitalness=TokenClassificationMetrics.capitalness(mention), - density=TokenClassificationMetrics.density( - _tokens_length, sentence_length=len(record.tokens) - ), + density=TokenClassificationMetrics.density(_tokens_length, sentence_length=len(record.tokens)), tokens_length=_tokens_length, chars_length=len(mention), ) @@ -295,26 +277,18 @@ def record_metrics(cls, record: ServiceTokenClassificationRecord) -> Dict[str, A annotated_tags = cls._compute_iob_tags(span_utils, record.annotation) or [] predicted_tags = cls._compute_iob_tags(span_utils, record.prediction) or [] - tokens_metrics = cls.build_tokens_metrics( - record, predicted_tags or annotated_tags - ) + tokens_metrics = cls.build_tokens_metrics(record, predicted_tags or annotated_tags) return { **base_metrics, "tokens": tokens_metrics, "tokens_length": len(record.tokens), "predicted": { "mentions": cls.mentions_metrics(record, record.predicted_mentions()), - "tags": [ - TokenTagMetrics(tag=tag, value=token) - for tag, token in zip(predicted_tags, record.tokens) - ], + "tags": [TokenTagMetrics(tag=tag, value=token) for tag, token in zip(predicted_tags, record.tokens)], }, "annotated": { "mentions": cls.mentions_metrics(record, record.annotated_mentions()), - "tags": [ - TokenTagMetrics(tag=tag, value=token) - for tag, token in zip(annotated_tags, record.tokens) - ], + "tags": [TokenTagMetrics(tag=tag, value=token) for tag, token in zip(annotated_tags, record.tokens)], }, } diff --git a/src/argilla/server/services/tasks/token_classification/model.py b/src/argilla/server/services/tasks/token_classification/model.py index 926783ecbd..ce01b96887 100644 --- a/src/argilla/server/services/tasks/token_classification/model.py +++ b/src/argilla/server/services/tasks/token_classification/model.py @@ -74,9 +74,7 @@ class ServiceTokenClassificationAnnotation(ServiceBaseAnnotation): score: Optional[float] = None -class ServiceTokenClassificationRecord( - ServiceBaseRecord[ServiceTokenClassificationAnnotation] -): +class ServiceTokenClassificationRecord(ServiceBaseRecord[ServiceTokenClassificationAnnotation]): tokens: List[str] = Field(min_items=1) text: str = Field() _raw_text: Optional[str] = Field(alias="raw_text") @@ -94,8 +92,7 @@ def extended_fields(self) -> Dict[str, Any]: for mention, entity in self.predicted_mentions() ], MENTIONS_ES_FIELD_NAME: [ - {"mention": mention, "entity": entity.label} - for mention, entity in self.annotated_mentions() + {"mention": mention, "entity": entity.label} for mention, entity in self.annotated_mentions() ], } @@ -148,11 +145,7 @@ def predicted(self) -> Optional[PredictionStatus]: return PredictionStatus.KO for ann, pred in zip(annotated_entities, predicted_entities): - if ( - ann.start != pred.start - or ann.end != pred.end - or ann.label != pred.label - ): + if ann.start != pred.start or ann.end != pred.end or ann.label != pred.label: return PredictionStatus.KO return PredictionStatus.OK @@ -179,18 +172,12 @@ def all_text(self) -> str: def predicted_mentions(self) -> List[Tuple[str, EntitySpan]]: return [ - (mention, entity) - for mention, entity in self.__mentions_from_entities__( - self.predicted_entities() - ).items() + (mention, entity) for mention, entity in self.__mentions_from_entities__(self.predicted_entities()).items() ] def annotated_mentions(self) -> List[Tuple[str, EntitySpan]]: return [ - (mention, entity) - for mention, entity in self.__mentions_from_entities__( - self.annotated_entities() - ).items() + (mention, entity) for mention, entity in self.__mentions_from_entities__(self.annotated_entities()).items() ] def annotated_entities(self) -> Set[EntitySpan]: @@ -205,14 +192,8 @@ def predicted_entities(self) -> Set[EntitySpan]: return set() return set(self.prediction.entities) - def __mentions_from_entities__( - self, entities: Set[EntitySpan] - ) -> Dict[str, EntitySpan]: - return { - mention: entity - for entity in entities - for mention in [self.text[entity.start : entity.end]] - } + def __mentions_from_entities__(self, entities: Set[EntitySpan]) -> Dict[str, EntitySpan]: + return {mention: entity for entity in entities for mention in [self.text[entity.start : entity.end]]} class Config: allow_population_by_field_name = True diff --git a/src/argilla/server/services/tasks/token_classification/service.py b/src/argilla/server/services/tasks/token_classification/service.py index 534ff6920b..50ef8dae2d 100644 --- a/src/argilla/server/services/tasks/token_classification/service.py +++ b/src/argilla/server/services/tasks/token_classification/service.py @@ -131,12 +131,8 @@ def search( results.metrics["status"] = results.metrics["status_distribution"] results.metrics["predicted"] = results.metrics["error_distribution"] results.metrics["predicted"].pop("unknown", None) - results.metrics["mentions"] = results.metrics[ - "annotated_mentions_distribution" - ] - results.metrics["predicted_mentions"] = results.metrics[ - "predicted_mentions_distribution" - ] + results.metrics["mentions"] = results.metrics["annotated_mentions_distribution"] + results.metrics["predicted_mentions"] = results.metrics["predicted_mentions_distribution"] return results diff --git a/src/argilla/utils/span_utils.py b/src/argilla/utils/span_utils.py index bed9e172f7..208450a24b 100644 --- a/src/argilla/utils/span_utils.py +++ b/src/argilla/utils/span_utils.py @@ -100,9 +100,7 @@ def validate(self, spans: List[Tuple[str, int, int]]): if not_valid_spans_errors or misaligned_spans_errors: message = "" if not_valid_spans_errors: - message += ( - f"Following entity spans are not valid: {not_valid_spans_errors}\n" - ) + message += f"Following entity spans are not valid: {not_valid_spans_errors}\n" if misaligned_spans_errors: spans = "\n".join(misaligned_spans_errors) @@ -191,9 +189,7 @@ def get_prefix_and_entity(tag_str: str) -> Tuple[str, Optional[str]]: return splits[0], "-".join(splits[1:]) if len(tags) != len(self.tokens): - raise ValueError( - "The list of tags must have the same length as the list of tokens!" - ) + raise ValueError("The list of tags must have the same length as the list of tokens!") spans, start_idx = [], None for idx, tag in enumerate(tags): diff --git a/src/argilla/utils/utils.py b/src/argilla/utils/utils.py index e45e3fa4d9..902d440eca 100644 --- a/src/argilla/utils/utils.py +++ b/src/argilla/utils/utils.py @@ -44,9 +44,7 @@ def __init__( for value in values: self._class_to_module[value] = key # Needed for autocompletion in an IDE - self.__all__ = list(import_structure.keys()) + list( - chain(*import_structure.values()) - ) + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) self.__file__ = module_file self.__spec__ = module_spec self.__path__ = [os.path.dirname(module_file)] @@ -83,9 +81,7 @@ def __getattr__(self, name: str) -> Any: elif name in self._deprecated_modules: value = self._get_module(name, deprecated=True) elif name in self._deprecated_class_to_module.keys(): - module = self._get_module( - self._deprecated_class_to_module[name], deprecated=True, class_name=name - ) + module = self._get_module(self._deprecated_class_to_module[name], deprecated=True, class_name=name) value = getattr(module, name) else: raise AttributeError(f"module {self.__name__} has no attribute {name}") @@ -139,9 +135,7 @@ def limit_value_length(data: Any, max_length: int) -> Any: if isinstance(data, str): return data[-max_length:] if isinstance(data, dict): - return { - k: limit_value_length(v, max_length=max_length) for k, v in data.items() - } + return {k: limit_value_length(v, max_length=max_length) for k, v in data.items()} if isinstance(data, (list, tuple, set)): new_values = map(lambda x: limit_value_length(x, max_length=max_length), data) return type(data)(new_values) @@ -167,9 +161,7 @@ def start_background_loop(loop: asyncio.AbstractEventLoop) -> None: if not (__LOOP__ and __THREAD__): loop = asyncio.new_event_loop() - thread = threading.Thread( - target=start_background_loop, args=(loop,), daemon=True - ) + thread = threading.Thread(target=start_background_loop, args=(loop,), daemon=True) thread.start() __LOOP__, __THREAD__ = loop, thread return __LOOP__, __THREAD__ diff --git a/tests/client/conftest.py b/tests/client/conftest.py index ab95abdd6f..f80b838c10 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -87,13 +87,7 @@ def singlelabel_textclassification_records( id=1, event_timestamp=datetime.datetime(2000, 1, 1), metadata={"mock_metadata": "mock"}, - explanation={ - "text": [ - ar.TokenAttributions( - token="mock", attributions={"a": 0.1, "b": 0.5} - ) - ] - }, + explanation={"text": [ar.TokenAttributions(token="mock", attributions={"a": 0.1, "b": 0.5})]}, status="Validated", ), ar.TextClassificationRecord( @@ -103,13 +97,7 @@ def singlelabel_textclassification_records( id=2, event_timestamp=datetime.datetime(2000, 2, 1), metadata={"mock2_metadata": "mock2"}, - explanation={ - "text": [ - ar.TokenAttributions( - token="mock2", attributions={"a": 0.7, "b": 0.2} - ) - ] - }, + explanation={"text": [ar.TokenAttributions(token="mock2", attributions={"a": 0.7, "b": 0.2})]}, status="Default", ), ar.TextClassificationRecord( @@ -152,8 +140,7 @@ def log_singlelabel_textclassification_records( "multi_label": False, }, records=[ - CreationTextClassificationRecord.from_client(rec) - for rec in singlelabel_textclassification_records + CreationTextClassificationRecord.from_client(rec) for rec in singlelabel_textclassification_records ], ).dict(by_alias=True), ) @@ -174,13 +161,7 @@ def multilabel_textclassification_records(request) -> List[ar.TextClassification id=1, event_timestamp=datetime.datetime(2000, 1, 1), metadata={"mock_metadata": "mock"}, - explanation={ - "text": [ - ar.TokenAttributions( - token="mock", attributions={"a": 0.1, "b": 0.5} - ) - ] - }, + explanation={"text": [ar.TokenAttributions(token="mock", attributions={"a": 0.1, "b": 0.5})]}, status="Validated", ), ar.TextClassificationRecord( @@ -191,13 +172,7 @@ def multilabel_textclassification_records(request) -> List[ar.TextClassification id=2, event_timestamp=datetime.datetime(2000, 2, 1), metadata={"mock2_metadata": "mock2"}, - explanation={ - "text": [ - ar.TokenAttributions( - token="mock2", attributions={"a": 0.7, "b": 0.2} - ) - ] - }, + explanation={"text": [ar.TokenAttributions(token="mock2", attributions={"a": 0.7, "b": 0.2})]}, status="Default", ), ar.TextClassificationRecord( @@ -245,8 +220,7 @@ def log_multilabel_textclassification_records( "multi_label": True, }, records=[ - CreationTextClassificationRecord.from_client(rec) - for rec in multilabel_textclassification_records + CreationTextClassificationRecord.from_client(rec) for rec in multilabel_textclassification_records ], ).dict(by_alias=True), ) @@ -321,10 +295,7 @@ def log_tokenclassification_records( "env": "test", "task": TaskType.token_classification, }, - records=[ - CreationTokenClassificationRecord.from_client(rec) - for rec in tokenclassification_records - ], + records=[CreationTokenClassificationRecord.from_client(rec) for rec in tokenclassification_records], ).dict(by_alias=True), ) @@ -395,9 +366,7 @@ def log_text2text_records( "env": "test", "task": TaskType.text2text, }, - records=[ - CreationText2TextRecord.from_client(rec) for rec in text2text_records - ], + records=[CreationText2TextRecord.from_client(rec) for rec in text2text_records], ).dict(by_alias=True), ) diff --git a/tests/client/sdk/commons/api.py b/tests/client/sdk/commons/api.py index b7ca635486..554a03611c 100644 --- a/tests/client/sdk/commons/api.py +++ b/tests/client/sdk/commons/api.py @@ -81,13 +81,9 @@ def test_build_bulk_response(status_code, expected): elif status_code == 500: server_response = ErrorMessage(detail="test") elif status_code == 422: - server_response = HTTPValidationError( - detail=[ValidationError(loc=["test"], msg="test", type="test")] - ) + server_response = HTTPValidationError(detail=[ValidationError(loc=["test"], msg="test", type="test")]) - httpx_response = HttpxResponse( - status_code=status_code, content=server_response.json() - ) + httpx_response = HttpxResponse(status_code=status_code, content=server_response.json()) response = build_bulk_response(httpx_response, name="mock-dataset", body={}) assert isinstance(response, Response) @@ -112,9 +108,7 @@ def test_build_data_response(status_code, expected): elif status_code == 500: server_response = ErrorMessage(detail="test") elif status_code == 422: - server_response = HTTPValidationError( - detail=[ValidationError(loc=["test"], msg="test", type="test")] - ) + server_response = HTTPValidationError(detail=[ValidationError(loc=["test"], msg="test", type="test")]) httpx_response = HttpxResponse( status_code=status_code, diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py index 75ef26042b..68b91ffe0f 100644 --- a/tests/client/sdk/conftest.py +++ b/tests/client/sdk/conftest.py @@ -61,8 +61,7 @@ def check_schema_props(client_props, server_props): continue if name not in server_props: LOGGER.warning( - f"Client property {name} not found in server properties. " - "Make sure your API compatibility" + f"Client property {name} not found in server properties. " "Make sure your API compatibility" ) different_props.append(name) continue @@ -105,9 +104,7 @@ def _expands_schema( expanded_props = self._expands_schema(field_props, definitions) definition["items"] = expanded_props.get("properties", expanded_props) new_schema[name] = definition - elif "additionalProperties" in definition and "$ref" in definition.get( - "additionalProperties", {} - ): + elif "additionalProperties" in definition and "$ref" in definition.get("additionalProperties", {}): additionalProperties_refs = self._expands_schema( {name: definition["additionalProperties"]}, definitions=definitions, @@ -116,9 +113,7 @@ def _expands_schema( elif "allOf" in definition: allOf_expanded = [ self._expands_schema( - definitions[def_["$ref"].replace("#/definitions/", "")].get( - "properties", {} - ), + definitions[def_["$ref"].replace("#/definitions/", "")].get("properties", {}), definitions, ) for def_ in definition["allOf"] @@ -140,18 +135,14 @@ def helpers(): @pytest.fixture def sdk_client(mocked_client, monkeypatch): - client = AuthenticatedClient( - base_url="http://localhost:6900", token=DEFAULT_API_KEY - ) + client = AuthenticatedClient(base_url="http://localhost:6900", token=DEFAULT_API_KEY) monkeypatch.setattr(client, "__httpx__", mocked_client) return client @pytest.fixture def bulk_textclass_data(): - explanation = { - "text": [ar.TokenAttributions(token="test", attributions={"test": 0.5})] - } + explanation = {"text": [ar.TokenAttributions(token="test", attributions={"test": 0.5})]} records = [ ar.TextClassificationRecord( text="test", diff --git a/tests/client/sdk/datasets/test_models.py b/tests/client/sdk/datasets/test_models.py index c490db0e9b..a1281e782b 100644 --- a/tests/client/sdk/datasets/test_models.py +++ b/tests/client/sdk/datasets/test_models.py @@ -19,13 +19,9 @@ def test_dataset_schema(helpers): client_schema = Dataset.schema() - server_schema = helpers.remove_key( - ServerDataset.schema(), key="created_by" - ) # don't care about creator here + server_schema = helpers.remove_key(ServerDataset.schema(), key="created_by") # don't care about creator here - assert helpers.remove_description(client_schema) == helpers.remove_description( - server_schema - ) + assert helpers.remove_description(client_schema) == helpers.remove_description(server_schema) def test_TaskType_enum(): diff --git a/tests/client/sdk/text2text/test_models.py b/tests/client/sdk/text2text/test_models.py index 57f2e70178..0783606f3e 100644 --- a/tests/client/sdk/text2text/test_models.py +++ b/tests/client/sdk/text2text/test_models.py @@ -65,9 +65,7 @@ def test_from_client_prediction(prediction, expected): sdk_record = CreationText2TextRecord.from_client(record) assert len(sdk_record.prediction.sentences) == len(prediction) - assert all( - [sentence.score == expected for sentence in sdk_record.prediction.sentences] - ) + assert all([sentence.score == expected for sentence in sdk_record.prediction.sentences]) assert sdk_record.metrics == {} diff --git a/tests/client/sdk/text_classification/test_models.py b/tests/client/sdk/text_classification/test_models.py index 9bebb747c5..583483a951 100644 --- a/tests/client/sdk/text_classification/test_models.py +++ b/tests/client/sdk/text_classification/test_models.py @@ -68,15 +68,11 @@ def test_labeling_rule_metrics_schema(helpers): client_schema = LabelingRuleMetricsSummary.schema() server_schema = ServerLabelingRuleMetricsSummary.schema() - assert helpers.remove_description(client_schema) == helpers.remove_description( - server_schema - ) + assert helpers.remove_description(client_schema) == helpers.remove_description(server_schema) def test_from_client_explanation(): - token_attributions = [ - TokenAttributions(token="test", attributions={"label1": 1.0, "label2": 2.0}) - ] + token_attributions = [TokenAttributions(token="test", attributions={"label1": 1.0, "label2": 2.0})] record = TextClassificationRecord( inputs={"text": "test"}, prediction=[("label1", 0.5), ("label2", 0.5)], @@ -90,9 +86,7 @@ def test_from_client_explanation(): assert sdk_record.explanation["text"] == token_attributions -@pytest.mark.parametrize( - "annotation,expected", [("label1", 1), (["label1", "label2"], 2)] -) +@pytest.mark.parametrize("annotation,expected", [("label1", 1), (["label1", "label2"], 2)]) def test_from_client_annotation(annotation, expected): record = TextClassificationRecord( inputs={"text": "test"}, @@ -127,13 +121,9 @@ def test_from_client_agent(pred_agent, annot_agent, pred_expected, annot_expecte assert sdk_record.metrics == {} -@pytest.mark.parametrize( - "multi_label,expected", [(False, "annot_label"), (True, ["annot_label"])] -) +@pytest.mark.parametrize("multi_label,expected", [(False, "annot_label"), (True, ["annot_label"])]) def test_to_client(multi_label, expected): - annotation = TextClassificationAnnotation( - labels=[ClassPrediction(**{"class": "annot_label"})], agent="annot_agent" - ) + annotation = TextClassificationAnnotation(labels=[ClassPrediction(**{"class": "annot_label"})], agent="annot_agent") prediction = TextClassificationAnnotation( labels=[ ClassPrediction(**{"class": "label1", "score": 0.5}), diff --git a/tests/client/sdk/users/test_api.py b/tests/client/sdk/users/test_api.py index 8eb8f2b3c7..a4a1f211c9 100644 --- a/tests/client/sdk/users/test_api.py +++ b/tests/client/sdk/users/test_api.py @@ -26,15 +26,11 @@ def test_whoami(mocked_client, sdk_client): def test_whoami_with_auth_error(monkeypatch, mocked_client): with pytest.raises(UnauthorizedApiError): - sdk_client = AuthenticatedClient( - base_url="http://localhost:6900", token="wrong-apikey" - ) + sdk_client = AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey") monkeypatch.setattr(sdk_client, "__httpx__", mocked_client) whoami(sdk_client) def test_whoami_with_connection_error(): with pytest.raises(BaseClientError): - whoami( - AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey") - ) + whoami(AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey")) diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 12e66b634c..4c7ec4bb44 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -166,9 +166,7 @@ def test_log_something(monkeypatch, mocked_client): assert response.processed == 1 assert response.failed == 0 - response = mocked_client.post( - f"/api/datasets/{dataset_name}/TextClassification:search" - ) + response = mocked_client.post(f"/api/datasets/{dataset_name}/TextClassification:search") assert response.status_code == 200, response.json() results = TextClassificationSearchResults.parse_obj(response.json()) @@ -200,9 +198,7 @@ def test_load_limits(mocked_client, supported_vector_search): def test_log_records_with_too_long_text(mocked_client): dataset_name = "test_log_records_with_too_long_text" mocked_client.delete(f"/api/datasets/{dataset_name}") - item = ar.TextClassificationRecord( - inputs={"text": "This is a toooooo long text\n" * 10000} - ) + item = ar.TextClassificationRecord(inputs={"text": "This is a toooooo long text\n" * 10000}) api.log([item], name=dataset_name) @@ -218,9 +214,7 @@ def test_log_without_name(mocked_client): match="Empty dataset name has been passed as argument.", ): api.log( - ar.TextClassificationRecord( - inputs={"text": "This is a single record. Only this. No more."} - ), + ar.TextClassificationRecord(inputs={"text": "This is a single record. Only this. No more."}), name=None, ) @@ -311,9 +305,7 @@ def inner(*args, **kwargs): return inner with pytest.raises(error_type): - monkeypatch.setattr( - httpx, "delete", send_mock_response_with_http_status(status) - ) + monkeypatch.setattr(httpx, "delete", send_mock_response_with_http_status(status)) api.delete("dataset") diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py index 83ae429af0..404b2768b4 100644 --- a/tests/client/test_dataset.py +++ b/tests/client/test_dataset.py @@ -50,9 +50,7 @@ def test_init_NotImplementedError(self): DatasetBase() def test_init(self, monkeypatch, singlelabel_textclassification_records): - monkeypatch.setattr( - "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord - ) + monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord) ds = DatasetBase( records=singlelabel_textclassification_records, @@ -62,9 +60,7 @@ def test_init(self, monkeypatch, singlelabel_textclassification_records): ds = DatasetBase() assert ds._records == [] - with pytest.raises( - WrongRecordTypeError, match="but you provided Text2TextRecord" - ): + with pytest.raises(WrongRecordTypeError, match="but you provided Text2TextRecord"): DatasetBase( records=[ar.Text2TextRecord(text="test")], ) @@ -90,9 +86,7 @@ def test_init(self, monkeypatch, singlelabel_textclassification_records): ds._from_pandas("mock") def test_to_dataframe(self, monkeypatch, singlelabel_textclassification_records): - monkeypatch.setattr( - "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord - ) + monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord) df = DatasetBase(singlelabel_textclassification_records).to_pandas() @@ -101,9 +95,7 @@ def test_to_dataframe(self, monkeypatch, singlelabel_textclassification_records) assert list(df.columns) == list(TextClassificationRecord.__fields__) def test_prepare_dataset_and_column_mapping(self, monkeypatch, caplog): - monkeypatch.setattr( - "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord - ) + monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord) ds = datasets.Dataset.from_dict( { @@ -119,12 +111,8 @@ def test_prepare_dataset_and_column_mapping(self, monkeypatch, caplog): with pytest.raises(ValueError, match="datasets.DatasetDict` are not supported"): DatasetBase._prepare_dataset_and_column_mapping(ds_dict, None) - col_mapping = dict( - id="ID", inputs=["inputs_a", "inputs_b"], metadata="metadata" - ) - prepared_ds, col_to_be_joined = DatasetBase._prepare_dataset_and_column_mapping( - ds, col_mapping - ) + col_mapping = dict(id="ID", inputs=["inputs_a", "inputs_b"], metadata="metadata") + prepared_ds, col_to_be_joined = DatasetBase._prepare_dataset_and_column_mapping(ds, col_mapping) assert prepared_ds.column_names == ["id", "inputs_a", "inputs_b", "metadata"] assert col_to_be_joined == { @@ -138,12 +126,8 @@ def test_prepare_dataset_and_column_mapping(self, monkeypatch, caplog): ) def test_from_pandas(self, monkeypatch, caplog): - monkeypatch.setattr( - "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord - ) - monkeypatch.setattr( - "argilla.client.datasets.DatasetBase._from_pandas", lambda x: x - ) + monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord) + monkeypatch.setattr("argilla.client.datasets.DatasetBase._from_pandas", lambda x: x) df = pd.DataFrame({"unsupported_column": [None]}) empty_df = DatasetBase.from_pandas(df) @@ -153,8 +137,7 @@ def test_from_pandas(self, monkeypatch, caplog): assert caplog.record_tuples[0][1] == 30 assert ( "Following columns are not supported by the " - "TextClassificationRecord model and are ignored: ['unsupported_column']" - == caplog.record_tuples[0][2] + "TextClassificationRecord model and are ignored: ['unsupported_column']" == caplog.record_tuples[0][2] ) def test_to_datasets(self, monkeypatch, caplog): @@ -170,9 +153,7 @@ def test_to_datasets(self, monkeypatch, caplog): assert len(datasets_ds) == 0 assert len(caplog.record_tuples) == 1 assert caplog.record_tuples[0][1] == 30 - assert ( - "The 'metadata' of the records were removed" in caplog.record_tuples[0][2] - ) + assert "The 'metadata' of the records were removed" in caplog.record_tuples[0][2] def test_datasets_not_installed(self, monkeypatch): monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", "mock") @@ -186,12 +167,8 @@ def test_datasets_wrong_version(self, monkeypatch): with pytest.raises(ModuleNotFoundError, match="pip install -U datasets>1.17.0"): DatasetBase().to_datasets() - def test_iter_len_getitem( - self, monkeypatch, singlelabel_textclassification_records - ): - monkeypatch.setattr( - "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord - ) + def test_iter_len_getitem(self, monkeypatch, singlelabel_textclassification_records): + monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord) dataset = DatasetBase(singlelabel_textclassification_records) for record, expected in zip(dataset, singlelabel_textclassification_records): @@ -201,9 +178,7 @@ def test_iter_len_getitem( assert dataset[1] is singlelabel_textclassification_records[1] def test_setitem_delitem(self, monkeypatch, singlelabel_textclassification_records): - monkeypatch.setattr( - "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord - ) + monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord) dataset = DatasetBase( [rec.copy(deep=True) for rec in singlelabel_textclassification_records], ) @@ -226,12 +201,8 @@ def test_setitem_delitem(self, monkeypatch, singlelabel_textclassification_recor ): dataset[0] = ar.Text2TextRecord(text="mock") - def test_prepare_for_training_train_test_splits( - self, monkeypatch, singlelabel_textclassification_records - ): - monkeypatch.setattr( - "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord - ) + def test_prepare_for_training_train_test_splits(self, monkeypatch, singlelabel_textclassification_records): + monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord) temp_records = copy.deepcopy(singlelabel_textclassification_records) ds = DatasetBase(temp_records) @@ -241,9 +212,7 @@ def test_prepare_for_training_train_test_splits( ): ds.prepare_for_training(train_size=-1) - with pytest.raises( - AssertionError, match="`train_size` and `test_size` must sum to 1." - ): + with pytest.raises(AssertionError, match="`train_size` and `test_size` must sum to 1."): ds.prepare_for_training(test_size=0.1, train_size=0.6) for rec in ds: @@ -307,18 +276,14 @@ def test_to_from_datasets(self, records, request): assert rec.inputs == {"text": "mock"} def test_from_to_datasets_id(self): - dataset_rb = ar.DatasetForTextClassification( - [ar.TextClassificationRecord(text="mock")] - ) + dataset_rb = ar.DatasetForTextClassification([ar.TextClassificationRecord(text="mock")]) dataset_ds = dataset_rb.to_datasets() assert dataset_ds["id"] == [None] assert ar.read_datasets(dataset_ds, task="TextClassification")[0].id is None def test_datasets_empty_metadata(self): - dataset = ar.DatasetForTextClassification( - [ar.TextClassificationRecord(text="mock")] - ) + dataset = ar.DatasetForTextClassification([ar.TextClassificationRecord(text="mock")]) assert dataset.to_datasets()["metadata"] == [None] @pytest.mark.parametrize( @@ -359,9 +324,7 @@ def test_push_to_hub(self, request, name: str): # TODO(@frascuchon): move dataset to new organization dataset_name = f"rubrix/_test_text_classification_records-{name}" dataset_ds = ar.DatasetForTextClassification(records).to_datasets() - _push_to_hub_with_retries( - dataset_ds, repo_id=dataset_name, token=_HF_HUB_ACCESS_TOKEN, private=True - ) + _push_to_hub_with_retries(dataset_ds, repo_id=dataset_name, token=_HF_HUB_ACCESS_TOKEN, private=True) sleep(1) dataset_ds = datasets.load_dataset( dataset_name, @@ -490,9 +453,7 @@ def test_from_datasets_with_annotation_arg(self): } ), ) - dataset_rb = ar.DatasetForTextClassification.from_datasets( - dataset_ds, annotation="label" - ) + dataset_rb = ar.DatasetForTextClassification.from_datasets(dataset_ds, annotation="label") assert [rec.annotation for rec in dataset_rb] == ["HAM", None] @@ -546,32 +507,24 @@ def test_to_from_datasets(self, tokenclassification_records): assert isinstance(dataset, ar.DatasetForTokenClassification) _compare_datasets(dataset, expected_dataset) - missing_optional_cols = datasets.Dataset.from_dict( - {"text": ["mock"], "tokens": [["mock"]]} - ) + missing_optional_cols = datasets.Dataset.from_dict({"text": ["mock"], "tokens": [["mock"]]}) rec = ar.DatasetForTokenClassification.from_datasets(missing_optional_cols)[0] assert rec.text == "mock" and rec.tokens == ["mock"] def test_from_to_datasets_id(self): - dataset_rb = ar.DatasetForTokenClassification( - [ar.TokenClassificationRecord(text="mock", tokens=["mock"])] - ) + dataset_rb = ar.DatasetForTokenClassification([ar.TokenClassificationRecord(text="mock", tokens=["mock"])]) dataset_ds = dataset_rb.to_datasets() assert dataset_ds["id"] == [None] assert ar.read_datasets(dataset_ds, task="TokenClassification")[0].id is None def test_prepare_for_training_empty(self): - dataset = ar.DatasetForTokenClassification( - [ar.TokenClassificationRecord(text="mock", tokens=["mock"])] - ) + dataset = ar.DatasetForTokenClassification([ar.TokenClassificationRecord(text="mock", tokens=["mock"])]) with pytest.raises(AssertionError): dataset.prepare_for_training() def test_datasets_empty_metadata(self): - dataset = ar.DatasetForTokenClassification( - [ar.TokenClassificationRecord(text="mock", tokens=["mock"])] - ) + dataset = ar.DatasetForTokenClassification([ar.TokenClassificationRecord(text="mock", tokens=["mock"])]) assert dataset.to_datasets()["metadata"] == [None] def test_to_from_pandas(self, tokenclassification_records): @@ -593,9 +546,7 @@ def test_to_from_pandas(self, tokenclassification_records): reason="You need a HF Hub access token to test the push_to_hub feature", ) def test_push_to_hub(self, tokenclassification_records): - dataset_ds = ar.DatasetForTokenClassification( - tokenclassification_records - ).to_datasets() + dataset_ds = ar.DatasetForTokenClassification(tokenclassification_records).to_datasets() _push_to_hub_with_retries( dataset_ds, # TODO(@frascuchon): Move dataset to the new org @@ -624,26 +575,18 @@ def test_prepare_for_training_with_spacy(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, split="train", ) - rb_dataset: DatasetForTokenClassification = ar.read_datasets( - ner_dataset, task="TokenClassification" - ) + rb_dataset: DatasetForTokenClassification = ar.read_datasets(ner_dataset, task="TokenClassification") for r in rb_dataset: - r.annotation = [ - (label, start, end) for label, start, end, _ in r.prediction - ] + r.annotation = [(label, start, end) for label, start, end, _ in r.prediction] with pytest.raises(ValueError): train = rb_dataset.prepare_for_training(framework="spacy") - train = rb_dataset.prepare_for_training( - framework="spacy", lang=spacy.blank("en") - ) + train = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en")) assert isinstance(train, spacy.tokens.DocBin) assert len(train) == 100 - train, test = rb_dataset.prepare_for_training( - framework="spacy", lang=spacy.blank("en"), train_size=0.8 - ) + train, test = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en"), train_size=0.8) assert isinstance(train, spacy.tokens.DocBin) assert isinstance(test, spacy.tokens.DocBin) assert len(train) == 80 @@ -660,21 +603,15 @@ def test_prepare_for_training_with_spark_nlp(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, split="train", ) - rb_dataset: DatasetForTokenClassification = ar.read_datasets( - ner_dataset, task="TokenClassification" - ) + rb_dataset: DatasetForTokenClassification = ar.read_datasets(ner_dataset, task="TokenClassification") for r in rb_dataset: - r.annotation = [ - (label, start, end) for label, start, end, _ in r.prediction - ] + r.annotation = [(label, start, end) for label, start, end, _ in r.prediction] train = rb_dataset.prepare_for_training(framework="spark-nlp") assert isinstance(train, pd.DataFrame) assert len(train) == 100 - train, test = rb_dataset.prepare_for_training( - framework="spark-nlp", train_size=0.8 - ) + train, test = rb_dataset.prepare_for_training(framework="spark-nlp", train_size=0.8) assert isinstance(train, pd.DataFrame) assert isinstance(test, pd.DataFrame) assert len(train) == 80 @@ -691,18 +628,12 @@ def test_prepare_for_training(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, split="train", ) - rb_dataset: DatasetForTokenClassification = ar.read_datasets( - ner_dataset, task="TokenClassification" - ) + rb_dataset: DatasetForTokenClassification = ar.read_datasets(ner_dataset, task="TokenClassification") for r in rb_dataset: - r.annotation = [ - (label, start, end) for label, start, end, _ in r.prediction - ] + r.annotation = [(label, start, end) for label, start, end, _ in r.prediction] train = rb_dataset.prepare_for_training() - assert isinstance(train, datasets.DatasetD.Dataset) or isinstance( - train, datasets.Dataset - ) + assert isinstance(train, datasets.DatasetD.Dataset) or isinstance(train, datasets.Dataset) assert "ner_tags" in train.column_names assert len(train) == 100 assert train.features["ner_tags"] == [ @@ -760,9 +691,7 @@ def test_from_dataset_with_non_argilla_format(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, ) - rb_ds = ar.DatasetForTokenClassification.from_datasets( - ds, tags="ner_tags", metadata=["spans"] - ) + rb_ds = ar.DatasetForTokenClassification.from_datasets(ds, tags="ner_tags", metadata=["spans"]) again_the_ds = rb_ds.to_datasets() assert again_the_ds.column_names == [ @@ -781,9 +710,7 @@ def test_from_dataset_with_non_argilla_format(self): def test_from_datasets_with_empty_tokens(self, caplog): dataset_ds = datasets.Dataset.from_dict({"empty_tokens": [["mock"], []]}) - dataset_rb = ar.DatasetForTokenClassification.from_datasets( - dataset_ds, tokens="empty_tokens" - ) + dataset_rb = ar.DatasetForTokenClassification.from_datasets(dataset_ds, tokens="empty_tokens") assert caplog.record_tuples[0][1] == 30 assert caplog.record_tuples[0][2] == "Ignoring row with no tokens." @@ -834,9 +761,7 @@ def test_to_from_datasets(self, text2text_records): assert rec.text == "mock" # alternative format for the predictions - ds = datasets.Dataset.from_dict( - {"text": ["example"], "prediction": [["ejemplo"]]} - ) + ds = datasets.Dataset.from_dict({"text": ["example"], "prediction": [["ejemplo"]]}) rec = ar.DatasetForText2Text.from_datasets(ds)[0] assert rec.prediction[0][0] == "ejemplo" assert rec.prediction[0][1] == pytest.approx(1.0) @@ -899,9 +824,7 @@ def test_from_dataset_with_non_argilla_format(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, ) - rb_ds = ar.DatasetForText2Text.from_datasets( - ds, text="description", annotation="abstract" - ) + rb_ds = ar.DatasetForText2Text.from_datasets(ds, text="description", annotation="abstract") again_the_ds = rb_ds.to_datasets() assert again_the_ds.column_names == [ @@ -924,9 +847,7 @@ def _compare_datasets(dataset, expected_dataset): # TODO: have to think about how we deal with `None`s if col in ["metadata", "metrics"]: continue - assert getattr(rec, col) == getattr( - expected, col - ), f"Wrong column value '{col}'" + assert getattr(rec, col) == getattr(expected, col), f"Wrong column value '{col}'" @pytest.mark.parametrize( @@ -941,9 +862,7 @@ def test_read_pandas(monkeypatch, task, dataset_class): def mock_from_pandas(mock): return mock - monkeypatch.setattr( - f"argilla.client.datasets.{dataset_class}.from_pandas", mock_from_pandas - ) + monkeypatch.setattr(f"argilla.client.datasets.{dataset_class}.from_pandas", mock_from_pandas) assert ar.read_pandas("mock", task) == "mock" @@ -960,8 +879,6 @@ def test_read_datasets(monkeypatch, task, dataset_class): def mock_from_datasets(mock): return mock - monkeypatch.setattr( - f"argilla.client.datasets.{dataset_class}.from_datasets", mock_from_datasets - ) + monkeypatch.setattr(f"argilla.client.datasets.{dataset_class}.from_datasets", mock_from_datasets) assert ar.read_datasets("mock", task) == "mock" diff --git a/tests/client/test_init.py b/tests/client/test_init.py index 2b998d2e02..5efebd74d4 100644 --- a/tests/client/test_init.py +++ b/tests/client/test_init.py @@ -45,9 +45,7 @@ def test_init_with_extra_headers(mocked_client): active_api = api.active_api() for key, value in expected_headers.items(): - assert ( - active_api.http_client.headers[key] == value - ), f"{key}:{value} not in client headers" + assert active_api.http_client.headers[key] == value, f"{key}:{value} not in client headers" def test_init(mocked_client): diff --git a/tests/client/test_models.py b/tests/client/test_models.py index bf3d72c366..7fe2c38e0f 100644 --- a/tests/client/test_models.py +++ b/tests/client/test_models.py @@ -40,29 +40,21 @@ def test_text_classification_record(annotation, status, expected_status): """Just testing its dynamic defaults""" if status: - record = TextClassificationRecord( - inputs={"text": "test"}, annotation=annotation, status=status - ) + record = TextClassificationRecord(inputs={"text": "test"}, annotation=annotation, status=status) else: - record = TextClassificationRecord( - inputs={"text": "test"}, annotation=annotation - ) + record = TextClassificationRecord(inputs={"text": "test"}, annotation=annotation) assert record.status == expected_status def test_text_classification_input_string(): event_timestamp = datetime.datetime.now() - assert TextClassificationRecord( - text="A text", event_timestamp=event_timestamp - ) == TextClassificationRecord( + assert TextClassificationRecord(text="A text", event_timestamp=event_timestamp) == TextClassificationRecord( inputs=dict(text="A text"), event_timestamp=event_timestamp ) assert TextClassificationRecord( inputs=["A text", "another text"], event_timestamp=event_timestamp - ) == TextClassificationRecord( - inputs=dict(text=["A text", "another text"]), event_timestamp=event_timestamp - ) + ) == TextClassificationRecord(inputs=dict(text=["A text", "another text"]), event_timestamp=event_timestamp) def test_text_classification_text_inputs(): @@ -74,38 +66,27 @@ def test_text_classification_text_inputs(): with pytest.warns( FutureWarning, - match=( - "the `inputs` argument of the `TextClassificationRecord` will not accept" - " strings." - ), + match=("the `inputs` argument of the `TextClassificationRecord` will not accept" " strings."), ): TextClassificationRecord(inputs="mock") event_timestamp = datetime.datetime.now() - assert TextClassificationRecord( - text="mock", event_timestamp=event_timestamp - ) == TextClassificationRecord( + assert TextClassificationRecord(text="mock", event_timestamp=event_timestamp) == TextClassificationRecord( inputs={"text": "mock"}, event_timestamp=event_timestamp ) - assert TextClassificationRecord( - inputs=["mock"], event_timestamp=event_timestamp - ) == TextClassificationRecord( + assert TextClassificationRecord(inputs=["mock"], event_timestamp=event_timestamp) == TextClassificationRecord( inputs={"text": ["mock"]}, event_timestamp=event_timestamp ) assert TextClassificationRecord( text="mock", inputs={"text": "mock"}, event_timestamp=event_timestamp - ) == TextClassificationRecord( - inputs={"text": "mock"}, event_timestamp=event_timestamp - ) + ) == TextClassificationRecord(inputs={"text": "mock"}, event_timestamp=event_timestamp) rec = TextClassificationRecord(text="mock") with pytest.raises(AttributeError, match="You cannot assign a new value to `text`"): rec.text = "mock" - with pytest.raises( - AttributeError, match="You cannot assign a new value to `inputs`" - ): + with pytest.raises(AttributeError, match="You cannot assign a new value to `inputs`"): rec.inputs = "mock" @@ -120,9 +101,7 @@ def test_text_classification_text_inputs(): ) def test_token_classification_record(annotation, status, expected_status, expected_iob): """Just testing its dynamic defaults""" - record = TokenClassificationRecord( - text="test text", tokens=["test", "text"], annotation=annotation, status=status - ) + record = TokenClassificationRecord(text="test text", tokens=["test", "text"], annotation=annotation, status=status) assert record.status == expected_status assert record.spans2iob(record.annotation) == expected_iob @@ -146,10 +125,7 @@ def test_token_classification_with_tokens_and_tags(tokens, tags, annotation): def test_token_classification_validations(): with pytest.raises( AssertionError, - match=( - "Missing fields: " - "At least one of `text` or `tokens` argument must be provided!" - ), + match=("Missing fields: " "At least one of `text` or `tokens` argument must be provided!"), ): TokenClassificationRecord() @@ -157,25 +133,17 @@ def test_token_classification_validations(): annotation = [("test", 0, 4)] with pytest.raises( AssertionError, - match=( - "Missing field `text`: " - "char level spans must be provided with a raw text sentence" - ), + match=("Missing field `text`: " "char level spans must be provided with a raw text sentence"), ): TokenClassificationRecord(tokens=tokens, annotation=annotation) with pytest.raises( AssertionError, - match=( - "Missing field `text`: " - "char level spans must be provided with a raw text sentence" - ), + match=("Missing field `text`: " "char level spans must be provided with a raw text sentence"), ): TokenClassificationRecord(tokens=tokens, prediction=annotation) - TokenClassificationRecord( - text=" ".join(tokens), tokens=tokens, prediction=annotation - ) + TokenClassificationRecord(text=" ".join(tokens), tokens=tokens, prediction=annotation) record = TokenClassificationRecord(tokens=tokens) assert record.text == "test text" @@ -185,16 +153,12 @@ def test_token_classification_with_mutation(): text_a = "The text" text_b = "Another text sample here !!!" - record = TokenClassificationRecord( - text=text_a, tokens=text_a.split(" "), annotation=[] - ) + record = TokenClassificationRecord(text=text_a, tokens=text_a.split(" "), annotation=[]) assert record.spans2iob(record.annotation) == ["O"] * len(text_a.split(" ")) with pytest.raises(AttributeError, match="You cannot assign a new value to `text`"): record.text = text_b - with pytest.raises( - AttributeError, match="You cannot assign a new value to `tokens`" - ): + with pytest.raises(AttributeError, match="You cannot assign a new value to `tokens`"): record.tokens = text_b.split(" ") @@ -212,9 +176,7 @@ def test_token_classification_with_mutation(): ], ) def test_token_classification_prediction_validator(prediction, expected): - record = TokenClassificationRecord( - text="this", tokens=["this"], prediction=prediction - ) + record = TokenClassificationRecord(text="this", tokens=["this"], prediction=prediction) assert record.prediction == expected @@ -246,9 +208,7 @@ def test_metadata_values_length(): def test_model_serialization_with_numpy_nan(): - record = Text2TextRecord( - text="My name is Sarah and I love my dog.", metadata={"nan": numpy.nan} - ) + record = Text2TextRecord(text="My name is Sarah and I love my dog.", metadata={"nan": numpy.nan}) json_record = json.loads(record.json()) @@ -260,13 +220,9 @@ class MockRecord(_Validators): annotation: Optional[Any] = None annotation_agent: Optional[str] = None - with pytest.warns( - UserWarning, match="`prediction_agent` will not be logged to the server." - ): + with pytest.warns(UserWarning, match="`prediction_agent` will not be logged to the server."): MockRecord(prediction_agent="mock") - with pytest.warns( - UserWarning, match="`annotation_agent` will not be logged to the server." - ): + with pytest.warns(UserWarning, match="`annotation_agent` will not be logged to the server."): MockRecord(annotation_agent="mock") @@ -312,9 +268,7 @@ def test_text2text_prediction_validator(prediction, expected): "record", [ TextClassificationRecord(text="This is a test"), - TokenClassificationRecord( - text="This is a test", tokens="This is a test".split() - ), + TokenClassificationRecord(text="This is a test", tokens="This is a test".split()), Text2TextRecord(text="This is a test"), ], ) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 4499d25964..7384aa9bae 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -57,9 +57,7 @@ def test_delete_dataset_by_non_creator(mocked_client): try: dataset = "test_delete_dataset_by_non_creator" ar.delete(dataset) - ar.configure_dataset( - dataset, settings=TextClassificationSettings(label_schema={"A", "B", "C"}) - ) + ar.configure_dataset(dataset, settings=TextClassificationSettings(label_schema={"A", "B", "C"})) mocked_client.change_current_user("mock-user") with pytest.raises(ForbiddenApiError): ar.delete(dataset) diff --git a/tests/functional_tests/datasets/test_delete_records_from_datasets.py b/tests/functional_tests/datasets/test_delete_records_from_datasets.py index 238c358338..56b2e6820c 100644 --- a/tests/functional_tests/datasets/test_delete_records_from_datasets.py +++ b/tests/functional_tests/datasets/test_delete_records_from_datasets.py @@ -26,10 +26,7 @@ def test_delete_records_from_dataset(mocked_client): ar.log( name=dataset, records=[ - ar.TextClassificationRecord( - id=i, text="This is the text", metadata=dict(idx=i) - ) - for i in range(0, 50) + ar.TextClassificationRecord(id=i, text="This is the text", metadata=dict(idx=i)) for i in range(0, 50) ], ) @@ -40,9 +37,7 @@ def test_delete_records_from_dataset(mocked_client): assert len(ds) == 50 time.sleep(1) - matched, processed = ar.delete_records( - name=dataset, query="id:10", discard_only=False - ) + matched, processed = ar.delete_records(name=dataset, query="id:10", discard_only=False) assert matched, processed == (1, 1) time.sleep(1) @@ -58,10 +53,7 @@ def test_delete_records_without_permission(mocked_client): ar.log( name=dataset, records=[ - ar.TextClassificationRecord( - id=i, text="This is the text", metadata=dict(idx=i) - ) - for i in range(0, 50) + ar.TextClassificationRecord(id=i, text="This is the text", metadata=dict(idx=i)) for i in range(0, 50) ], ) try: diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py index 9fc0b9b439..90ea518ae1 100644 --- a/tests/functional_tests/search/test_search_service.py +++ b/tests/functional_tests/search/test_search_service.py @@ -70,9 +70,7 @@ def test_query_builder_with_query_range(backend: GenericElasticEngineBackend): } -def test_query_builder_with_nested( - mocked_client, dao, backend: GenericElasticEngineBackend -): +def test_query_builder_with_nested(mocked_client, dao, backend: GenericElasticEngineBackend): dataset = Dataset( name="test_query_builder_with_nested", owner=argilla.get_workspace(), @@ -106,20 +104,8 @@ def test_query_builder_with_nested( "query": { "bool": { "must": [ - { - "term": { - "metrics.predicted.mentions.label": { - "value": "NAME" - } - } - }, - { - "range": { - "metrics.predicted.mentions.score": { - "lte": "0.1" - } - } - }, + {"term": {"metrics.predicted.mentions.label": {"value": "NAME"}}}, + {"range": {"metrics.predicted.mentions.score": {"lte": "0.1"}}}, ] } }, diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py index 5e2a16f8bb..be50cdf406 100644 --- a/tests/functional_tests/test_log_for_text_classification.py +++ b/tests/functional_tests/test_log_for_text_classification.py @@ -226,9 +226,7 @@ def test_search_keywords(mocked_client): top_keywords = set( [ keyword - for keywords in df.search_keywords.value_counts(sort=True, ascending=False) - .index[:3] - .tolist() + for keywords in df.search_keywords.value_counts(sort=True, ascending=False).index[:3].tolist() for keyword in keywords ] ) @@ -262,17 +260,12 @@ def test_logging_with_metadata_limits_exceeded(mocked_client): expected_record = ar.TextClassificationRecord( text="The input text", - metadata={ - k: f"this is a string {k}" - for k in range(0, settings.metadata_fields_limit + 1) - }, + metadata={k: f"this is a string {k}" for k in range(0, settings.metadata_fields_limit + 1)}, ) with pytest.raises(BadRequestApiError): ar.log(expected_record, name=dataset) - expected_record.metadata = { - k: f"This is a string {k}" for k in range(0, settings.metadata_fields_limit) - } + expected_record.metadata = {k: f"This is a string {k}" for k in range(0, settings.metadata_fields_limit)} # Dataset creation with data ar.log(expected_record, name=dataset) # This call will check already included fields diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py index 2df8c37c5c..7aa45b3ed4 100644 --- a/tests/functional_tests/test_log_for_token_classification.py +++ b/tests/functional_tests/test_log_for_token_classification.py @@ -29,9 +29,7 @@ def test_log_with_empty_text(mocked_client): text = " " argilla.delete(dataset) - with pytest.raises( - Exception, match="The provided `text` contains only whitespaces." - ): + with pytest.raises(Exception, match="The provided `text` contains only whitespaces."): argilla.log( TokenClassificationRecord(id=0, text=text, tokens=["a", "b", "c"]), name=dataset, @@ -531,9 +529,7 @@ def test_search_keywords(mocked_client): top_keywords = set( [ keyword - for keywords in df.search_keywords.value_counts(sort=True, ascending=False) - .index[:3] - .tolist() + for keywords in df.search_keywords.value_counts(sort=True, ascending=False).index[:3].tolist() for keyword in keywords ] ) diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py index f33f9d8ab6..b07780ec71 100644 --- a/tests/labeling/text_classification/test_label_errors.py +++ b/tests/labeling/text_classification/test_label_errors.py @@ -27,15 +27,11 @@ from pkg_resources import parse_version -@pytest.fixture( - params=[False, True], ids=["single_label", "multi_label"], scope="module" -) +@pytest.fixture(params=[False, True], ids=["single_label", "multi_label"], scope="module") def records(request): if request.param: return [ - ar.TextClassificationRecord( - text="test", annotation=anot, prediction=pred, multi_label=True, id=i - ) + ar.TextClassificationRecord(text="test", annotation=anot, prediction=pred, multi_label=True, id=i) for i, anot, pred in zip( range(2 * 6), [["bad"], ["bad", "good"]] * 6, @@ -70,9 +66,7 @@ def test_no_records(): ar.TextClassificationRecord(text="test", annotation="test"), ] - with pytest.raises( - NoRecordsError, match="none of your records have a prediction AND annotation" - ): + with pytest.raises(NoRecordsError, match="none of your records have a prediction AND annotation"): find_label_errors(records) @@ -84,10 +78,7 @@ def test_multi_label_warning(caplog): multi_label=True, ) find_label_errors([record], multi_label="True") - assert ( - "You provided the kwarg 'multi_label', but it is determined automatically" - in caplog.text - ) + assert "You provided the kwarg 'multi_label', but it is determined automatically" in caplog.text @pytest.mark.parametrize( @@ -120,9 +111,7 @@ def mock_find_label_issues(*args, **kwargs): mock_find_label_issues, ) - record = ar.TextClassificationRecord( - text="mock", prediction=[("mock", 0.1)], annotation="mock" - ) + record = ar.TextClassificationRecord(text="mock", prediction=[("mock", 0.1)], annotation="mock") find_label_errors(records=[record], sort_by=sort_by) @@ -144,9 +133,7 @@ def mock_get_noise_indices(s, psx, n_jobs, **kwargs): mock_get_noise_indices, ) - with pytest.raises( - ValueError, match="'sorted_index_method' kwarg is not supported" - ): + with pytest.raises(ValueError, match="'sorted_index_method' kwarg is not supported"): find_label_errors(records=records, sorted_index_method="mock") find_label_errors(records=records, mock="mock") @@ -156,9 +143,7 @@ def mock_find_label_issues(s, psx, n_jobs, **kwargs): assert kwargs == { "mock": "mock", "multi_label": is_multi_label, - "return_indices_ranked_by": "normalized_margin" - if not is_multi_label - else "self_confidence", + "return_indices_ranked_by": "normalized_margin" if not is_multi_label else "self_confidence", } return [] @@ -167,9 +152,7 @@ def mock_find_label_issues(s, psx, n_jobs, **kwargs): mock_find_label_issues, ) - with pytest.raises( - ValueError, match="'return_indices_ranked_by' kwarg is not supported" - ): + with pytest.raises(ValueError, match="'return_indices_ranked_by' kwarg is not supported"): find_label_errors(records=records, return_indices_ranked_by="mock") find_label_errors(records=records, mock="mock") @@ -207,22 +190,14 @@ def test_construct_s_and_psx(records): def test_missing_predictions(): - records = [ - ar.TextClassificationRecord( - text="test", annotation="mock", prediction=[("mock2", 0.1)] - ) - ] + records = [ar.TextClassificationRecord(text="test", annotation="mock", prediction=[("mock2", 0.1)])] with pytest.raises( MissingPredictionError, match="It seems predictions are missing for the label 'mock'", ): _construct_s_and_psx(records) - records.append( - ar.TextClassificationRecord( - text="test", annotation="mock", prediction=[("mock", 0.1)] - ) - ) + records.append(ar.TextClassificationRecord(text="test", annotation="mock", prediction=[("mock", 0.1)])) with pytest.raises( MissingPredictionError, match="It seems a prediction for 'mock' is missing in the following record", diff --git a/tests/labeling/text_classification/test_label_models.py b/tests/labeling/text_classification/test_label_models.py index 0cde14796f..4cad2d86f4 100644 --- a/tests/labeling/text_classification/test_label_models.py +++ b/tests/labeling/text_classification/test_label_models.py @@ -43,9 +43,7 @@ def mock_load(*args, **kwargs): TextClassificationRecord(text="test", annotation="neutral"), ] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array( @@ -63,17 +61,13 @@ def mock_apply(self, *args, **kwargs): @pytest.fixture def weak_labels_from_guide(monkeypatch, resources): - matrix_and_annotation = np.load( - str(resources / "weak-supervision-guide-matrix.npy") - ) + matrix_and_annotation = np.load(str(resources / "weak-supervision-guide-matrix.npy")) matrix, annotation = matrix_and_annotation[:, :-1], matrix_and_annotation[:, -1] def mock_load(*args, **kwargs): return [TextClassificationRecord(text="mock", id=i) for i in range(len(matrix))] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): return matrix, annotation, {None: -1, "SPAM": 0, "HAM": 1} @@ -87,19 +81,13 @@ def mock_apply(self, *args, **kwargs): def weak_multi_labels(monkeypatch): def mock_load(*args, **kwargs): return [ - TextClassificationRecord( - text="test", multi_label=True, annotation=["scared"] - ), - TextClassificationRecord( - text="test", multi_label=True, annotation=["sad", "scared"] - ), + TextClassificationRecord(text="test", multi_label=True, annotation=["scared"]), + TextClassificationRecord(text="test", multi_label=True, annotation=["sad", "scared"]), TextClassificationRecord(text="test", multi_label=True, annotation=[]), TextClassificationRecord(text="test", multi_label=True), ] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array( @@ -111,9 +99,7 @@ def mock_apply(self, *args, **kwargs): ], dtype=np.byte, ) - annotation_array = np.array( - [[0, 0, 1], [1, 0, 1], [0, 0, 0], [-1, -1, -1]], dtype=np.byte - ) + annotation_array = np.array([[0, 0, 1], [1, 0, 1], [0, 0, 0], [-1, -1, -1]], dtype=np.byte) labels = ["sad", "happy", "scared"] return weak_label_matrix, annotation_array, labels @@ -159,9 +145,7 @@ def test_no_need_to_fit_error(self): ("weak_multi_labels", False, 1), ], ) - def test_predict( - self, monkeypatch, request, wls, include_annotated_records, expected - ): + def test_predict(self, monkeypatch, request, wls, include_annotated_records, expected): def compute_probs(self, wl_matrix, **kwargs): assert len(wl_matrix) == expected compute_probs.called = None @@ -171,20 +155,13 @@ def make_records(self, probabilities, records, **kwargs): return records single_or_multi = "multi" if wls == "weak_multi_labels" else "single" - monkeypatch.setattr( - MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs - ) - monkeypatch.setattr( - MajorityVoter, f"_make_{single_or_multi}_label_records", make_records - ) + monkeypatch.setattr(MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs) + monkeypatch.setattr(MajorityVoter, f"_make_{single_or_multi}_label_records", make_records) weak_labels = request.getfixturevalue(wls) mj = MajorityVoter(weak_labels) - assert ( - len(mj.predict(include_annotated_records=include_annotated_records)) - == expected - ) + assert len(mj.predict(include_annotated_records=include_annotated_records)) == expected assert hasattr(compute_probs, "called") def test_compute_single_label_probs(self, weak_labels): @@ -210,9 +187,7 @@ def test_compute_single_label_probs(self, weak_labels): (False, TieBreakPolicy.RANDOM, 4), ], ) - def test_make_single_label_records( - self, weak_labels, include_abstentions, tie_break_policy, expected - ): + def test_make_single_label_records(self, weak_labels, include_abstentions, tie_break_policy, expected): mj = MajorityVoter(weak_labels) probs = mj._compute_single_label_probs(weak_labels.matrix()) @@ -273,9 +248,7 @@ def test_compute_multi_label_probs(self, weak_multi_labels): (False, 3), ], ) - def test_make_multi_label_records( - self, weak_multi_labels, include_abstentions, expected - ): + def test_make_multi_label_records(self, weak_multi_labels, include_abstentions, expected): mj = MajorityVoter(weak_multi_labels) probs = mj._compute_multi_label_probs(weak_multi_labels.matrix()) @@ -324,9 +297,7 @@ def score(self, probabilities, tie_break_policy=None): return np.array([[1, 1, 1], [0, 0, 0]]), np.array([[1, 1, 1], [1, 0, 0]]) single_or_multi = "multi" if wls == "weak_multi_labels" else "single" - monkeypatch.setattr( - MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs - ) + monkeypatch.setattr(MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs) monkeypatch.setattr(MajorityVoter, f"_score_{single_or_multi}_label", score) weak_labels = request.getfixturevalue(wls) @@ -349,45 +320,31 @@ def score(self, probabilities, tie_break_policy=None): def test_score_single_label(self, weak_labels, tie_break_policy, expected): mj = MajorityVoter(weak_labels) - probabilities = np.array( - [[0.5, 0.5, 0.0], [0.5, 0.0, 0.5], [1.0 / 3, 0.0, 2.0 / 3]] - ) + probabilities = np.array([[0.5, 0.5, 0.0], [0.5, 0.0, 0.5], [1.0 / 3, 0.0, 2.0 / 3]]) if tie_break_policy is TieBreakPolicy.TRUE_RANDOM: - with pytest.raises( - NotImplementedError, match="not implemented for MajorityVoter" - ): + with pytest.raises(NotImplementedError, match="not implemented for MajorityVoter"): mj._score_single_label(probabilities, tie_break_policy) return - annotation, prediction = mj._score_single_label( - probabilities=probabilities, tie_break_policy=tie_break_policy - ) + annotation, prediction = mj._score_single_label(probabilities=probabilities, tie_break_policy=tie_break_policy) assert np.allclose(annotation, expected[0]) assert np.allclose(prediction, expected[1]) def test_score_single_label_no_ties(self, weak_labels): mj = MajorityVoter(weak_labels) - probabilities = np.array( - [[0.5, 0.3, 0.0], [0.5, 0.0, 0.0], [1.0 / 3, 0.0, 2.0 / 3]] - ) + probabilities = np.array([[0.5, 0.3, 0.0], [0.5, 0.0, 0.0], [1.0 / 3, 0.0, 2.0 / 3]]) - _, prediction = mj._score_single_label( - probabilities=probabilities, tie_break_policy=TieBreakPolicy.ABSTAIN - ) - _, prediction2 = mj._score_single_label( - probabilities=probabilities, tie_break_policy=TieBreakPolicy.RANDOM - ) + _, prediction = mj._score_single_label(probabilities=probabilities, tie_break_policy=TieBreakPolicy.ABSTAIN) + _, prediction2 = mj._score_single_label(probabilities=probabilities, tie_break_policy=TieBreakPolicy.RANDOM) assert np.allclose(prediction, prediction2) def test_score_multi_label(self, weak_multi_labels): mj = MajorityVoter(weak_multi_labels) - probabilities = np.array( - [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [np.nan, np.nan, np.nan]] - ) + probabilities = np.array([[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [np.nan, np.nan, np.nan]]) annotation, prediction = mj._score_multi_label(probabilities=probabilities) @@ -428,9 +385,7 @@ def test_init_wrong_mapping(self, weak_labels, wrong_mapping, expected): label_model = Snorkel(weak_labels) assert label_model._weaklabelsInt2snorkelInt == expected - assert label_model._snorkelInt2weaklabelsInt == { - k: v for v, k in expected.items() - } + assert label_model._snorkelInt2weaklabelsInt == {k: v for v, k in expected.items()} @pytest.mark.parametrize( "include_annotated_records", @@ -450,9 +405,7 @@ def mock_fit(self, L_train, *args, **kwargs): ) label_model = Snorkel(weak_labels) - label_model.fit( - include_annotated_records=include_annotated_records, passed_on=None - ) + label_model.fit(include_annotated_records=include_annotated_records, passed_on=None) def test_fit_automatically_added_kwargs(self, weak_labels): label_model = Snorkel(weak_labels) @@ -525,12 +478,8 @@ def mock_predict(self, L, return_probs, tie_break_policy, *args, **kwargs): prediction_agent="mock_agent", ) assert len(records) == expected[0] - assert [ - rec.prediction[0][0] if rec.prediction else None for rec in records - ] == expected[1] - assert [ - rec.prediction[0][1] if rec.prediction else None for rec in records - ] == expected[2] + assert [rec.prediction[0][0] if rec.prediction else None for rec in records] == expected[1] + assert [rec.prediction[0][1] if rec.prediction else None for rec in records] == expected[2] assert records[0].prediction_agent == "mock_agent" @pytest.mark.parametrize("policy,expected", [("abstain", 0.5), ("random", 2.0 / 3)]) @@ -795,16 +744,12 @@ def mock_predict(self, weak_label_matrix, verbose): assert isinstance(label_model.score(output_str=True), str) - @pytest.mark.parametrize( - "tbp,vrb,expected", [("abstain", False, 1.0), ("random", True, 2 / 3.0)] - ) + @pytest.mark.parametrize("tbp,vrb,expected", [("abstain", False, 1.0), ("random", True, 2 / 3.0)]) def test_score_tbp(self, monkeypatch, weak_labels, tbp, vrb, expected): def mock_predict(self, weak_label_matrix, verbose): assert verbose is vrb assert len(weak_label_matrix) == 3 - return np.array( - [[0.8, 0.1, 0.1], [0.4, 0.4, 0.2], [1 / 3.0, 1 / 3.0, 1 / 3.0]] - ) + return np.array([[0.8, 0.1, 0.1], [0.4, 0.4, 0.2], [1 / 3.0, 1 / 3.0, 1 / 3.0]]) monkeypatch.setattr(FlyingSquid, "_predict", mock_predict) diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py index c46c7cdf3e..7920319b4e 100644 --- a/tests/labeling/text_classification/test_rule.py +++ b/tests/labeling/text_classification/test_rule.py @@ -69,9 +69,7 @@ def log_dataset(mocked_client) -> str: "id": idx, } ) - for text, label, idx in zip( - ["negative", "positive"], ["negative", "positive"], [1, 2] - ) + for text, label, idx in zip(["negative", "positive"], ["negative", "positive"], [1, 2]) ] mocked_client.post( f"/api/datasets/{dataset_name}/TextClassification:bulk", @@ -83,9 +81,7 @@ def log_dataset(mocked_client) -> str: return dataset_name -@pytest.mark.parametrize( - "name,expected", [(None, "query_string"), ("test_name", "test_name")] -) +@pytest.mark.parametrize("name,expected", [(None, "query_string"), ("test_name", "test_name")]) def test_name(name, expected): rule = Rule(query="query_string", label="mock", name=name) assert rule.name == expected @@ -357,9 +353,7 @@ def test_rule_metrics(mocked_client, log_dataset, rule, expected_metrics): ), ], ) -def test_rule_metrics_without_annotated( - mocked_client, log_dataset_without_annotations, rule, expected_metrics -): +def test_rule_metrics_without_annotated(mocked_client, log_dataset_without_annotations, rule, expected_metrics): delete_rule_silently(mocked_client, log_dataset_without_annotations, rule) mocked_client.post( @@ -373,8 +367,6 @@ def test_rule_metrics_without_annotated( def delete_rule_silently(client, dataset: str, rule: Rule): try: - client.delete( - f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}" - ) + client.delete(f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}") except EntityNotFoundError: pass diff --git a/tests/labeling/text_classification/test_weak_labels.py b/tests/labeling/text_classification/test_weak_labels.py index 8a42263881..e6d69f1341 100644 --- a/tests/labeling/text_classification/test_weak_labels.py +++ b/tests/labeling/text_classification/test_weak_labels.py @@ -163,9 +163,7 @@ def test_rules_from_dataset(self, monkeypatch, log_dataset): assert wl.rules is mock_rules def test_norulesfounderror(self, monkeypatch): - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load_rules", lambda x: [] - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load_rules", lambda x: []) with pytest.raises(NoRulesFoundError, match="No rules were found"): WeakLabelsBase("mock") @@ -179,13 +177,9 @@ def test_no_records_found_error(self, monkeypatch): def mock_load(*args, **kwargs): return [] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) - with pytest.raises( - NoRecordsFoundError, match="No records found in dataset 'mock'." - ): + with pytest.raises(NoRecordsFoundError, match="No records found in dataset 'mock'."): WeakLabels(rules=[lambda x: None], dataset="mock") with pytest.raises( NoRecordsFoundError, @@ -212,9 +206,7 @@ def test_rules_records_properties(self, monkeypatch): def mock_load(*args, **kwargs): return expected_records - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) weak_labels = WeakLabelsBase(rules=[lambda x: "mock"] * 2, dataset="mock") @@ -233,9 +225,7 @@ def test_not_implemented_errors(self, monkeypatch): def mock_load(*args, **kwargs): return ["mock"] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) weak_labels = WeakLabelsBase(rules=["mock"], dataset="mock") @@ -258,9 +248,7 @@ def test_faiss_not_installed(self, monkeypatch): def mock_load(*args, **kwargs): return ["mock"] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) monkeypatch.setitem(sys.modules, "faiss", None) with pytest.raises(ModuleNotFoundError, match="pip install faiss-cpu"): weak_labels = WeakLabelsBase(rules=[lambda x: "mock"] * 2, dataset="mock") @@ -272,9 +260,7 @@ def test_multi_label_error(self, monkeypatch): def mock_load(*args, **kwargs): return [TextClassificationRecord(text="test", multi_label=True)] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) with pytest.raises(MultiLabelError): WeakLabels(rules=[lambda x: None], dataset="mock") @@ -291,9 +277,7 @@ def mock_load(*args, **kwargs): ( {None: -10, "negative": 50, "positive": 10}, {None: -10, "negative": 50, "positive": 10}, - np.array( - [[50, -10, 10], [-10, 10, 10], [-10, 10, -10]], dtype=np.short - ), + np.array([[50, -10, 10], [-10, 10, 10], [-10, 10, -10]], dtype=np.short), np.array([50, 10, -10], dtype=np.short), ), ({}, None, None, None), @@ -335,9 +319,7 @@ def test_apply( assert (weak_labels._annotation == expected_annotation_array).all() def test_apply_MultiLabelError(self, log_dataset): - with pytest.raises( - MultiLabelError, match="For rules that do not return exactly 1 label" - ): + with pytest.raises(MultiLabelError, match="For rules that do not return exactly 1 label"): WeakLabels(rules=[lambda x: ["a", "b"]], dataset=log_dataset) def test_matrix_annotation_properties(self, monkeypatch): @@ -349,9 +331,7 @@ def test_matrix_annotation_properties(self, monkeypatch): def mock_load(*args, **kwargs): return expected_records - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array([[0, 1], [-1, 0]], dtype=np.short) @@ -363,36 +343,21 @@ def mock_apply(self, *args, **kwargs): weak_labels = WeakLabels(rules=[lambda x: "mock"] * 2, dataset="mock") - assert ( - weak_labels.matrix(has_annotation=False) - == np.array([[0, 1]], dtype=np.short) - ).all() - assert ( - weak_labels.matrix(has_annotation=True) - == np.array([[-1, 0]], dtype=np.short) - ).all() + assert (weak_labels.matrix(has_annotation=False) == np.array([[0, 1]], dtype=np.short)).all() + assert (weak_labels.matrix(has_annotation=True) == np.array([[-1, 0]], dtype=np.short)).all() assert (weak_labels.annotation() == np.array([[0]], dtype=np.short)).all() - assert ( - weak_labels.annotation(include_missing=True) - == np.array([[-1, 0]], dtype=np.short) - ).all() - with pytest.warns( - FutureWarning, match="'exclude_missing_annotations' is deprecated" - ): + assert (weak_labels.annotation(include_missing=True) == np.array([[-1, 0]], dtype=np.short)).all() + with pytest.warns(FutureWarning, match="'exclude_missing_annotations' is deprecated"): weak_labels.annotation(exclude_missing_annotations=True) def test_summary(self, monkeypatch, rules): def mock_load(*args, **kwargs): return [TextClassificationRecord(text="test")] * 4 - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): - weak_label_matrix = np.array( - [[0, 1, -1], [-1, 0, -1], [-1, -1, -1], [1, 1, -1]], dtype=np.short - ) + weak_label_matrix = np.array([[0, 1, -1], [-1, 0, -1], [-1, -1, -1], [1, 1, -1]], dtype=np.short) # weak_label_matrix = np.random.randint(-1, 30, (int(1e5), 50), dtype=np.short) annotation_array = np.array([-1, -1, -1, -1], dtype=np.short) # annotation_array = np.random.randint(-1, 30, int(1e5), dtype=np.short) @@ -465,9 +430,7 @@ def test_show_records(self, monkeypatch, rules): def mock_load(*args, **kwargs): return [TextClassificationRecord(text="test", id=i) for i in range(5)] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array( @@ -495,17 +458,13 @@ def mock_apply(self, *args, **kwargs): 1, 4, ] - assert weak_labels.show_records( - labels=["positive"], rules=["argilla_rule"] - ).empty + assert weak_labels.show_records(labels=["positive"], rules=["argilla_rule"]).empty def test_change_mapping(self, monkeypatch, rules): def mock_load(*args, **kwargs): return [TextClassificationRecord(text="test", id=i) for i in range(5)] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array( @@ -543,10 +502,7 @@ def mock_apply(self, *args, **kwargs): dtype=np.short, ) ).all() - assert ( - weak_labels.annotation(include_missing=True) - == np.array([2, 10, -10, 1, 2], dtype=np.short) - ).all() + assert (weak_labels.annotation(include_missing=True) == np.array([2, 10, -10, 1, 2], dtype=np.short)).all() assert weak_labels.label2int == new_mapping assert weak_labels.int2label == {val: key for key, val in new_mapping.items()} @@ -559,9 +515,7 @@ def weak_labels(self, monkeypatch, rules): def mock_load(*args, **kwargs): return [TextClassificationRecord(text="test", id=i) for i in range(3)] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array( @@ -583,19 +537,13 @@ def test_extend_matrix(self, weak_labels): ): weak_labels.extend_matrix([1.0, 0.5, 0.5]) - weak_labels.extend_matrix( - [1.0, 0.5, 0.5], np.array([[0.1, 0.1], [0.1, 0.11], [0.11, 0.1]]) - ) + weak_labels.extend_matrix([1.0, 0.5, 0.5], np.array([[0.1, 0.1], [0.1, 0.11], [0.11, 0.1]])) - np.testing.assert_equal( - weak_labels.matrix(), np.array([[0, -1, -1], [-1, 1, -1], [-1, 1, -1]]) - ) + np.testing.assert_equal(weak_labels.matrix(), np.array([[0, -1, -1], [-1, 1, -1], [-1, 1, -1]])) weak_labels.extend_matrix([1.0, 1.0, 1.0]) - np.testing.assert_equal( - weak_labels.matrix(), np.array([[0, -1, -1], [-1, 1, -1], [-1, -1, -1]]) - ) + np.testing.assert_equal(weak_labels.matrix(), np.array([[0, -1, -1], [-1, 1, -1], [-1, -1, -1]])) class TestWeakMultiLabels: @@ -604,9 +552,7 @@ def test_apply( log_multilabel_dataset, multilabel_rules, ): - weak_labels = WeakMultiLabels( - rules=multilabel_rules, dataset=log_multilabel_dataset - ) + weak_labels = WeakMultiLabels(rules=multilabel_rules, dataset=log_multilabel_dataset) assert weak_labels.labels == ["negative", "positive"] @@ -625,30 +571,21 @@ def test_apply( ) ).all() - assert ( - weak_labels._annotation - == np.array([[1, 0], [1, 1], [-1, -1]], dtype=np.short) - ).all() + assert (weak_labels._annotation == np.array([[1, 0], [1, 1], [-1, -1]], dtype=np.short)).all() def test_matrix_annotation(self, monkeypatch): expected_records = [ TextClassificationRecord(text="test without annot", multi_label=True), - TextClassificationRecord( - text="test with annot", annotation="positive", multi_label=True - ), + TextClassificationRecord(text="test with annot", annotation="positive", multi_label=True), ] def mock_load(*args, **kwargs): return expected_records - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): - weak_label_matrix = np.array( - [[[1, 0], [0, 1]], [[-1, -1], [1, 0]]], dtype=np.short - ) + weak_label_matrix = np.array([[[1, 0], [0, 1]], [[-1, -1], [1, 0]]], dtype=np.short) annotation_array = np.array([[-1, -1], [1, 0]], dtype=np.short) labels = ["negative", "positive"] return weak_label_matrix, annotation_array, labels @@ -657,27 +594,16 @@ def mock_apply(self, *args, **kwargs): weak_labels = WeakMultiLabels(rules=[lambda x: "mock"] * 2, dataset="mock") - assert ( - weak_labels.matrix(has_annotation=False) - == np.array([[[1, 0], [0, 1]]], dtype=np.short) - ).all() - assert ( - weak_labels.matrix(has_annotation=True) - == np.array([[[-1, -1], [1, 0]]], dtype=np.short) - ).all() + assert (weak_labels.matrix(has_annotation=False) == np.array([[[1, 0], [0, 1]]], dtype=np.short)).all() + assert (weak_labels.matrix(has_annotation=True) == np.array([[[-1, -1], [1, 0]]], dtype=np.short)).all() assert (weak_labels.annotation() == np.array([[1, 0]], dtype=np.short)).all() - assert ( - weak_labels.annotation(include_missing=True) - == np.array([[[-1, -1], [1, 0]]], dtype=np.short) - ).all() + assert (weak_labels.annotation(include_missing=True) == np.array([[[-1, -1], [1, 0]]], dtype=np.short)).all() def test_summary(self, monkeypatch, multilabel_rules): def mock_load(*args, **kwargs): return [TextClassificationRecord(text="test", multi_label=True)] * 4 - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array( @@ -690,9 +616,7 @@ def mock_apply(self, *args, **kwargs): dtype=np.short, ) # weak_label_matrix = np.random.randint(-1, 30, (int(1e5), 50), dtype=np.short) - annotation_array = np.array( - [[-1, -1], [-1, -1], [-1, -1], [-1, -1]], dtype=np.short - ) + annotation_array = np.array([[-1, -1], [-1, -1], [-1, -1], [-1, -1]], dtype=np.short) # annotation_array = np.random.randint(-1, 30, int(1e5), dtype=np.short) labels = ["negative", "positive"] # label2int = {k: v for k, v in zip(["None"]+list(range(30)), list(range(-1, 30)))} @@ -734,9 +658,7 @@ def mock_apply(self, *args, **kwargs): ) assert_frame_equal(summary, expected) - summary = weak_labels.summary( - annotation=np.array([[0, 1], [-1, -1], [0, 1], [1, 1]]) - ) + summary = weak_labels.summary(annotation=np.array([[0, 1], [-1, -1], [0, 1], [1, 1]])) expected = pd.DataFrame( { "label": [ @@ -762,9 +684,7 @@ def test_compute_correct_incorrect(self, monkeypatch): def mock_load(*args, **kwargs): return [TextClassificationRecord(text="mock")] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array([[[1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.short) @@ -773,23 +693,16 @@ def mock_apply(self, *args, **kwargs): monkeypatch.setattr(WeakMultiLabels, "_apply_rules", mock_apply) weak_labels = WeakMultiLabels(rules=[lambda x: "mock"] * 2, dataset="mock") - correct, incorrect = weak_labels._compute_correct_incorrect( - annotation=np.array([[1, 0, 1, 0]]) - ) + correct, incorrect = weak_labels._compute_correct_incorrect(annotation=np.array([[1, 0, 1, 0]])) assert np.allclose(correct, np.array([2, 0, 2])) assert np.allclose(incorrect, np.array([0, 2, 2])) def test_show_records(self, monkeypatch, multilabel_rules): def mock_load(*args, **kwargs): - return [ - TextClassificationRecord(text="test", id=i, multi_label=True) - for i in range(5) - ] + return [TextClassificationRecord(text="test", id=i, multi_label=True) for i in range(5)] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array( @@ -812,9 +725,7 @@ def mock_apply(self, *args, **kwargs): assert weak_labels.show_records().id.tolist() == [0, 1, 2, 3, 4] assert weak_labels.show_records(labels=["positive"]).id.tolist() == [0, 1, 3] assert weak_labels.show_records(labels=["negative"]).id.tolist() == [0, 1, 4] - assert weak_labels.show_records( - labels=["negative", "positive"] - ).id.tolist() == [0, 1] + assert weak_labels.show_records(labels=["negative", "positive"]).id.tolist() == [0, 1] assert weak_labels.show_records(rules=[0]).id.tolist() == [0, 1, 3] assert weak_labels.show_records(rules=[0, "rule_1"]).id.tolist() == [0, 1, 3] @@ -823,21 +734,14 @@ def mock_apply(self, *args, **kwargs): 1, 4, ] - assert weak_labels.show_records( - labels=["positive"], rules=["argilla_rule"] - ).empty + assert weak_labels.show_records(labels=["positive"], rules=["argilla_rule"]).empty @pytest.fixture def weak_multi_labels(self, monkeypatch, rules): def mock_load(*args, **kwargs): - return [ - TextClassificationRecord(text="test", id=i, multi_label=True) - for i in range(3) - ] + return [TextClassificationRecord(text="test", id=i, multi_label=True) for i in range(3)] - monkeypatch.setattr( - "argilla.labeling.text_classification.weak_labels.load", mock_load - ) + monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load) def mock_apply(self, *args, **kwargs): weak_label_matrix = np.array( @@ -862,9 +766,7 @@ def test_extend_matrix(self, weak_multi_labels): ): weak_multi_labels.extend_matrix([1.0, 0.5, 0.5]) - weak_multi_labels.extend_matrix( - [1.0, 0.5, 0.5], np.array([[0.1, 0.1], [0.1, 0.11], [0.11, 0.1]]) - ) + weak_multi_labels.extend_matrix([1.0, 0.5, 0.5], np.array([[0.1, 0.1], [0.1, 0.11], [0.11, 0.1]])) np.testing.assert_equal( weak_multi_labels.matrix(), @@ -891,16 +793,10 @@ def test_extend_matrix(self, weak_multi_labels): ) # The "correct" and "incorrect" columns from `expected_summary` may infer a different # dtype than `weak_multi_labels.summary()` returns. - assert_frame_equal( - weak_multi_labels.summary(), expected_summary, check_dtype=False - ) + assert_frame_equal(weak_multi_labels.summary(), expected_summary, check_dtype=False) - expected_show_records = pd.DataFrame( - map(lambda x: x.dict(), weak_multi_labels.records()) - ) - assert_frame_equal( - weak_multi_labels.show_records(rules=["rule_1"]), expected_show_records - ) + expected_show_records = pd.DataFrame(map(lambda x: x.dict(), weak_multi_labels.records())) + assert_frame_equal(weak_multi_labels.show_records(rules=["rule_1"]), expected_show_records) weak_multi_labels.extend_matrix([1.0, 1.0, 1.0]) diff --git a/tests/listeners/test_listener.py b/tests/listeners/test_listener.py index 8549036869..7651a7dbaf 100644 --- a/tests/listeners/test_listener.py +++ b/tests/listeners/test_listener.py @@ -39,9 +39,7 @@ def condition_check_params(search): ("dataset", "val + {param}", None, condition_check_params, {"param": 100}), ], ) -def test_listener_with_parameters( - mocked_client, dataset, query, metrics, condition, query_params -): +def test_listener_with_parameters(mocked_client, dataset, query, metrics, condition, query_params): ar.delete(dataset) class TestListener: diff --git a/tests/metrics/test_common_metrics.py b/tests/metrics/test_common_metrics.py index 4cbb6f80f6..1bf51590aa 100644 --- a/tests/metrics/test_common_metrics.py +++ b/tests/metrics/test_common_metrics.py @@ -137,9 +137,7 @@ def test_keywords_metrics(mocked_client, gutenberg_spacy_ner): "two": 18, } - assert keywords(name=gutenberg_spacy_ner) == keywords( - name=gutenberg_spacy_ner, query="" - ) + assert keywords(name=gutenberg_spacy_ner) == keywords(name=gutenberg_spacy_ner, query="") with pytest.raises(AssertionError, match="size must be greater than 0"): keywords(name=gutenberg_spacy_ner, size=0) diff --git a/tests/monitoring/test_monitor.py b/tests/monitoring/test_monitor.py index 18240551af..39b5db059a 100644 --- a/tests/monitoring/test_monitor.py +++ b/tests/monitoring/test_monitor.py @@ -30,8 +30,7 @@ def test_monitor_with_non_supported_model(): assert len(warning_list) == 1 warn_text = warning_list[0].message.args[0] assert ( - warn_text - == "The provided task model is not supported by monitoring module. " + warn_text == "The provided task model is not supported by monitoring module. " "Predictions won't be logged into argilla" ) @@ -53,7 +52,6 @@ def test_monitor_non_supported_huggingface_model(): assert len(warning_list) == 1 warn_text = warning_list[0].message.args[0] assert ( - warn_text - == "The provided task model is not supported by monitoring module. " + warn_text == "The provided task model is not supported by monitoring module. " "Predictions won't be logged into argilla" ) diff --git a/tests/monitoring/test_transformers_monitoring.py b/tests/monitoring/test_transformers_monitoring.py index 225f606518..e478f1d572 100644 --- a/tests/monitoring/test_transformers_monitoring.py +++ b/tests/monitoring/test_transformers_monitoring.py @@ -198,10 +198,7 @@ def check_zero_shot_results( assert record.inputs["text"] == text assert record.metadata == {"labels": labels, "hypothesis_template": hypothesis} assert record.prediction_agent == zero_shot_classifier.model.config.name_or_path - assert record.prediction == [ - (label, score) - for label, score in zip(predictions["labels"], predictions["scores"]) - ] + assert record.prediction == [(label, score) for label, score in zip(predictions["labels"], predictions["scores"])] @pytest.mark.parametrize( @@ -306,9 +303,7 @@ def test_monitor_zero_shot_with_text_array( dataset, ): argilla.delete(dataset) - predictions = mocked_monitor( - [text], candidate_labels=labels, hypothesis_template=hypothesis - ) + predictions = mocked_monitor([text], candidate_labels=labels, hypothesis_template=hypothesis) check_zero_shot_results( predictions, diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py index 0276da15eb..4db8dff1c3 100644 --- a/tests/server/backend/test_query_builder.py +++ b/tests/server/backend/test_query_builder.py @@ -62,9 +62,7 @@ def test_build_sort_configuration(index_schema, sort_cfg, expected_sort): builder = EsQueryBuilder() - es_sort = builder.map_2_es_sort_configuration( - sort=SortConfig(sort_by=sort_cfg), schema=index_schema - ) + es_sort = builder.map_2_es_sort_configuration(sort=SortConfig(sort_by=sort_cfg), schema=index_schema) assert es_sort == expected_sort @@ -72,9 +70,7 @@ def test_build_sort_with_wrong_field_name(): builder = EsQueryBuilder() with pytest.raises(Exception): - builder.map_2_es_sort_configuration( - sort=SortConfig(sort_by=[SortableField(id="wat?!")]) - ) + builder.map_2_es_sort_configuration(sort=SortConfig(sort_by=[SortableField(id="wat?!")])) def test_build_sort_without_sort_config(): diff --git a/tests/server/commons/test_telemetry.py b/tests/server/commons/test_telemetry.py index f3ce903918..0b2902ea03 100644 --- a/tests/server/commons/test_telemetry.py +++ b/tests/server/commons/test_telemetry.py @@ -45,9 +45,7 @@ async def test_track_bulk(telemetry_track_data): task, records = TaskType.token_classification, 100 await telemetry.track_bulk(task=task, records=records) - telemetry_track_data.assert_called_once_with( - "LogRecordsRequested", {"task": task, "records": records} - ) + telemetry_track_data.assert_called_once_with("LogRecordsRequested", {"task": task, "records": records}) @pytest.mark.asyncio diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py index 2d0ebebf0b..4f613f0123 100644 --- a/tests/server/datasets/test_api.py +++ b/tests/server/datasets/test_api.py @@ -178,9 +178,7 @@ def test_update_dataset(mocked_client): delete_dataset(mocked_client, dataset) create_mock_dataset(mocked_client, dataset) - response = mocked_client.patch( - f"/api/datasets/{dataset}", json={"metadata": {"new": "value"}} - ) + response = mocked_client.patch(f"/api/datasets/{dataset}", json={"metadata": {"new": "value"}}) assert response.status_code == 200 response = mocked_client.get(f"/api/datasets/{dataset}") @@ -206,12 +204,7 @@ def test_open_and_close_dataset(mocked_client): } assert mocked_client.put(f"/api/datasets/{dataset}:open").status_code == 200 - assert ( - mocked_client.post( - f"/api/datasets/{dataset}/TextClassification:search" - ).status_code - == 200 - ) + assert mocked_client.post(f"/api/datasets/{dataset}/TextClassification:search").status_code == 200 def delete_dataset(client, dataset, workspace: Optional[str] = None): @@ -268,9 +261,7 @@ def test_delete_records(mocked_client): } } - response = mocked_client.delete( - f"/api/datasets/{dataset_name}/data?mark_as_discarded=true" - ) + response = mocked_client.delete(f"/api/datasets/{dataset_name}/data?mark_as_discarded=true") assert response.status_code == 200 assert response.json() == { "matched": 99, diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py index 21032362b2..0dc30ba139 100644 --- a/tests/server/metrics/test_api.py +++ b/tests/server/metrics/test_api.py @@ -61,9 +61,7 @@ def test_wrong_dataset_metrics(mocked_client): assert response.json() == { "detail": { "code": "argilla.api.errors::WrongTaskError", - "params": { - "message": "Provided task TokenClassification cannot be applied to dataset" - }, + "params": {"message": "Provided task TokenClassification cannot be applied to dataset"}, } } @@ -76,9 +74,7 @@ def test_wrong_dataset_metrics(mocked_client): assert response.json() == { "detail": { "code": "argilla.api.errors::WrongTaskError", - "params": { - "message": "Provided task TokenClassification cannot be applied to dataset" - }, + "params": {"message": "Provided task TokenClassification cannot be applied to dataset"}, } } @@ -136,9 +132,7 @@ def test_dataset_for_token_classification(mocked_client): ).status_code == 200 ) - metrics = mocked_client.get( - f"/api/datasets/TokenClassification/{dataset}/metrics" - ).json() + metrics = mocked_client.get(f"/api/datasets/TokenClassification/{dataset}/metrics").json() assert len(metrics) == len(TokenClassificationMetrics.metrics) for metric in metrics: @@ -187,9 +181,7 @@ def test_dataset_metrics(mocked_client): == 200 ) - metrics = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/metrics" - ).json() + metrics = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/metrics").json() assert len(metrics) == COMMON_METRICS_LENGTH + 5 diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py index 757349d6a0..1511695f63 100644 --- a/tests/server/security/test_model.py +++ b/tests/server/security/test_model.py @@ -24,17 +24,13 @@ def test_valid_mail(email): assert user.email == email -@pytest.mark.parametrize( - "wrong_email", ["non-valid-email", "wrong@mail", "@wrong" "wrong.mail"] -) +@pytest.mark.parametrize("wrong_email", ["non-valid-email", "wrong@mail", "@wrong" "wrong.mail"]) def test_email_validator(wrong_email): with pytest.raises(ValidationError): User(username="user", email=wrong_email) -@pytest.mark.parametrize( - "wrong_name", ["user name", "user/name", "user.name", "UserName", "userName"] -) +@pytest.mark.parametrize("wrong_name", ["user name", "user/name", "user.name", "UserName", "userName"]) def test_username_validator(wrong_name): with pytest.raises( ValidationError, @@ -43,9 +39,7 @@ def test_username_validator(wrong_name): User(username=wrong_name) -@pytest.mark.parametrize( - "wrong_workspace", ["work space", "work/space", "work.space", "_", "-"] -) +@pytest.mark.parametrize("wrong_workspace", ["work space", "work/space", "work.space", "_", "-"]) def test_workspace_validator(wrong_workspace): with pytest.raises(ValidationError): User(username="username", workspaces=[wrong_workspace]) diff --git a/tests/server/security/test_provider.py b/tests/server/security/test_provider.py index 6d169e74dc..27e73cfc7e 100644 --- a/tests/server/security/test_provider.py +++ b/tests/server/security/test_provider.py @@ -33,9 +33,7 @@ async def test_get_user_via_token(): @pytest.mark.asyncio async def test_get_user_via_api_key(): - user = await localAuth.get_user( - security_scopes=security_Scopes, api_key=DEFAULT_API_KEY - ) + user = await localAuth.get_user(security_scopes=security_Scopes, api_key=DEFAULT_API_KEY) assert user.username == "argilla" diff --git a/tests/server/test_errors.py b/tests/server/test_errors.py index 0090ab2e31..80da4da6f3 100644 --- a/tests/server/test_errors.py +++ b/tests/server/test_errors.py @@ -17,7 +17,4 @@ def test_generic_error(): err = GenericServerError(error=ValueError("this is an error")) - assert ( - str(err) - == "argilla.api.errors::GenericServerError(type=builtins.ValueError,message=this is an error)" - ) + assert str(err) == "argilla.api.errors::GenericServerError(type=builtins.ValueError,message=this is an error)" diff --git a/tests/server/text2text/test_api.py b/tests/server/text2text/test_api.py index c5db714f9d..aec7eec723 100644 --- a/tests/server/text2text/test_api.py +++ b/tests/server/text2text/test_api.py @@ -181,9 +181,7 @@ def test_api_with_new_predictions_data_model(mocked_client): { "text": "This is a text data", "predictions": { - "test": { - "sentences": [{"text": "This is a test data", "score": 0.6}] - }, + "test": {"sentences": [{"text": "This is a test data", "score": 0.6}]}, }, } ), diff --git a/tests/server/text_classification/test_api.py b/tests/server/text_classification/test_api.py index 6791614471..79210fdf55 100644 --- a/tests/server/text_classification/test_api.py +++ b/tests/server/text_classification/test_api.py @@ -99,9 +99,7 @@ def test_create_records_for_text_classification_with_multi_label(mocked_client): ).dict(by_alias=True), ) - get_dataset = Dataset.parse_obj( - mocked_client.get(f"/api/datasets/{dataset}").json() - ) + get_dataset = Dataset.parse_obj(mocked_client.get(f"/api/datasets/{dataset}").json()) assert get_dataset.tags == { "env": "test", "class": "text classification", @@ -202,9 +200,7 @@ def test_create_records_for_text_classification(mocked_client, telemetry_track_d condition=not SUPPORTED_VECTOR_SEARCH, reason="Vector search not supported", ) -def test_create_records_for_text_classification_vector_search( - mocked_client, telemetry_track_data -): +def test_create_records_for_text_classification_vector_search(mocked_client, telemetry_track_data): dataset = "test_create_records_for_text_classification_vector_search" assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 tags = {"env": "test", "class": "text classification"} @@ -271,9 +267,7 @@ def test_create_records_for_text_classification_vector_search( assert created_dataset.tags == tags assert created_dataset.metadata == metadata - response = mocked_client.post( - f"/api/datasets/{dataset}/TextClassification:search", json={} - ) + response = mocked_client.post(f"/api/datasets/{dataset}/TextClassification:search", json={}) assert response.status_code == 200 results = TextClassificationSearchResults.parse_obj(response.json()) @@ -362,9 +356,7 @@ def test_partial_record_update(mocked_client): response = mocked_client.post( f"/api/datasets/{name}/TextClassification:search", json={ - "query": TextClassificationQuery(predicted=PredictionStatus.OK).dict( - by_alias=True - ), + "query": TextClassificationQuery(predicted=PredictionStatus.OK).dict(by_alias=True), }, ) @@ -374,9 +366,7 @@ def test_partial_record_update(mocked_client): first_record = results.records[0] assert first_record.last_updated is not None first_record.last_updated = None - assert TextClassificationRecord( - **first_record.dict(by_alias=True, exclude_none=True) - ) == TextClassificationRecord( + assert TextClassificationRecord(**first_record.dict(by_alias=True, exclude_none=True)) == TextClassificationRecord( **{ "id": 1, "inputs": {"text": "This is a text, oh yeah!"}, @@ -508,10 +498,7 @@ def test_some_sort_by(mocked_client): }, } } - assert ( - response.json()["detail"]["code"] - == expected_response_property_name_2_value["detail"]["code"] - ) + assert response.json()["detail"]["code"] == expected_response_property_name_2_value["detail"]["code"] assert ( response.json()["detail"]["params"]["message"] == expected_response_property_name_2_value["detail"]["params"]["message"] @@ -721,9 +708,7 @@ def test_wrong_text_query(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:search", - json=TextClassificationSearchRequest( - query=TextClassificationQuery(query_text="!") - ).dict(), + json=TextClassificationSearchRequest(query=TextClassificationQuery(query_text="!")).dict(), ) assert response.status_code == 400 assert response.json() == { @@ -755,18 +740,14 @@ def test_search_using_text(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:search", - json=TextClassificationSearchRequest( - query=TextClassificationQuery(query_text="text: texto") - ).dict(), + json=TextClassificationSearchRequest(query=TextClassificationQuery(query_text="text: texto")).dict(), ) assert response.status_code == 200 assert response.json()["total"] == 1 response = mocked_client.post( f"/api/datasets/{dataset}/TextClassification:search", - json=TextClassificationSearchRequest( - query=TextClassificationQuery(query_text="text.exact: texto") - ).dict(), + json=TextClassificationSearchRequest(query=TextClassificationQuery(query_text="text.exact: texto")).dict(), ) assert response.status_code == 200 assert response.json()["total"] == 0 diff --git a/tests/server/text_classification/test_api_rules.py b/tests/server/text_classification/test_api_rules.py index 464ab427af..16291950ff 100644 --- a/tests/server/text_classification/test_api_rules.py +++ b/tests/server/text_classification/test_api_rules.py @@ -57,9 +57,7 @@ def test_dataset_without_rules(mocked_client): dataset = "test_dataset_without_rules" log_some_records(mocked_client, dataset) - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules") assert response.status_code == 200 assert len(response.json()) == 0 @@ -80,9 +78,7 @@ def test_dataset_update_rule(mocked_client): json={"label": "NEW Label"}, ) - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules") rules = list(map(LabelingRule.parse_obj, response.json())) assert len(rules) == 1 assert rules[0].label == "NEW Label" @@ -94,9 +90,7 @@ def test_dataset_update_rule(mocked_client): json={"labels": ["A", "B"], "description": "New description"}, ) - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules") rules = list(map(LabelingRule.parse_obj, response.json())) assert len(rules) == 1 assert rules[0].description == "New description" @@ -109,9 +103,7 @@ def test_dataset_update_rule(mocked_client): [ CreateLabelingRule(query="a query", description="Description", label="LALA"), CreateLabelingRule(query="/a qu?ry/", description="Description", label="LALA"), - CreateLabelingRule( - query="another query", description="Description", labels=["A", "B", "C"] - ), + CreateLabelingRule(query="another query", description="Description", labels=["A", "B", "C"]), ], ) def test_dataset_with_rules(mocked_client, rule): @@ -130,9 +122,7 @@ def test_dataset_with_rules(mocked_client, rule): assert created_rule.labels == rule.labels assert created_rule.description == rule.description - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules") assert response.status_code == 200 rules = list(map(LabelingRule.parse_obj, response.json())) assert len(rules) == 1 @@ -143,12 +133,8 @@ def test_dataset_with_rules(mocked_client, rule): "rule", [ CreateLabelingRule(query="a query", description="Description", label="LALA"), - CreateLabelingRule( - query="/a qu(e|E)ry/", description="Description", label="LALA" - ), - CreateLabelingRule( - query="another query", description="Description", labels=["A", "B", "C"] - ), + CreateLabelingRule(query="/a qu(e|E)ry/", description="Description", label="LALA"), + CreateLabelingRule(query="another query", description="Description", labels=["A", "B", "C"]), ], ) def test_get_dataset_rule(mocked_client, rule): @@ -161,9 +147,7 @@ def test_get_dataset_rule(mocked_client, rule): ) assert response.status_code == 200 - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}") assert response.status_code == 200 found_rule = LabelingRule.parse_obj(response.json()) assert found_rule.query == rule.query @@ -178,20 +162,14 @@ def test_delete_dataset_rules(mocked_client): response = mocked_client.post( f"/api/datasets/TextClassification/{dataset}/labeling/rules", - json=CreateLabelingRule( - query="/a query/", label="TEST", description="Description" - ).dict(), + json=CreateLabelingRule(query="/a query/", label="TEST", description="Description").dict(), ) assert response.status_code == 200 - response = mocked_client.delete( - f"/api/datasets/TextClassification/{dataset}/labeling/rules//a query/" - ) + response = mocked_client.delete(f"/api/datasets/TextClassification/{dataset}/labeling/rules//a query/") assert response.status_code == 200 - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules") assert response.status_code == 200 assert len(response.json()) == 0 @@ -242,9 +220,7 @@ def test_rule_metrics_with_missing_label(mocked_client): dataset = "test_rule_metrics_with_missing_label" log_some_records(mocked_client, dataset, annotation="OK") - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules/a query/metrics" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/a query/metrics") assert response.status_code == 200, response.json() assert response.json() == { "coverage": 0.0, @@ -344,18 +320,12 @@ def test_rule_metrics_with_missing_label(mocked_client): ), ], ) -def test_rule_metrics_with_missing_label_for_stored_rule( - mocked_client, rule, expected_metrics -): +def test_rule_metrics_with_missing_label_for_stored_rule(mocked_client, rule, expected_metrics): dataset = "test_rule_metrics_with_missing_label_for_stored_rule" log_some_records(mocked_client, dataset, annotation="o.k.") - mocked_client.post( - f"/api/datasets/TextClassification/{dataset}/labeling/rules", json=rule.dict() - ) + mocked_client.post(f"/api/datasets/TextClassification/{dataset}/labeling/rules", json=rule.dict()) - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}/metrics" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}/metrics") assert response.status_code == 200 assert response.json() == expected_metrics @@ -366,21 +336,15 @@ def test_create_rules_and_then_log(mocked_client): for query in ["ejemplo", "bad query"]: mocked_client.post( f"/api/datasets/TextClassification/{dataset}/labeling/rules", - json=CreateLabelingRule( - query=query, label="TEST", description="Description" - ).dict(), + json=CreateLabelingRule(query=query, label="TEST", description="Description").dict(), ) - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules") rules = list(map(LabelingRule.parse_obj, response.json())) assert len(rules) == 2 log_some_records(mocked_client, dataset, annotation="OK", delete=False) - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules") rules = list(map(LabelingRule.parse_obj, response.json())) assert len(rules) == 2 @@ -441,9 +405,7 @@ def test_dataset_rules_metrics(mocked_client, rules, expected_metrics, annotatio json=rule.dict(), ) - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules/metrics" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/metrics") assert response.status_code == 200, response.json() assert response.json() == expected_metrics @@ -465,9 +427,7 @@ def test_rule_metric(mocked_client): assert metrics.incorrect == 1 assert metrics.precision == 0 - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules/ejemplo/metrics?label=OK" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/ejemplo/metrics?label=OK") assert response.status_code == 200 metrics = LabelingRuleMetricsSummary.parse_obj(response.json()) @@ -475,9 +435,7 @@ def test_rule_metric(mocked_client): assert metrics.incorrect == 0 assert metrics.precision == 1 - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules/ejemplo/metrics" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/ejemplo/metrics") assert response.status_code == 200 metrics = LabelingRuleMetricsSummary.parse_obj(response.json()) @@ -486,9 +444,7 @@ def test_rule_metric(mocked_client): assert metrics.precision is None assert metrics.coverage_annotated == 1 - response = mocked_client.get( - f"/api/datasets/TextClassification/{dataset}/labeling/rules/badd/metrics?label=OK" - ) + response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/badd/metrics?label=OK") assert response.status_code == 200 metrics = LabelingRuleMetricsSummary.parse_obj(response.json()) diff --git a/tests/server/text_classification/test_api_settings.py b/tests/server/text_classification/test_api_settings.py index 03e50a3c8c..e091ce7952 100644 --- a/tests/server/text_classification/test_api_settings.py +++ b/tests/server/text_classification/test_api_settings.py @@ -17,9 +17,7 @@ def create_dataset(client, name: str): - response = client.post( - "/api/datasets", json={"name": name, "task": TaskType.text_classification} - ) + response = client.post("/api/datasets", json={"name": name, "task": TaskType.text_classification}) assert response.status_code == 200 @@ -60,9 +58,7 @@ def test_delete_settings(mocked_client): create_dataset(mocked_client, name) assert create_settings(mocked_client, name).status_code == 200 - response = mocked_client.delete( - f"/api/datasets/{TaskType.text_classification}/{name}/settings" - ) + response = mocked_client.delete(f"/api/datasets/{TaskType.text_classification}/{name}/settings") assert response.status_code == 200 assert fetch_settings(mocked_client, name).status_code == 404 @@ -131,6 +127,4 @@ def log_some_data(mocked_client, name): def fetch_settings(mocked_client, name): - return mocked_client.get( - f"/api/datasets/{TaskType.text_classification}/{name}/settings" - ) + return mocked_client.get(f"/api/datasets/{TaskType.text_classification}/{name}/settings") diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py index 87e344cb02..c20d984804 100644 --- a/tests/server/text_classification/test_model.py +++ b/tests/server/text_classification/test_model.py @@ -31,9 +31,7 @@ def test_flatten_metadata(): data = { "inputs": {"text": "bogh"}, - "metadata": { - "mail": {"subject": "The mail subject", "body": "This is a large text body"} - }, + "metadata": {"mail": {"subject": "The mail subject", "body": "This is a large text body"}}, } record = ServiceTextClassificationRecord.parse_obj(data) assert list(record.metadata.keys()) == ["mail.subject", "mail.body"] @@ -378,9 +376,7 @@ def test_empty_labels_for_no_multilabel(): record = ServiceTextClassificationRecord( inputs={"text": "The input text"}, prediction=TextClassificationAnnotation(agent="ann.", labels=[]), - annotation=TextClassificationAnnotation( - agent="ann.", labels=[ClassPrediction(class_label="B")] - ), + annotation=TextClassificationAnnotation(agent="ann.", labels=[ClassPrediction(class_label="B")]), ) assert record.predicted == PredictionStatus.KO @@ -400,12 +396,8 @@ def test_using_predictions_dict(): record = ServiceTextClassificationRecord( inputs={"text": "this is a text"}, predictions={ - "carl": TextClassificationAnnotation( - agent="wat at", labels=[ClassPrediction(class_label="YES")] - ), - "BOB": TextClassificationAnnotation( - agent="wot wot", labels=[ClassPrediction(class_label="NO")] - ), + "carl": TextClassificationAnnotation(agent="wat at", labels=[ClassPrediction(class_label="YES")]), + "BOB": TextClassificationAnnotation(agent="wot wot", labels=[ClassPrediction(class_label="NO")]), }, ) @@ -415,9 +407,7 @@ def test_using_predictions_dict(): } assert record.predictions == { "BOB": TextClassificationAnnotation(labels=[ClassPrediction(class_label="NO")]), - "carl": TextClassificationAnnotation( - labels=[ClassPrediction(class_label="YES")] - ), + "carl": TextClassificationAnnotation(labels=[ClassPrediction(class_label="YES")]), } @@ -425,7 +415,5 @@ def test_with_no_agent_at_all(): with pytest.raises(ValidationError): ServiceTextClassificationRecord( inputs={"text": "this is a text"}, - prediction=TextClassificationAnnotation( - labels=[ClassPrediction(class_label="YES")] - ), + prediction=TextClassificationAnnotation(labels=[ClassPrediction(class_label="YES")]), ) diff --git a/tests/server/token_classification/test_api.py b/tests/server/token_classification/test_api.py index 8260300012..577b84f7cd 100644 --- a/tests/server/token_classification/test_api.py +++ b/tests/server/token_classification/test_api.py @@ -58,10 +58,7 @@ def test_load_as_different_task(mocked_client): assert response.json() == { "detail": { "code": "argilla.api.errors::WrongTaskError", - "params": { - "message": "Provided task TextClassification cannot be " - "applied to dataset" - }, + "params": {"message": "Provided task TextClassification cannot be " "applied to dataset"}, } } @@ -90,9 +87,7 @@ def test_search_special_characters(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TokenClassification:search", - json=TokenClassificationSearchRequest( - query=TokenClassificationQuery(query_text="\!") - ).dict(), + json=TokenClassificationSearchRequest(query=TokenClassificationQuery(query_text="\!")).dict(), ) assert response.status_code == 200, response.json() results = TokenClassificationSearchResults.parse_obj(response.json()) @@ -100,9 +95,7 @@ def test_search_special_characters(mocked_client): response = mocked_client.post( f"/api/datasets/{dataset}/TokenClassification:search", - json=TokenClassificationSearchRequest( - query=TokenClassificationQuery(query_text="text.exact:\!") - ).dict(), + json=TokenClassificationSearchRequest(query=TokenClassificationQuery(query_text="text.exact:\!")).dict(), ) assert response.status_code == 200, response.json() results = TokenClassificationSearchResults.parse_obj(response.json()) @@ -161,9 +154,7 @@ def test_some_sort(mocked_client): (None, lambda r: len(r.metrics) == 0), ], ) -def test_create_records_for_token_classification( - mocked_client, include_metrics: bool, metrics_validator: Callable -): +def test_create_records_for_token_classification(mocked_client, include_metrics: bool, metrics_validator: Callable): dataset = "test_create_records_for_token_classification" assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 entity_label = "TEST" diff --git a/tests/server/token_classification/test_api_settings.py b/tests/server/token_classification/test_api_settings.py index 3701e4de53..587ea89177 100644 --- a/tests/server/token_classification/test_api_settings.py +++ b/tests/server/token_classification/test_api_settings.py @@ -17,9 +17,7 @@ def create_dataset(client, name: str): - response = client.post( - "/api/datasets", json={"name": name, "task": TaskType.token_classification} - ) + response = client.post("/api/datasets", json={"name": name, "task": TaskType.token_classification}) assert response.status_code == 200 @@ -60,9 +58,7 @@ def test_delete_settings(mocked_client): create_dataset(mocked_client, name) assert create_settings(mocked_client, name).status_code == 200 - response = mocked_client.delete( - f"/api/datasets/{TaskType.token_classification}/{name}/settings" - ) + response = mocked_client.delete(f"/api/datasets/{TaskType.token_classification}/{name}/settings") assert response.status_code == 200 assert fetch_settings(mocked_client, name).status_code == 404 @@ -135,6 +131,4 @@ def test_validate_settings_after_logging(mocked_client): def fetch_settings(mocked_client, name): - return mocked_client.get( - f"/api/datasets/{TaskType.token_classification}/{name}/settings" - ) + return mocked_client.get(f"/api/datasets/{TaskType.token_classification}/{name}/settings") diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py index 8d0c12bbff..878db93577 100644 --- a/tests/server/token_classification/test_model.py +++ b/tests/server/token_classification/test_model.py @@ -119,9 +119,7 @@ def test_model_with_predictions(): "metrics": {}, "predictions": { "test": { - "entities": [ - {"end": 24, "label": "test", "score": 1.0, "start": 9} - ], + "entities": [{"end": 24, "label": "test", "score": 1.0, "start": 9}], } }, "status": "Default", @@ -162,9 +160,7 @@ def test_too_long_metadata(): def test_entity_label_too_long(): text = "On one ones o no" - with pytest.raises( - ValidationError, match="ensure this value has at most 128 character" - ): + with pytest.raises(ValidationError, match="ensure this value has at most 128 character"): ServiceTokenClassificationRecord( text=text, tokens=text.split(), @@ -269,9 +265,7 @@ def test_annotated_without_entities(): record = ServiceTokenClassificationRecord( text=text, tokens=text.split(), - prediction=TokenClassificationAnnotation( - agent="pred.test", entities=[EntitySpan(start=0, end=3, label="DET")] - ), + prediction=TokenClassificationAnnotation(agent="pred.test", entities=[EntitySpan(start=0, end=3, label="DET")]), annotation=TokenClassificationAnnotation(agent="test", entities=[]), ) diff --git a/tests/utils/test_span_utils.py b/tests/utils/test_span_utils.py index 3cde38131c..16b71d10b1 100644 --- a/tests/utils/test_span_utils.py +++ b/tests/utils/test_span_utils.py @@ -43,9 +43,7 @@ def test_init(): def test_init_value_error(): - with pytest.raises( - ValueError, match="Token 'ValueError' not found in text: test error" - ): + with pytest.raises(ValueError, match="Token 'ValueError' not found in text: test error"): SpanUtils(text="test error", tokens=["test", "ValueError"]) @@ -56,9 +54,7 @@ def test_validate(): def test_validate_not_valid_spans(): span_utils = SpanUtils("test this.", ["test", "this", "."]) - with pytest.raises( - ValueError, match="Following entity spans are not valid: \[\('mock', 2, 1\)\]\n" - ): + with pytest.raises(ValueError, match="Following entity spans are not valid: \[\('mock', 2, 1\)\]\n"): span_utils.validate([("mock", 2, 1)]) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 85fe24e9f5..6a08af0c66 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -34,14 +34,10 @@ def mock_import_module(name, package): assert lazy_module.title() == ".mock_module".title() assert lazy_module.string == str - with pytest.warns( - FutureWarning, match="Importing 'dep_mock_module' from the argilla namespace" - ): + with pytest.warns(FutureWarning, match="Importing 'dep_mock_module' from the argilla namespace"): assert lazy_module.dep_mock_module == ".dep_mock_module" - with pytest.warns( - FutureWarning, match="Importing 'upper' from the argilla namespace" - ): + with pytest.warns(FutureWarning, match="Importing 'upper' from the argilla namespace"): assert lazy_module.upper() == ".dep_mock_module".upper() with pytest.raises(AttributeError): From 4c5f51377e374fb30649bdc7b9a3291db21c5bb8 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Thu, 16 Feb 2023 14:40:26 +0100 Subject: [PATCH 10/45] Use `rich` for logging, tracebacks, printing, progressbars (#2350) Closes #1843 Hello! ## Pull Request overview * Use [`rich`](https://github.com/Textualize/rich) for logging, tracebacks, printing and progressbars. * Add `rich` as a dependency. * Remove `loguru` as a dependency and remove all mentions of it in the codebase. * Simplify logging configuration according to the logging documentation. * Update logging tests. ## Before & After [`rich`](https://github.com/Textualize/rich) is a large Python library for very colorful formatting in the terminal. Most importantly (in my opinion), it improves the readability of logs and tracebacks. Let's go over some before-and-afters:
Printing, Logging & Progressbars ### Before: ![image](https://user-images.githubusercontent.com/37621491/219089678-e57906d3-568d-480e-88a4-9240397f5229.png) ### After: ![image](https://user-images.githubusercontent.com/37621491/219089826-646d57a6-7e5b-426f-9ab1-d6d6317ec885.png) Note that for the logs, the repeated information on the left side is removed. Beyond that, the file location from which the log originates is moved to the right side. Beyond that, the progressbar has been updated, ahd the URL in the printed output has been highlighted automatically.
Tracebacks ### Before: ![image](https://user-images.githubusercontent.com/37621491/219090868-42cfe128-fd98-47ec-9d38-6f6f52a21373.png) ### After: ![image](https://user-images.githubusercontent.com/37621491/219090903-86f1fe11-d509-440d-8a6a-2833c344707b.png) --- ### Before: ![image](https://user-images.githubusercontent.com/37621491/219091343-96bae874-a673-4281-80c5-caebb67e348e.png) ### After: ![image](https://user-images.githubusercontent.com/37621491/219091193-d4cb1f64-11a7-4783-a9b2-0aec1abb8eb7.png) --- ### Before ![image](https://user-images.githubusercontent.com/37621491/219091791-aa8969a1-e0c1-4708-a23d-38d22c2406f2.png) ### After ![image](https://user-images.githubusercontent.com/37621491/219091878-e24c1f6b-83fa-4fed-9705-ede522faee82.png)
## Notes Note that there are some changes in the logging configuration. Most of all, it has been simplified according to the note from [here](https://docs.python.org/3/library/logging.html#logging.Logger.propagate). In my changes, I only attach our handler to the root logger and let propagation take care of the rest. Beyond that, I've set `rich` to 13.0.1 as newer versions experience a StopIteration error like discussed [here](https://github.com/Textualize/rich/issues/2800#issuecomment-1428764064). I've replaced `tqdm` with `rich` Progressbar when logging. However, I've kept the `tqdm` progressbar for the [Weak Labeling](https://github.com/argilla-io/argilla/blob/develop/src/argilla/labeling/text_classification/weak_labels.py) for now. One difference between the old situation and now is that all of the logs are displayed during `pytest` under "live log call" (so, including expected errors), while earlier only warnings were shown. ## What to review? Please do the following when reviewing: 1. Ensuring that `rich` is correctly set to be installed whenever someone installs `argilla`. I always set my dependencies explicitly in setup.py like [here](https://github.com/nltk/nltk/blob/develop/setup.py#L115) or [here](https://github.com/huggingface/setfit/blob/78851287535305ef32f789c7a87004628172b5b6/setup.py#L47-L48), but the one for `argilla` is [empty](https://github.com/argilla-io/argilla/blob/develop/setup.py), and `pyproject.toml` is used instead. I'd like for someone to look this over. 2. Fetch this branch and run some arbitrary code. Load some data, log some data, crash some programs, and get an idea of the changes. Especially changes to loggers and tracebacks can be a bit personal, so I'd like to get people on board with this. Otherwise we can scrap it or find a compromise. After all, this is also a design PR. 3. Please have a look at my discussion points below. ## Discussion `rich` is quite configurable, so there's some changes that we can make still. 1. The `RichHandler` logging handler can be modified to e.g. include rich tracebacks in their logs as discussed [here](https://rich.readthedocs.io/en/latest/logging.html#handle-exceptions). Are we interested in this? 2. The `rich` traceback handler can be set up to include local variables in its traceback:
Click to see a rich traceback with local variables ![image](https://user-images.githubusercontent.com/37621491/219096029-796b57ee-2f1b-485f-af35-c3effd44200b.png)
Are we interested in this? I think this is a bit overkill in my opinion. 3. We can suppress frames from certain Python modules to exclude them from the rich tracebacks. Are we interested in this? 4. The default rich traceback shows a maximum of 100 frames, which is a *lot*. Are we interested in reducing this to only show the first and last X? 5. The progress bar doesn't automatically stretch to fill the full available width, while `tqdm` does. If we want, we can set `expand=True` and it'll also expand to the entire width. Are we interested in this? 6. The progress "bar" does not need to be a bar, we can also use e.g. a spinner animation. See some more info [here](https://rich.readthedocs.io/en/latest/progress.html#columns). Are we interested in this? --- **Type of change** - [x] Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** I've updated the tests according to my changes. **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I added comments to my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - Tom Aarsen --- environment_dev.yml | 6 +- pyproject.toml | 6 +- src/argilla/__init__.py | 8 +++ src/argilla/client/client.py | 40 ++++++----- src/argilla/logging.py | 70 +++---------------- tests/conftest.py | 19 +---- .../text_classification/test_label_errors.py | 3 +- tests/test_init.py | 7 +- tests/test_logging.py | 6 +- 9 files changed, 56 insertions(+), 109 deletions(-) diff --git a/environment_dev.yml b/environment_dev.yml index dcb2c0b7fd..3b73689e5c 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -43,9 +43,9 @@ dependencies: - pgmpy - plotly>=4.1.0 - snorkel>=0.9.7 - - spacy==3.1.0 - - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0.tar.gz + - spacy==3.5.0 + - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0.tar.gz - transformers[torch]~=4.18.0 - - loguru + - rich==13.0.1 # install Argilla in editable mode - -e .[server,listeners] diff --git a/pyproject.toml b/pyproject.toml index 42133cb2c7..bf7123e24f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,11 +38,13 @@ dependencies = [ "wrapt >= 1.13,< 1.15", # weaksupervision "numpy < 1.24.0", + # for progressbars "tqdm >= 4.27.0", # monitor background consumers "backoff", - "monotonic" - + "monotonic", + # for logging, tracebacks, printing, progressbars + "rich <= 13.0.1" ] dynamic = ["version"] diff --git a/src/argilla/__init__.py b/src/argilla/__init__.py index 1dbd654c9f..81ed00085a 100644 --- a/src/argilla/__init__.py +++ b/src/argilla/__init__.py @@ -26,6 +26,14 @@ from . import _version from .utils import LazyargillaModule as _LazyargillaModule +try: + from rich.traceback import install as _install_rich + + # Rely on `rich` for tracebacks + _install_rich() +except ModuleNotFoundError: + pass + __version__ = _version.version if _TYPE_CHECKING: diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index ef36f35800..70f9938509 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -20,7 +20,8 @@ from asyncio import Future from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union -from tqdm.auto import tqdm +from rich import print as rprint +from rich.progress import Progress from argilla._constants import ( _OLD_WORKSPACE_HEADER_NAME, @@ -363,25 +364,26 @@ async def log_async( raise InputValueError(f"Unknown record type {record_type}. Available values are" f" {Record.__args__}") processed, failed = 0, 0 - progress_bar = tqdm(total=len(records), disable=not verbose) - for i in range(0, len(records), chunk_size): - chunk = records[i : i + chunk_size] - - response = await async_bulk( - client=self._client, - name=name, - json_body=bulk_class( - tags=tags, - metadata=metadata, - records=[creation_class.from_client(r) for r in chunk], - ), - ) + with Progress() as progress_bar: + task = progress_bar.add_task("Logging...", total=len(records), visible=verbose) + + for i in range(0, len(records), chunk_size): + chunk = records[i : i + chunk_size] + + response = await async_bulk( + client=self._client, + name=name, + json_body=bulk_class( + tags=tags, + metadata=metadata, + records=[creation_class.from_client(r) for r in chunk], + ), + ) - processed += response.parsed.processed - failed += response.parsed.failed + processed += response.parsed.processed + failed += response.parsed.failed - progress_bar.update(len(chunk)) - progress_bar.close() + progress_bar.update(task, advance=len(chunk)) # TODO: improve logging policy in library if verbose: @@ -389,7 +391,7 @@ async def log_async( workspace = self.get_workspace() if not workspace: # Just for backward comp. with datasets with no workspaces workspace = "-" - print(f"{processed} records logged to" f" {self._client.base_url}/datasets/{workspace}/{name}") + rprint(f"{processed} records logged to {self._client.base_url}/datasets/{workspace}/{name}") # Creating a composite BulkResponse with the total processed and failed return BulkResponse(dataset=name, processed=processed, failed=failed) diff --git a/src/argilla/logging.py b/src/argilla/logging.py index 94b1cc249e..51e3c56b49 100644 --- a/src/argilla/logging.py +++ b/src/argilla/logging.py @@ -18,13 +18,13 @@ """ import logging -from logging import Logger +from logging import Logger, StreamHandler from typing import Type try: - from loguru import logger + from rich.logging import RichHandler as ArgillaHandler except ModuleNotFoundError: - logger = None + ArgillaHandler = StreamHandler def full_qualified_class_name(_class: Type) -> str: @@ -60,64 +60,10 @@ def logger(self) -> logging.Logger: return self.__logger__ -class LoguruLoggerHandler(logging.Handler): - """This logging handler enables an easy way to use loguru fo all built-in logger traces""" - - __LOGLEVEL_MAPPING__ = { - 50: "CRITICAL", - 40: "ERROR", - 30: "WARNING", - 20: "INFO", - 10: "DEBUG", - 0: "NOTSET", - } - - @property - def is_available(self) -> bool: - """Return True if handler can tackle log records. False otherwise""" - return logger is not None - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - if not self.is_available: - self.emit = lambda record: None - - def emit(self, record: logging.LogRecord): - try: - level = logger.level(record.levelname).name - except AttributeError: - level = self.__LOGLEVEL_MAPPING__[record.levelno] - - frame, depth = logging.currentframe(), 2 - while frame.f_code.co_filename == logging.__file__: - frame = frame.f_back - depth += 1 - - log = logger.bind(request_id="argilla") - log.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) - - def configure_logging(): """Normalizes logging configuration for argilla and its dependencies""" - intercept_handler = LoguruLoggerHandler() - if not intercept_handler.is_available: - return - - logging.basicConfig(handlers=[intercept_handler], level=logging.WARNING) - for name in logging.root.manager.loggerDict: - logger_ = logging.getLogger(name) - logger_.handlers = [] - - for name in [ - "uvicorn", - "uvicorn.lifespan", - "uvicorn.error", - "uvicorn.access", - "fastapi", - "argilla", - "argilla.server", - ]: - logger_ = logging.getLogger(name) - logger_.propagate = False - logger_.handlers = [intercept_handler] + handler = ArgillaHandler() + + # See the note here: https://docs.python.org/3/library/logging.html#logging.Logger.propagate + # We only attach our handler to the root logger and let propagation take care of the rest + logging.basicConfig(handlers=[handler], level=logging.WARNING) diff --git a/tests/conftest.py b/tests/conftest.py index 33738a6db6..f8495bff2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,15 +15,10 @@ import httpx import pytest from _pytest.logging import LogCaptureFixture -from argilla.client.sdk.users import api as users_api -from argilla.server.commons import telemetry - -try: - from loguru import logger -except ModuleNotFoundError: - logger = None from argilla import app from argilla.client.api import active_api +from argilla.client.sdk.users import api as users_api +from argilla.server.commons import telemetry from starlette.testclient import TestClient from .helpers import SecuredClient @@ -68,13 +63,3 @@ def whoami_mocked(client): monkeypatch.setattr(rb_api._client, "__httpx__", client_) yield client_ - - -@pytest.fixture -def caplog(caplog: LogCaptureFixture): - if not logger: - yield caplog - else: - handler_id = logger.add(caplog.handler, format="{message}") - yield caplog - logger.remove(handler_id) diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py index b07780ec71..0e46394bf0 100644 --- a/tests/labeling/text_classification/test_label_errors.py +++ b/tests/labeling/text_classification/test_label_errors.py @@ -17,6 +17,7 @@ import argilla as ar import cleanlab import pytest +from _pytest.logging import LogCaptureFixture from argilla.labeling.text_classification import find_label_errors from argilla.labeling.text_classification.label_errors import ( MissingPredictionError, @@ -70,7 +71,7 @@ def test_no_records(): find_label_errors(records) -def test_multi_label_warning(caplog): +def test_multi_label_warning(caplog: LogCaptureFixture): record = ar.TextClassificationRecord( text="test", prediction=[("mock", 0.0), ("mock2", 0.0)], diff --git a/tests/test_init.py b/tests/test_init.py index 9a3a37b103..753e230216 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -16,7 +16,7 @@ import logging import sys -from argilla.logging import LoguruLoggerHandler +from argilla.logging import ArgillaHandler from argilla.utils import LazyargillaModule @@ -25,4 +25,7 @@ def test_lazy_module(): def test_configure_logging_call(): - assert isinstance(logging.getLogger("argilla").handlers[0], LoguruLoggerHandler) + # Ensure that the root logger uses the ArgillaHandler (RichHandler if rich is installed), + # whereas the other loggers do not have handlers + assert isinstance(logging.getLogger().handlers[0], ArgillaHandler) + assert len(logging.getLogger("argilla").handlers) == 0 diff --git a/tests/test_logging.py b/tests/test_logging.py index 054deb7740..5f7f9f5be4 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from argilla.logging import LoggingMixin, LoguruLoggerHandler +from argilla.logging import ArgillaHandler, LoggingMixin class LoggingForTest(LoggingMixin): @@ -50,8 +50,8 @@ def test_logging_mixin_without_breaking_constructors(): def test_logging_handler(mocker): - mocker.patch.object(LoguruLoggerHandler, "emit", autospec=True) - handler = LoguruLoggerHandler() + mocker.patch.object(ArgillaHandler, "emit", autospec=True) + handler = ArgillaHandler() logger = logging.getLogger(__name__) logger.handlers = [handler] From 5a8bb28210ec1255d3805baf3bc1735e106fe4c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?= Date: Mon, 20 Feb 2023 07:57:00 +0100 Subject: [PATCH 11/45] chore: Replace old recognai emails with argilla ones (#2365) # Description Replace old `recogn.ai` emails with `argilla.io` ones. This will improve [argilla Pypi package page](https://pypi.org/project/argilla/) showing the correct contact emails. We still have a reference to `recogn.ai` on this JS model file: https://github.com/argilla-io/argilla/blob/develop/frontend/models/Dataset.js#L22 Please @frascuchon and @keithCuniah can you confirm if that JS code can be safely changed and how? --- CODE_OF_CONDUCT.md | 2 +- pyproject.toml | 4 ++-- tests/server/security/test_model.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index f18617bd23..d86e29843e 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -60,7 +60,7 @@ representative at an online or offline event. Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at -contact@recogn.ai. +contact@argilla.io. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the diff --git a/pyproject.toml b/pyproject.toml index bf7123e24f..34f978ea40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,10 +20,10 @@ keywords = [ "mlops" ] authors = [ - {name = "recognai", email = "contact@recogn.ai"} + {name = "argilla", email = "contact@argilla.io"} ] maintainers = [ - {name = "recognai", email = "contact@recogn.ai"} + {name = "argilla", email = "contact@argilla.io"} ] dependencies = [ # Client diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py index 1511695f63..e89ac76842 100644 --- a/tests/server/security/test_model.py +++ b/tests/server/security/test_model.py @@ -18,7 +18,7 @@ from pydantic import ValidationError -@pytest.mark.parametrize("email", ["my@email.com", "infra@recogn.ai"]) +@pytest.mark.parametrize("email", ["my@email.com", "infra@argilla.io"]) def test_valid_mail(email): user = User(username="user", email=email) assert user.email == email From 175a52a3046eba79c2d56772150a53f4ed469c9e Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Mon, 20 Feb 2023 16:39:56 +0100 Subject: [PATCH 12/45] refactor: remove the classification labeling rules service (#2361) # Description This PR move the labeling rule operations from the labeling rule service to the text-classfication service. **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [x] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **Checklist** - [x] I have merged the original branch into my forked branch - [x] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I added comments to my code - [x] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --- .../server/apis/v0/handlers/metrics.py | 17 ++- .../apis/v0/handlers/text_classification.py | 11 +- .../server/daos/backend/metrics/base.py | 11 +- .../backend/metrics/text_classification.py | 2 +- src/argilla/server/security/settings.py | 14 -- src/argilla/server/services/datasets.py | 3 + .../server/services/metrics/service.py | 13 ++ .../tasks/text_classification/__init__.py | 1 - .../labeling_rules_service.py | 141 ------------------ .../tasks/text_classification/metrics.py | 8 + .../tasks/text_classification/model.py | 13 ++ .../tasks/text_classification/service.py | 119 +++++++++++---- tests/server/metrics/test_api.py | 38 ++++- 13 files changed, 189 insertions(+), 202 deletions(-) delete mode 100644 src/argilla/server/security/settings.py delete mode 100644 src/argilla/server/services/tasks/text_classification/labeling_rules_service.py diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py index 3f4e0b6743..9408f1e3f0 100644 --- a/src/argilla/server/apis/v0/handlers/metrics.py +++ b/src/argilla/server/apis/v0/handlers/metrics.py @@ -14,9 +14,9 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Optional +from typing import Any, Dict, List, Optional -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Query, Request, Security from pydantic import BaseModel, Field from argilla.server.apis.v0.helpers import deprecate_endpoint @@ -36,6 +36,8 @@ class MetricInfo(BaseModel): @dataclass class MetricSummaryParams: + request: Request + interval: Optional[float] = Query( default=None, gt=0.0, @@ -47,6 +49,15 @@ class MetricSummaryParams: description="The number of terms for terminological summaries", ) + @property + def parameters(self) -> Dict[str, Any]: + """Returns dynamic metric args found in the request query params""" + return { + "interval": self.interval, + "size": self.size, + **{k: v for k, v in self.request.query_params.items() if k not in ["interval", "size"]}, + } + def configure_router(router: APIRouter, cfg: TaskConfig): base_metrics_endpoint = f"/{cfg.task}/{{name}}/metrics" @@ -112,7 +123,7 @@ def metric_summary( metric=metric_, record_class=record_class, query=query, - **vars(metric_params), + **metric_params.parameters, ) diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py index e5d4ca840b..59f9436425 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification.py +++ b/src/argilla/server/apis/v0/handlers/text_classification.py @@ -316,7 +316,7 @@ async def list_labeling_rules( as_dataset_class=TasksFactory.get_task_dataset(task_type), ) - return [LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset)] + return [LabelingRule.parse_obj(rule) for rule in service.list_labeling_rules(dataset)] @deprecate_endpoint( path=f"{new_base_endpoint}/labeling/rules", @@ -379,7 +379,7 @@ async def compute_rule_metrics( as_dataset_class=TasksFactory.get_task_dataset(task_type), ) - return service.compute_rule_metrics(dataset, rule_query=query, labels=labels) + return service.compute_labeling_rule(dataset, rule_query=query, labels=labels) @deprecate_endpoint( path=f"{new_base_endpoint}/labeling/rules/metrics", @@ -404,7 +404,7 @@ async def compute_dataset_rules_metrics( workspace=common_params.workspace, as_dataset_class=TasksFactory.get_task_dataset(task_type), ) - metrics = service.compute_overall_rules_metrics(dataset) + metrics = service.compute_all_labeling_rules(dataset) return DatasetLabelingRulesMetricsSummary.parse_obj(metrics) @deprecate_endpoint( @@ -456,10 +456,7 @@ async def get_rule( workspace=common_params.workspace, as_dataset_class=TasksFactory.get_task_dataset(task_type), ) - rule = service.find_labeling_rule( - dataset, - rule_query=query, - ) + rule = service.find_labeling_rule(dataset, rule_query=query) return LabelingRule.parse_obj(rule) @deprecate_endpoint( diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py index 82bd60a1e7..344dfeba18 100644 --- a/src/argilla/server/daos/backend/metrics/base.py +++ b/src/argilla/server/daos/backend/metrics/base.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from argilla.server.daos.backend.query_helpers import aggregations @@ -22,6 +23,9 @@ from argilla.server.daos.backend.client_adapters.base import IClientAdapter +_LOGGER = logging.getLogger(__file__) + + @dataclasses.dataclass class ElasticsearchMetric: id: str @@ -37,11 +41,14 @@ def __post_init__(self): def get_function_arg_names(func): return func.__code__.co_varnames - def aggregation_request(self, *args, **kwargs) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + def aggregation_request(self, *args, **kwargs) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]: """ Configures the summary es aggregation definition """ - return {self.id: self._build_aggregation(*args, **kwargs)} + try: + return {self.id: self._build_aggregation(*args, **kwargs)} + except TypeError as ex: + _LOGGER.warning(f"Cannot build metric for metric {self.id}. Error: {ex}. Skipping...") def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/src/argilla/server/daos/backend/metrics/text_classification.py b/src/argilla/server/daos/backend/metrics/text_classification.py index 8d3a83228c..cf54f5e12d 100644 --- a/src/argilla/server/daos/backend/metrics/text_classification.py +++ b/src/argilla/server/daos/backend/metrics/text_classification.py @@ -43,7 +43,7 @@ def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]: class LabelingRulesMetric(ElasticsearchMetric): id: str - def _build_aggregation(self, rule_query: str, labels: Optional[List[str]]) -> Dict[str, Any]: + def _build_aggregation(self, rule_query: str, labels: Optional[List[str]] = None) -> Dict[str, Any]: annotated_records_filter = filters.exists_field("annotated_as") rule_query_filter = filters.text_query(rule_query) aggr_filters = { diff --git a/src/argilla/server/security/settings.py b/src/argilla/server/security/settings.py deleted file mode 100644 index 168dce9eb1..0000000000 --- a/src/argilla/server/security/settings.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding=utf-8 -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index b6af813d4b..1c4e7f6843 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -230,6 +230,9 @@ async def get_settings( raise EntityNotFoundError(name=dataset.name, type=class_type) return class_type.parse_obj(settings.dict()) + def raw_dataset_update(self, dataset): + self.__dao__.update_dataset(dataset) + async def save_settings( self, user: User, dataset: ServiceDataset, settings: ServiceDatasetSettings ) -> ServiceDatasetSettings: diff --git a/src/argilla/server/services/metrics/service.py b/src/argilla/server/services/metrics/service.py index 9989690387..90b86cf294 100644 --- a/src/argilla/server/services/metrics/service.py +++ b/src/argilla/server/services/metrics/service.py @@ -116,3 +116,16 @@ def summarize_metric( dataset=dataset, query=query, ) + + def annotated_records(self, dataset: ServiceDataset) -> int: + """Return the number of annotated records for a dataset""" + results = self.__dao__.search_records( + dataset, + size=0, + search=DaoRecordsSearch(query=ServiceBaseRecordsQuery(has_annotation=True)), + ) + return results.total + + def total_records(self, dataset: ServiceDataset) -> int: + """Return the total number of records for a given dataset""" + return self.__dao__.search_records(dataset, size=0).total diff --git a/src/argilla/server/services/tasks/text_classification/__init__.py b/src/argilla/server/services/tasks/text_classification/__init__.py index 1a0c078cea..b8298320b7 100644 --- a/src/argilla/server/services/tasks/text_classification/__init__.py +++ b/src/argilla/server/services/tasks/text_classification/__init__.py @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .labeling_rules_service import LabelingService from .service import TextClassificationService diff --git a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py deleted file mode 100644 index 38471bfccc..0000000000 --- a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Tuple - -from fastapi import Depends -from pydantic import BaseModel, Field - -from argilla.server.daos.datasets import DatasetsDAO -from argilla.server.daos.models.records import DaoRecordsSearch -from argilla.server.daos.records import DatasetRecordsDAO -from argilla.server.errors import EntityAlreadyExistsError, EntityNotFoundError -from argilla.server.services.search.model import ServiceBaseRecordsQuery -from argilla.server.services.tasks.text_classification.model import ( - ServiceLabelingRule, - ServiceTextClassificationDataset, -) - - -class DatasetLabelingRulesSummary(BaseModel): - covered_records: int - annotated_covered_records: int - - -class LabelingRuleSummary(BaseModel): - covered_records: int - annotated_covered_records: int - correct_records: int = Field(default=0) - incorrect_records: int = Field(default=0) - precision: Optional[float] = None - - -class LabelingService: - _INSTANCE = None - - @classmethod - def get_instance( - cls, - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - records: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), - ): - if cls._INSTANCE is None: - cls._INSTANCE = cls(datasets, records) - return cls._INSTANCE - - def __init__(self, datasets: DatasetsDAO, records: DatasetRecordsDAO): - self.__datasets__ = datasets - self.__records__ = records - - # TODO(@frascuchon): Move all rules management methods to the common datasets service like settings - def list_rules(self, dataset: ServiceTextClassificationDataset) -> List[ServiceLabelingRule]: - """List a set of rules for a given dataset""" - return dataset.rules - - def delete_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str): - """Delete a rule from a dataset by its defined query string""" - new_rules_set = [r for r in dataset.rules if r.query != rule_query] - if len(dataset.rules) != new_rules_set: - dataset.rules = new_rules_set - self.__datasets__.update_dataset(dataset) - - def add_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> ServiceLabelingRule: - """Adds a rule to a dataset""" - for r in dataset.rules: - if r.query == rule.query: - raise EntityAlreadyExistsError(rule.query, type=ServiceLabelingRule) - dataset.rules.append(rule) - self.__datasets__.update_dataset(dataset) - return rule - - def compute_rule_metrics( - self, - dataset: ServiceTextClassificationDataset, - rule_query: str, - labels: Optional[List[str]] = None, - ) -> Tuple[int, int, LabelingRuleSummary]: - """Computes metrics for given rule query and optional label against a set of rules""" - - annotated_records = self._count_annotated_records(dataset) - dataset_records = self.__records__.search_records(dataset, size=0).total - metric_data = self.__records__.compute_metric( - dataset=dataset, - metric_id="labeling_rule", - metric_params=dict(rule_query=rule_query, labels=labels), - ) - - return ( - dataset_records, - annotated_records, - LabelingRuleSummary.parse_obj(metric_data), - ) - - def _count_annotated_records(self, dataset: ServiceTextClassificationDataset) -> int: - results = self.__records__.search_records( - dataset, - size=0, - search=DaoRecordsSearch(query=ServiceBaseRecordsQuery(has_annotation=True)), - ) - return results.total - - def all_rules_metrics( - self, dataset: ServiceTextClassificationDataset - ) -> Tuple[int, int, DatasetLabelingRulesSummary]: - annotated_records = self._count_annotated_records(dataset) - dataset_records = self.__records__.search_records(dataset, size=0).total - metric_data = self.__records__.compute_metric( - dataset=dataset, - metric_id="dataset_labeling_rules", - metric_params=dict(queries=[r.query for r in dataset.rules]), - ) - - return ( - dataset_records, - annotated_records, - DatasetLabelingRulesSummary.parse_obj(metric_data), - ) - - def find_rule_by_query(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule: - rule_query = rule_query.strip() - for rule in dataset.rules: - if rule.query == rule_query: - return rule - raise EntityNotFoundError(rule_query, type=ServiceLabelingRule) - - def replace_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule): - for idx, r in enumerate(dataset.rules): - if r.query == rule.query: - dataset.rules[idx] = rule - break - self.__datasets__.update_dataset(dataset) diff --git a/src/argilla/server/services/tasks/text_classification/metrics.py b/src/argilla/server/services/tasks/text_classification/metrics.py index d09891620b..2b9bd006a9 100644 --- a/src/argilla/server/services/tasks/text_classification/metrics.py +++ b/src/argilla/server/services/tasks/text_classification/metrics.py @@ -159,5 +159,13 @@ class TextClassificationMetrics(CommonTasksMetrics[ServiceTextClassificationReco id="annotated_as", name="Annotated labels distribution", ), + ServiceBaseMetric( + id="labeling_rule", + name="Labeling rule metric based on a query rule and a set of labels", + ), + ServiceBaseMetric( + id="dataset_labeling_rules", + name="Computes the overall labeling rules stats", + ), ] ) diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py index 73c5cd7d05..f8a82db328 100644 --- a/src/argilla/server/services/tasks/text_classification/model.py +++ b/src/argilla/server/services/tasks/text_classification/model.py @@ -307,3 +307,16 @@ class ServiceTextClassificationQuery(ServiceBaseRecordsQuery): predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) uncovered_by_rules: List[str] = Field(default_factory=list) + + +class DatasetLabelingRulesSummary(BaseModel): + covered_records: int + annotated_covered_records: int + + +class LabelingRuleSummary(BaseModel): + covered_records: int + annotated_covered_records: int + correct_records: int = Field(default=0) + incorrect_records: int = Field(default=0) + precision: Optional[float] = None diff --git a/src/argilla/server/services/tasks/text_classification/service.py b/src/argilla/server/services/tasks/text_classification/service.py index 8a5407258c..79d6ddb473 100644 --- a/src/argilla/server/services/tasks/text_classification/service.py +++ b/src/argilla/server/services/tasks/text_classification/service.py @@ -13,12 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Tuple from fastapi import Depends -from argilla.server.commons.config import TasksFactory -from argilla.server.errors.base_errors import MissingDatasetRecordsError +from argilla.server.errors.base_errors import ( + EntityAlreadyExistsError, + EntityNotFoundError, + MissingDatasetRecordsError, +) +from argilla.server.services.datasets import DatasetsService +from argilla.server.services.metrics import MetricsService from argilla.server.services.search.model import ( ServiceSearchResults, ServiceSortableField, @@ -27,10 +32,14 @@ from argilla.server.services.search.service import SearchRecordsService from argilla.server.services.storage.service import RecordsStorageService from argilla.server.services.tasks.commons import BulkResponse -from argilla.server.services.tasks.text_classification import LabelingService +from argilla.server.services.tasks.text_classification.metrics import ( + TextClassificationMetrics, +) from argilla.server.services.tasks.text_classification.model import ( DatasetLabelingRulesMetricsSummary, + DatasetLabelingRulesSummary, LabelingRuleMetricsSummary, + LabelingRuleSummary, ServiceLabelingRule, ServiceTextClassificationDataset, ServiceTextClassificationQuery, @@ -49,23 +58,26 @@ class TextClassificationService: @classmethod def get_instance( cls, + datasets: DatasetsService = Depends(DatasetsService.get_instance), + metrics: MetricsService = Depends(MetricsService.get_instance), storage: RecordsStorageService = Depends(RecordsStorageService.get_instance), - labeling: LabelingService = Depends(LabelingService.get_instance), search: SearchRecordsService = Depends(SearchRecordsService.get_instance), ) -> "TextClassificationService": if not cls._INSTANCE: - cls._INSTANCE = cls(storage, labeling=labeling, search=search) + cls._INSTANCE = cls(datasets=datasets, metrics=metrics, storage=storage, search=search) return cls._INSTANCE def __init__( self, + datasets: DatasetsService, + metrics: MetricsService, storage: RecordsStorageService, search: SearchRecordsService, - labeling: LabelingService, ): self.__storage__ = storage self.__search__ = search - self.__labeling__ = labeling + self.__metrics__ = metrics + self.__datasets__ = datasets async def add_records( self, @@ -116,9 +128,9 @@ def search( """ - metrics = TasksFactory.find_task_metrics( - dataset.task, - metric_ids={ + metrics = [ + TextClassificationMetrics.find_metric(id) + for id in { "words_cloud", "predicted_by", "predicted_as", @@ -128,8 +140,8 @@ def search( "status_distribution", "metadata", "score", - }, - ) + } + ] results = self.__search__.search( dataset, @@ -209,8 +221,18 @@ def _is_dataset_multi_label(self, dataset: ServiceTextClassificationDataset) -> if results.records: return results.records[0].multi_label - def get_labeling_rules(self, dataset: ServiceTextClassificationDataset) -> Iterable[ServiceLabelingRule]: - return self.__labeling__.list_rules(dataset) + def find_labeling_rule( + self, dataset: ServiceTextClassificationDataset, rule_query: str, error_on_missing: bool = True + ) -> Optional[ServiceLabelingRule]: + rule_query = rule_query.strip() + for rule in dataset.rules: + if rule.query == rule_query: + return rule + if error_on_missing: + raise EntityNotFoundError(rule_query, type=ServiceLabelingRule) + + def list_labeling_rules(self, dataset: ServiceTextClassificationDataset) -> Iterable[ServiceLabelingRule]: + return dataset.rules def add_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> None: """ @@ -224,8 +246,13 @@ def add_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule: Ser rule: The rule """ + self.__normalized_rule__(rule) - self.__labeling__.add_rule(dataset, rule) + if self.find_labeling_rule(dataset, rule_query=rule.query, error_on_missing=False): + raise EntityAlreadyExistsError(rule.query, type=ServiceLabelingRule) + + dataset.rules.append(rule) + self.__datasets__.raw_dataset_update(dataset) def update_labeling_rule( self, @@ -234,7 +261,7 @@ def update_labeling_rule( labels: List[str], description: Optional[str] = None, ) -> ServiceLabelingRule: - found_rule = self.__labeling__.find_rule_by_query(dataset, rule_query) + found_rule = self.find_labeling_rule(dataset, rule_query) found_rule.labels = labels found_rule.label = labels[0] if len(labels) == 1 else None @@ -242,17 +269,22 @@ def update_labeling_rule( found_rule.description = description self.__normalized_rule__(found_rule) - self.__labeling__.replace_rule(dataset, found_rule) - return found_rule + for idx, r in enumerate(dataset.rules): + if r.query == found_rule.query: + dataset.rules[idx] = found_rule + break + self.__datasets__.raw_dataset_update(dataset) - def find_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule: - return self.__labeling__.find_rule_by_query(dataset, rule_query=rule_query) + return found_rule def delete_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str): - if rule_query.strip(): - return self.__labeling__.delete_rule(dataset, rule_query) + """Delete a rule from a dataset by its defined query string""" + new_rules_set = [r for r in dataset.rules if r.query != rule_query] + if len(dataset.rules) != new_rules_set: + dataset.rules = new_rules_set + self.__datasets__.raw_dataset_update(dataset) - def compute_rule_metrics( + def compute_labeling_rule( self, dataset: ServiceTextClassificationDataset, rule_query: str, @@ -288,14 +320,20 @@ def compute_rule_metrics( rule_query = rule_query.strip() if labels is None: - for rule in self.get_labeling_rules(dataset): - if rule.query == rule_query: - labels = rule.labels - break + rule = self.find_labeling_rule(dataset, rule_query=rule_query, error_on_missing=False) + if rule: + labels = rule.labels - total, annotated, metrics = self.__labeling__.compute_rule_metrics( - dataset, rule_query=rule_query, labels=labels + metric_data = self.__metrics__.summarize_metric( + dataset=dataset, + metric=TextClassificationMetrics.find_metric("labeling_rule"), + rule_query=rule_query, + labels=labels, ) + annotated = self.__metrics__.annotated_records(dataset) + total = self.__metrics__.total_records(dataset) + + metrics = LabelingRuleSummary.parse_obj(metric_data) coverage = metrics.covered_records / total if total > 0 else None coverage_annotated = metrics.annotated_covered_records / annotated if annotated > 0 else None @@ -310,8 +348,8 @@ def compute_rule_metrics( precision=metrics.precision if annotated > 0 else None, ) - def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDataset): - total, annotated, metrics = self.__labeling__.all_rules_metrics(dataset) + def compute_all_labeling_rules(self, dataset: ServiceTextClassificationDataset): + total, annotated, metrics = self._compute_all_lb_rules_metrics(dataset) coverage = metrics.covered_records / total if total else None coverage_annotated = metrics.annotated_covered_records / annotated if annotated else None return DatasetLabelingRulesMetricsSummary( @@ -321,6 +359,23 @@ def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDatase annotated_records=annotated, ) + def _compute_all_lb_rules_metrics( + self, dataset: ServiceTextClassificationDataset + ) -> Tuple[int, int, DatasetLabelingRulesSummary]: + annotated_records = self.__metrics__.annotated_records(dataset) + dataset_records = self.__metrics__.total_records(dataset) + metric_data = self.__metrics__.summarize_metric( + dataset=dataset, + metric=TextClassificationMetrics.find_metric(id="dataset_labeling_rules"), + queries=[r.query for r in dataset.rules], + ) + + return ( + dataset_records, + annotated_records, + DatasetLabelingRulesSummary.parse_obj(metric_data), + ) + @staticmethod def __normalized_rule__(rule: ServiceLabelingRule) -> ServiceLabelingRule: if rule.labels and len(rule.labels) == 1: diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py index 0dc30ba139..dec0169c54 100644 --- a/tests/server/metrics/test_api.py +++ b/tests/server/metrics/test_api.py @@ -26,11 +26,15 @@ TokenClassificationRecord, ) from argilla.server.services.metrics.models import CommonTasksMetrics +from argilla.server.services.tasks.text_classification.metrics import ( + TextClassificationMetrics, +) from argilla.server.services.tasks.token_classification.metrics import ( TokenClassificationMetrics, ) COMMON_METRICS_LENGTH = len(CommonTasksMetrics.metrics) +CLASSIFICATION_METRICS_LENGTH = len(TextClassificationMetrics.metrics) def test_wrong_dataset_metrics(mocked_client): @@ -183,7 +187,7 @@ def test_dataset_metrics(mocked_client): metrics = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/metrics").json() - assert len(metrics) == COMMON_METRICS_LENGTH + 5 + assert len(metrics) == CLASSIFICATION_METRICS_LENGTH response = mocked_client.post( f"/api/datasets/TextClassification/{dataset}/metrics/missing_metric:summary", @@ -206,6 +210,38 @@ def test_dataset_metrics(mocked_client): assert response.status_code == 200, f"{metric}: {response.json()}" +def create_some_classification_data(mocked_client, dataset: str, records: list): + request = TextClassificationBulkRequest(records=[TextClassificationRecord.parse_obj(r) for r in records]) + + assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 + assert ( + mocked_client.post( + f"/api/datasets/{dataset}/TextClassification:bulk", + json=request.dict(by_alias=True), + ).status_code + == 200 + ) + + +def test_labeling_rule_metric(mocked_client): + dataset = "test_labeling_rule_metric" + create_some_classification_data( + mocked_client, dataset, records=[{"inputs": {"text": "This is classification record"}}] * 10 + ) + + rule_query = "t*" + response = mocked_client.post( + f"/api/datasets/TextClassification/{dataset}/metrics/labeling_rule:summary?rule_query={rule_query}", + json={}, + ) + assert response.json() == { + "annotated_covered_records": 0, + "correct_records": 0, + "covered_records": 10, + "incorrect_records": 0, + } + + def test_dataset_labels_for_text_classification(mocked_client): records = [ TextClassificationRecord.parse_obj(data) From 3c27fd576719965e09400b01efdada542ec7d993 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Feb 2023 04:40:00 +0000 Subject: [PATCH 13/45] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/charliermarsh/ruff-pre-commit: v0.0.244 → v0.0.249](https://github.com/charliermarsh/ruff-pre-commit/compare/v0.0.244...v0.0.249) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5ba8658505..0f8bffe411 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: additional_dependencies: ["click==8.0.4"] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.244 + rev: v0.0.249 hooks: # Simulate isort via (the much faster) ruff - id: ruff From 8f0d10de2c409d562e3bb758998625a1a430836f Mon Sep 17 00:00:00 2001 From: Gnonpi Date: Tue, 21 Feb 2023 12:56:53 +0100 Subject: [PATCH 14/45] Documentation update: adding missing n (#2362) # Description There's an "n" character missing at the end of the sentence, not a super critical change **Type of change** - [x] Documentation update **How Has This Been Tested** By eye **Checklist** - [x] I have merged the original branch into my forked branch - [x] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I added comments to my code - [x] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --- docs/_source/guides/query_datasets.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_source/guides/query_datasets.md b/docs/_source/guides/query_datasets.md index 0b80e19384..7ae8d1861c 100644 --- a/docs/_source/guides/query_datasets.md +++ b/docs/_source/guides/query_datasets.md @@ -35,7 +35,7 @@ If you do not retrieve any results after a version update, you should use the `w The (arguably) most important fields are the `text` and `text.exact` fields. They both contain the text of the records, however in two different forms: -- the `text` field uses Elasticsearch's [standard analyzer](https://www.elastic.co/guide/en/elasticsearch/reference/7.10/analysis-standard-analyzer.html) that ignores capitalization and removes most of the punctuatio; +- the `text` field uses Elasticsearch's [standard analyzer](https://www.elastic.co/guide/en/elasticsearch/reference/7.10/analysis-standard-analyzer.html) that ignores capitalization and removes most of the punctuation; - the `text.exact` field uses the [whitespace analyzer](https://www.elastic.co/guide/en/elasticsearch/reference/7.10/analysis-whitespace-analyzer.html) that differentiates between lower and upper case, and does take into account punctuation; Let's have a look at a few examples. From 5ab38be2a188eb3f0d7a5159a4b75e358e208fd8 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 22 Feb 2023 11:41:47 +0100 Subject: [PATCH 15/45] ci: Remove Pyre from CI (#2358) Hello! ## Pull Request overview * Remove `pyre` from the CI. * Remove the `pyre` configuration file. ## Details Pyre has not been shown useful. It only causes CI failures, and I don't think anyone actually looks at the logs. Let's save ourselves to red crosses and remove it. After all, the CI doesn't even run on the `develop` branch, only on `main` (and `master` and `releases`). See [here](https://github.com/argilla-io/argilla/actions/runs/4184361432/jobs/7249906904) for an example CI output of the `pyre` workflow. **Type of change** - [x] Removal of unused code **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I added comments to my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works --- I intend to create a handful of other PRs to improve the CI situation. In particular, I want green checkmarks to once again mean that a PR is likely ready, and for red crosses to mean that it most certainly is not. - Tom Aarsen --- .github/workflows/pyre-check.yml | 78 -------------------------------- .pyre_configuration | 33 -------------- 2 files changed, 111 deletions(-) delete mode 100644 .github/workflows/pyre-check.yml delete mode 100644 .pyre_configuration diff --git a/.github/workflows/pyre-check.yml b/.github/workflows/pyre-check.yml deleted file mode 100644 index 3941c480b5..0000000000 --- a/.github/workflows/pyre-check.yml +++ /dev/null @@ -1,78 +0,0 @@ -name: pyre - -on: - push: - branches: [main, master, releases/*] - pull_request: - # The branches below must be a subset of the branches above - branches: [main, master, releases/*] - -jobs: - pyre: - runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} - steps: - - uses: actions/checkout@v2 - - - name: Setup Conda Env 🐍 - uses: conda-incubator/setup-miniconda@v2 - with: - miniforge-variant: Mambaforge - miniforge-version: latest - activate-environment: argilla - use-mamba: true - - - name: Get date for conda cache - id: get-date - run: echo "::set-output name=today::$(/bin/date -u '+%Y%m%d')" - shell: bash - - - name: Cache Conda env - uses: actions/cache@v2 - id: cache - with: - path: ${{ env.CONDA }}/envs - key: conda-${{ runner.os }}-${{ runner.arch }}-${{ steps.get-date.outputs.today }}-${{ hashFiles('environment_dev.yml') }}-${{ env.CACHE_NUMBER }} - env: - # Increase this value to reset cache if etc/example-environment.yml has not changed - CACHE_NUMBER: 0 - - - name: Update environment - if: steps.cache.outputs.cache-hit != 'true' - run: mamba env update -n argilla -f environment_dev.yml - - - name: Cache pip 👜 - uses: actions/cache@v2 - if: steps.filter.outputs.python_code == 'true' - env: - # Increase this value to reset cache if pyproject.toml has not changed - CACHE_NUMBER: 0 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ env.CACHE_NUMBER }}-${{ hashFiles('pyproject.toml') }} - - - name: Install dependencies - # Force install this version since older versions fail with click < 8.0 (included by spacy 3.x) - run: pip install pyre-check==0.9.15 - - - name: Run Pyre - continue-on-error: false - run: | - pyre --output=text check - -# - name: Expose SARIF Resultss -# uses: actions/upload-artifact@v2 -# with: -# name: SARIF Results -# path: sarif.json -# -# - name: Upload SARIF Results -# uses: github/codeql-action/upload-sarif@v1 -# with: -# sarif_file: sarif.json -# -# - name: Fail Command On Errors -# run: | -# if [ "$(cat sarif.json | grep 'PYRE-ERROR')" != "" ]; then cat sarif.json && exit 1; fi diff --git a/.pyre_configuration b/.pyre_configuration deleted file mode 100644 index 6fef55869e..0000000000 --- a/.pyre_configuration +++ /dev/null @@ -1,33 +0,0 @@ -{ - "site_package_search_strategy": "pep561", - - "source_directories": [ - "src" - ], - "search_path": [ - {"site-package": "pytest"}, - {"site-package": "httpx"}, - {"site-package": "datasets"}, - {"site-package": "pandas"}, - {"site-package": "pydantic"}, - {"site-package": "fastapi"}, - {"site-package": "snorkel"}, - {"site-package": "spacy"}, - {"site-package": "cleanlab"}, - {"site-package": "flair"}, - {"site-package": "uvicorn"}, - {"site-package": "tqdm"}, - {"site-package": "sklearn"}, - {"site-package": "flyingsquid"}, - {"site-package": "pgmpy"}, - {"site-package": "faiss"}, - {"site-package": "schedule"}, - {"site-package": "prodict"}, - {"site-package": "plotly"}, - {"site-package": "wrapt"}, - {"site-package": "stopwordsiso"}, - {"site-package": "luqum"}, - {"site-package": "jose"}, - {"site-package": "brotli_asgi"} - ] -} From e999fe211eafbccfb5698ff40e6c06dddf6f5334 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 22 Feb 2023 12:07:21 +0100 Subject: [PATCH 16/45] Refactor/deprecate dataset owner (#2386) # Description Adding dataset `workspace` attribute and deprecating `owner` attribute for current Dataset model. The `owner` will be removed in the next release. **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [x] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I added comments to my code - [x] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --------- Co-authored-by: keithCuniah --- .../_source/guides/log_load_and_prepare_data.ipynb | 3 ++- .../commons/header/help-info/HelpInfo.spec.js | 2 +- .../components/commons/results/results.spec.js | 2 +- .../components/commons/results/resultsList.spec.js | 2 +- .../SimilarityRecordReference.component.spec.js | 2 +- frontend/components/commons/taskSearch.spec.js | 2 +- .../results/recordTextClassification.spec.js | 2 +- .../text2text/results/recordText2Text.spec.js | 2 +- .../text2text/results/text2TextResultsList.spec.js | 2 +- .../token-classifier/results/textSpan.spec.js | 2 +- .../results/tokenClassificationResultsList.spec.js | 2 +- frontend/database/modules/datasets.js | 4 ++-- frontend/models/Dataset.js | 12 ++++++------ frontend/models/TextClassification.js | 14 +++++++------- frontend/models/TokenClassification.js | 2 +- frontend/pages/datasets/index.vue | 10 +++++----- .../components/core/table/BaseTableInfo.spec.js | 8 ++++---- .../core/table/TableFiltrableColumn.spec.js | 6 +++--- .../table/__snapshots__/BaseTableInfo.spec.js.snap | 2 +- .../results/EntitiesSelector.spec.js | 2 +- src/argilla/client/apis/datasets.py | 1 + src/argilla/client/sdk/datasets/models.py | 1 + src/argilla/server/daos/models/datasets.py | 13 +++++++++++-- tests/datasets/test_datasets.py | 10 ++++++++++ tests/server/datasets/test_api.py | 6 +++--- 25 files changed, 68 insertions(+), 46 deletions(-) diff --git a/docs/_source/guides/log_load_and_prepare_data.ipynb b/docs/_source/guides/log_load_and_prepare_data.ipynb index fe0861f2f9..ce21c887ee 100644 --- a/docs/_source/guides/log_load_and_prepare_data.ipynb +++ b/docs/_source/guides/log_load_and_prepare_data.ipynb @@ -452,7 +452,8 @@ "empty_workspace_datasets = [\n", " ds[\"name\"]\n", " for ds in rg_client.http_client.get(\"/api/datasets\")\n", - " if not ds.get(\"owner\", None) # filtering dataset with no workspace (owner field)\n", + " # filtering dataset with no workspace (use `\"owner\"` if you're running this code with server versions <=1.3.0)\n", + " if not ds.get(\"workspace\", None)\n", "]\n", "\n", "rg.set_workspace(\"\") # working from the \"empty\" workspace\n", diff --git a/frontend/components/commons/header/help-info/HelpInfo.spec.js b/frontend/components/commons/header/help-info/HelpInfo.spec.js index 4a39e0c005..ec4f9736fd 100644 --- a/frontend/components/commons/header/help-info/HelpInfo.spec.js +++ b/frontend/components/commons/header/help-info/HelpInfo.spec.js @@ -20,7 +20,7 @@ const options = { component: "helpInfoExplain", }, propsData: { - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetTask: "TextClassification", datasetName: "name", }, diff --git a/frontend/components/commons/results/results.spec.js b/frontend/components/commons/results/results.spec.js index d03da9e8c5..04ef8f98bb 100644 --- a/frontend/components/commons/results/results.spec.js +++ b/frontend/components/commons/results/results.spec.js @@ -9,7 +9,7 @@ const options = { "Text2TextResultsList", ], propsData: { - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetTask: "TextClassification", datasetName: "name", }, diff --git a/frontend/components/commons/results/resultsList.spec.js b/frontend/components/commons/results/resultsList.spec.js index 586d5e4062..1b093feda6 100644 --- a/frontend/components/commons/results/resultsList.spec.js +++ b/frontend/components/commons/results/resultsList.spec.js @@ -21,7 +21,7 @@ const options = { localVue, store, propsData: { - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetTask: "TextClassification", }, }; diff --git a/frontend/components/commons/results/similarity/SimilarityRecordReference.component.spec.js b/frontend/components/commons/results/similarity/SimilarityRecordReference.component.spec.js index 1224292455..4aac5320f2 100644 --- a/frontend/components/commons/results/similarity/SimilarityRecordReference.component.spec.js +++ b/frontend/components/commons/results/similarity/SimilarityRecordReference.component.spec.js @@ -6,7 +6,7 @@ let wrapper = null; const options = { stubs: ["nuxt", "results-record"], propsData: { - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetTask: "TextClassification", dataset: { type: Object, diff --git a/frontend/components/commons/taskSearch.spec.js b/frontend/components/commons/taskSearch.spec.js index b3f12b9bda..929f27ea64 100644 --- a/frontend/components/commons/taskSearch.spec.js +++ b/frontend/components/commons/taskSearch.spec.js @@ -5,7 +5,7 @@ let wrapper = null; const options = { stubs: ["results"], propsData: { - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetTask: "TextClassification", datasetName: "name", }, diff --git a/frontend/components/text-classifier/results/recordTextClassification.spec.js b/frontend/components/text-classifier/results/recordTextClassification.spec.js index 94a3456d58..64090bae0b 100644 --- a/frontend/components/text-classifier/results/recordTextClassification.spec.js +++ b/frontend/components/text-classifier/results/recordTextClassification.spec.js @@ -6,7 +6,7 @@ const options = { stubs: ["record-inputs", "classifier-exploration-area", "base-tag"], propsData: { viewSettings: {}, - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetName: "name", datasetLabels: ["label 1", "label 2"], record: { diff --git a/frontend/components/text2text/results/recordText2Text.spec.js b/frontend/components/text2text/results/recordText2Text.spec.js index 20857b1f97..26c2df6b04 100644 --- a/frontend/components/text2text/results/recordText2Text.spec.js +++ b/frontend/components/text2text/results/recordText2Text.spec.js @@ -6,7 +6,7 @@ let wrapper = null; const options = { stubs: ["text-2-text-list", "record-string-text-2-text"], propsData: { - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetName: "name", viewSettings: {}, record: { diff --git a/frontend/components/text2text/results/text2TextResultsList.spec.js b/frontend/components/text2text/results/text2TextResultsList.spec.js index fc9b1a6f1f..bda7173df8 100644 --- a/frontend/components/text2text/results/text2TextResultsList.spec.js +++ b/frontend/components/text2text/results/text2TextResultsList.spec.js @@ -5,7 +5,7 @@ let wrapper = null; const options = { stubs: ["results-list"], propsData: { - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetTask: "TextClassification", }, }; diff --git a/frontend/components/token-classifier/results/textSpan.spec.js b/frontend/components/token-classifier/results/textSpan.spec.js index f327deca1d..d2879eb1ad 100644 --- a/frontend/components/token-classifier/results/textSpan.spec.js +++ b/frontend/components/token-classifier/results/textSpan.spec.js @@ -9,7 +9,7 @@ const options = { }, }, propsData: { - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetName: "name", datasetLastSelectedEntity: {}, datasetEntities: [ diff --git a/frontend/components/token-classifier/results/tokenClassificationResultsList.spec.js b/frontend/components/token-classifier/results/tokenClassificationResultsList.spec.js index 8a9c2f13e4..8951a7744c 100644 --- a/frontend/components/token-classifier/results/tokenClassificationResultsList.spec.js +++ b/frontend/components/token-classifier/results/tokenClassificationResultsList.spec.js @@ -5,7 +5,7 @@ let wrapper = null; const options = { stubs: ["results-list"], propsData: { - datasetId: ["owner", "name"], + datasetId: ["workspace", "name"], datasetTask: "TextClassification", }, }; diff --git a/frontend/database/modules/datasets.js b/frontend/database/modules/datasets.js index 221699b533..d58006c1ec 100644 --- a/frontend/database/modules/datasets.js +++ b/frontend/database/modules/datasets.js @@ -69,7 +69,7 @@ async function _getOrFetchDataset({ workspace, name }) { } await ObservationDataset.api().get(`/datasets/${name}`, { dataTransformer: ({ data }) => { - data.owner = data.owner || workspace; + data.workspace = data.workspace || workspace; return data; }, }); @@ -600,7 +600,7 @@ const actions = { persistBy: "create", dataTransformer: ({ data }) => { return data.map((datasource) => { - datasource.owner = datasource.owner || NO_WORKSPACE; + datasource.workspace = datasource.workspace || NO_WORKSPACE; return datasource; }); }, diff --git a/frontend/models/Dataset.js b/frontend/models/Dataset.js index 2d97418a3a..efc79c8eab 100644 --- a/frontend/models/Dataset.js +++ b/frontend/models/Dataset.js @@ -24,9 +24,9 @@ const USER_DATA_METADATA_KEY = "rubrix.recogn.ai/ui/custom/userData.v1"; class ObservationDataset extends Model { static entity = "datasets"; - // TODO: Combine name + owner for primary key. + // TODO: Combine name + workspace for primary key. // This should fix https://github.com/recognai/rubrix/issues/736 - static primaryKey = ["owner", "name"]; + static primaryKey = ["workspace", "name"]; static #registeredDatasetClasses = {}; @@ -75,7 +75,7 @@ class ObservationDataset extends Model { } get id() { - return [this.owner, this.name]; + return [this.workspace, this.name]; } get visibleRecords() { @@ -85,19 +85,19 @@ class ObservationDataset extends Model { static fields() { return { name: this.string(null), - owner: this.string(null), + workspace: this.string(null), metadata: this.attr(null), tags: this.attr(null), task: this.string(null), created_at: this.string(null), last_updated: this.string(null), - // This will be normalized in a future PR using also owner for relational ids + // This will be normalized in a future PR using also workspace for relational ids viewSettings: this.hasOne(DatasetViewSettings, "id", "name"), }; } } -const getDatasetModelPrimaryKey = ({ owner, name }) => [owner, name]; +const getDatasetModelPrimaryKey = ({ workspace, name }) => [workspace, name]; export { ObservationDataset, diff --git a/frontend/models/TextClassification.js b/frontend/models/TextClassification.js index ccc9144272..0d26a8e02a 100644 --- a/frontend/models/TextClassification.js +++ b/frontend/models/TextClassification.js @@ -152,7 +152,7 @@ class TextClassificationDataset extends ObservationDataset { where: this.id, data: [ { - owner: this.owner, + workspace: this.workspace, name: this.name, _labels: labels, settings, @@ -303,7 +303,7 @@ class TextClassificationDataset extends ObservationDataset { return await TextClassificationDataset.insertOrUpdate({ where: this.id, data: { - owner: this.owner, + workspace: this.workspace, name: this.name, perRuleQueryMetrics, rulesOveralMetrics: overalMetrics, @@ -315,7 +315,7 @@ class TextClassificationDataset extends ObservationDataset { const rules = await this._fetchAllRules(); await TextClassificationDataset.insertOrUpdate({ data: { - owner: this.owner, + workspace: this.workspace, name: this.name, rules, }, @@ -394,7 +394,7 @@ class TextClassificationDataset extends ObservationDataset { await TextClassificationDataset.insertOrUpdate({ data: { - owner: this.owner, + workspace: this.workspace, name: this.name, activeRule: rule || { query, @@ -427,7 +427,7 @@ class TextClassificationDataset extends ObservationDataset { await TextClassificationDataset.insertOrUpdate({ data: { - owner: this.owner, + workspace: this.workspace, name: this.name, rules, activeRule, @@ -454,7 +454,7 @@ class TextClassificationDataset extends ObservationDataset { await TextClassificationDataset.insertOrUpdate({ data: { - owner: this.owner, + workspace: this.workspace, name: this.name, rules, perRuleQueryMetrics, @@ -466,7 +466,7 @@ class TextClassificationDataset extends ObservationDataset { async clearCurrentLabelingRule() { await TextClassificationDataset.insertOrUpdate({ data: { - owner: this.owner, + workspace: this.workspace, name: this.name, activeRule: null, activeRuleMetrics: null, diff --git a/frontend/models/TokenClassification.js b/frontend/models/TokenClassification.js index c937d5d154..5dd03c4382 100644 --- a/frontend/models/TokenClassification.js +++ b/frontend/models/TokenClassification.js @@ -104,7 +104,7 @@ class TokenClassificationDataset extends ObservationDataset { where: this.id, data: [ { - owner: this.owner, + workspace: this.workspace, name: this.name, settings, }, diff --git a/frontend/pages/datasets/index.vue b/frontend/pages/datasets/index.vue index 3521c9c3b6..6267d64acb 100644 --- a/frontend/pages/datasets/index.vue +++ b/frontend/pages/datasets/index.vue @@ -82,7 +82,7 @@ export default { { name: "Name", field: "name", class: "table-info__title", type: "link" }, { name: "Workspace", - field: "owner", + field: "workspace", class: "text", type: "text", filtrable: "true", @@ -142,7 +142,7 @@ export default { const tasks = this.tasks; const tags = this.tags; return [ - { column: "owner", values: workspaces || [] }, + { column: "workspace", values: workspaces || [] }, { column: "task", values: tasks || [] }, { column: "tags", values: tags || [] }, ]; @@ -217,7 +217,7 @@ export default { _deleteDataset: "entities/datasets/deleteDataset", }), onColumnFilterApplied({ column, values }) { - if (column === "owner") { + if (column === "workspace") { if (values !== this.workspaces) { this.$router.push({ query: { ...this.$route.query, workspace: values }, @@ -243,7 +243,7 @@ export default { } }, datasetWorkspace(dataset) { - var workspace = dataset.owner; + var workspace = dataset.workspace; if (workspace === null || workspace === "null") { workspace = this.workspace; } @@ -285,7 +285,7 @@ export default { }, deleteDataset(dataset) { this._deleteDataset({ - workspace: dataset.owner, + workspace: dataset.workspace, name: dataset.name, }); this.closeModal(); diff --git a/frontend/specs/components/core/table/BaseTableInfo.spec.js b/frontend/specs/components/core/table/BaseTableInfo.spec.js index 13b6b885a7..f6e02b996c 100644 --- a/frontend/specs/components/core/table/BaseTableInfo.spec.js +++ b/frontend/specs/components/core/table/BaseTableInfo.spec.js @@ -20,7 +20,7 @@ function mountBaseTableInfo() { idx: 1, key: "column1", class: "text", - field: "owner", + field: "workspace", filtrable: "true", name: "Workspace", type: "text", @@ -30,13 +30,13 @@ function mountBaseTableInfo() { { key: "data1", name: "dataset_1", - owner: "recognai", + workspace: "recognai", task: "TokenClassification", }, { key: "data2", name: "dataset_2", - owner: "recognai", + workspace: "recognai", task: "TokenClassification", }, ], @@ -52,7 +52,7 @@ function mountBaseTableInfo() { hideButton: false, noDataInfo: undefined, querySearch: undefined, - filterFromRoute: "owner", + filterFromRoute: "workspace", searchOn: "name", sortedByField: "last_updated", sortedOrder: "desc", diff --git a/frontend/specs/components/core/table/TableFiltrableColumn.spec.js b/frontend/specs/components/core/table/TableFiltrableColumn.spec.js index e41446f0d1..e6cc404f40 100644 --- a/frontend/specs/components/core/table/TableFiltrableColumn.spec.js +++ b/frontend/specs/components/core/table/TableFiltrableColumn.spec.js @@ -5,11 +5,11 @@ function mountTableFiltrableColumn() { return mount(TableFiltrableColumn, { propsData: { filters: { - owner: ["recognai"], + workspace: ["recognai"], }, column: { class: "text", - field: "owner", + field: "workspace", filtrable: "true", name: "Workspace", type: "text", @@ -17,7 +17,7 @@ function mountTableFiltrableColumn() { data: [ { name: "dataset_1", - owner: "recognai", + workspace: "recognai", task: "TokenClassification", }, ], diff --git a/frontend/specs/components/core/table/__snapshots__/BaseTableInfo.spec.js.snap b/frontend/specs/components/core/table/__snapshots__/BaseTableInfo.spec.js.snap index bdd477500b..672efd231d 100644 --- a/frontend/specs/components/core/table/__snapshots__/BaseTableInfo.spec.js.snap +++ b/frontend/specs/components/core/table/__snapshots__/BaseTableInfo.spec.js.snap @@ -1,7 +1,7 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP exports[`BaseTableInfo renders properly 1`] = ` - +
diff --git a/frontend/specs/components/token-classification/results/EntitiesSelector.spec.js b/frontend/specs/components/token-classification/results/EntitiesSelector.spec.js index d31d7b8db8..47964aef55 100644 --- a/frontend/specs/components/token-classification/results/EntitiesSelector.spec.js +++ b/frontend/specs/components/token-classification/results/EntitiesSelector.spec.js @@ -6,7 +6,7 @@ function mountEntitiesSelector() { return mount(EntitiesSelector, { stubs: ["entity-label"], propsData: { - datasetId: ["name", "owner"], + datasetId: ["workspace", "name"], datasetLastSelectedEntity: { colorId: 14, shortCut: "1", diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py index c887300cb1..fb1135b1e7 100644 --- a/src/argilla/client/apis/datasets.py +++ b/src/argilla/client/apis/datasets.py @@ -107,6 +107,7 @@ class _DatasetApiModel(BaseModel): name: str task: TaskType owner: Optional[str] = None + workspace: Optional[str] = None created_at: Optional[datetime] = None last_updated: Optional[datetime] = None diff --git a/src/argilla/client/sdk/datasets/models.py b/src/argilla/client/sdk/datasets/models.py index f08e47d085..5d8ea06cac 100644 --- a/src/argilla/client/sdk/datasets/models.py +++ b/src/argilla/client/sdk/datasets/models.py @@ -42,6 +42,7 @@ class BaseDatasetModel(BaseModel): class Dataset(BaseDatasetModel): task: TaskType owner: str = None + workspace: str = None created_at: datetime = None last_updated: datetime = None diff --git a/src/argilla/server/daos/models/datasets.py b/src/argilla/server/daos/models/datasets.py index 027e4a7011..03f690f777 100644 --- a/src/argilla/server/daos/models/datasets.py +++ b/src/argilla/server/daos/models/datasets.py @@ -15,7 +15,7 @@ from datetime import datetime from typing import Any, Dict, Optional, TypeVar, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator from argilla._constants import DATASET_NAME_REGEX_PATTERN from argilla.server.commons.models import TaskType @@ -24,7 +24,9 @@ class BaseDatasetDB(BaseModel): name: str = Field(regex=DATASET_NAME_REGEX_PATTERN) task: TaskType - owner: Optional[str] = None + owner: Optional[str] = Field(description="Deprecated. Use `workspace` instead. Will be removed in v1.5.0") + workspace: Optional[str] = None + tags: Dict[str, str] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict) created_at: datetime = None @@ -34,6 +36,13 @@ class BaseDatasetDB(BaseModel): ) last_updated: datetime = None + @validator("workspace", pre=True, always=True) + def set_workspace_defaults(cls, value, values): + if value: + return value + else: + return values.get("owner") + @classmethod def build_dataset_id(cls, name: str, owner: Optional[str] = None) -> str: """Build a dataset id for a given name and owner""" diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 7384aa9bae..5f71f9d282 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -53,6 +53,16 @@ def test_settings_workflow(mocked_client, settings_, wrong_settings): ar.configure_dataset(dataset, wrong_settings) +def test_list_dataset(mocked_client): + from argilla.client.api import active_client + + client = active_client() + datasets = client.http_client.get("/api/datasets") + + for ds in datasets: + assert ds["owner"] == ds["workspace"] + + def test_delete_dataset_by_non_creator(mocked_client): try: dataset = "test_delete_dataset_by_non_creator" diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py index 4f613f0123..5660a2393f 100644 --- a/tests/server/datasets/test_api.py +++ b/tests/server/datasets/test_api.py @@ -58,7 +58,7 @@ def test_create_dataset(mocked_client): assert dataset.metadata == request["metadata"] assert dataset.tags == request["tags"] assert dataset.name == dataset_name - assert dataset.owner == "argilla" + assert dataset.workspace == "argilla" assert dataset.task == TaskType.text_classification response = mocked_client.post( @@ -88,7 +88,7 @@ def test_fetch_dataset_using_workspaces(mocked_client: SecuredClient): dataset = Dataset.parse_obj(response.json()) assert dataset.created_by == "argilla" assert dataset.name == dataset_name - assert dataset.owner == ws + assert dataset.workspace == ws assert dataset.task == TaskType.text_classification response = mocked_client.post( @@ -106,7 +106,7 @@ def test_fetch_dataset_using_workspaces(mocked_client: SecuredClient): dataset = Dataset.parse_obj(response.json()) assert dataset.created_by == "argilla" assert dataset.name == dataset_name - assert dataset.owner == "argilla" + assert dataset.workspace == "argilla" assert dataset.task == TaskType.text_classification From 4e623d4c1adc97776812355d2587e41cb487221e Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 22 Feb 2023 12:10:20 +0100 Subject: [PATCH 17/45] feat: Add `active_client` function to main argilla module (#2387) # Description This PR allows to fetch the active argilla client instance as follow: ```python import argilla as rg client = rg.active_client() # from here, we can interact with API without extra header configuration client.http_client.get("/api/datasets") ``` Closes #2183 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [x] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I added comments to my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works --- docs/_source/guides/log_load_and_prepare_data.ipynb | 3 +-- docs/_source/reference/python/python_client.rst | 2 +- src/argilla/__init__.py | 2 ++ src/argilla/client/__init__.py | 2 ++ 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/_source/guides/log_load_and_prepare_data.ipynb b/docs/_source/guides/log_load_and_prepare_data.ipynb index ce21c887ee..190d8bdea4 100644 --- a/docs/_source/guides/log_load_and_prepare_data.ipynb +++ b/docs/_source/guides/log_load_and_prepare_data.ipynb @@ -442,10 +442,9 @@ "outputs": [], "source": [ "import argilla as rg\n", - "from argilla.client import api\n", "\n", "rg.init()\n", - "rg_client = api.active_client()\n", + "rg_client = rg.active_client()\n", "\n", "new_workspace = \"\"\n", "\n", diff --git a/docs/_source/reference/python/python_client.rst b/docs/_source/reference/python/python_client.rst index dd60c49986..9d3f941dfd 100644 --- a/docs/_source/reference/python/python_client.rst +++ b/docs/_source/reference/python/python_client.rst @@ -15,7 +15,7 @@ Methods ------- .. automodule:: argilla - :members: init, log, load, copy, delete, set_workspace, get_workspace, delete_records + :members: init, log, load, copy, delete, set_workspace, get_workspace, delete_records, active_client .. _python ref records: diff --git a/src/argilla/__init__.py b/src/argilla/__init__.py index 81ed00085a..8f0454b6db 100644 --- a/src/argilla/__init__.py +++ b/src/argilla/__init__.py @@ -38,6 +38,7 @@ if _TYPE_CHECKING: from argilla.client.api import ( + active_client, copy, delete, delete_records, @@ -84,6 +85,7 @@ "log", "log_async", "set_workspace", + "active_client", ], "client.models": [ "Text2TextRecord", diff --git a/src/argilla/client/__init__.py b/src/argilla/client/__init__.py index 168dce9eb1..6d25f8c3fb 100644 --- a/src/argilla/client/__init__.py +++ b/src/argilla/client/__init__.py @@ -12,3 +12,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .api import active_client From 4e18c6b4d649e97c8eba5a8b1e2936747f4e8b4f Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Wed, 22 Feb 2023 12:34:09 +0100 Subject: [PATCH 18/45] Refactor/remove no workspace usage and better superuser computation (#2373) # Description For old argilla versions, workspaces were not required to create datasets. Some code changes were included to support this behavior This PR clean these code sections and refactor them for those code parts where old logic was used. For example, the `user.is_superuser` flag is computed from the workspace list setup but also persisted to improve the code usability. We need still to provide a way to migrate those datasets that have no dataset to a default workspace. @davidberenstein1957 can you provide a section in docs where to include these steps from the release notes? **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [x] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I added comments to my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --- frontend/database/modules/datasets.js | 13 +---- frontend/models/Workspace.js | 4 +- frontend/plugins/vuex-orm-axios.js | 6 +- src/argilla/client/client.py | 2 +- .../server/daos/backend/search/model.py | 3 + src/argilla/server/daos/datasets.py | 3 - .../auth_provider/local/users/service.py | 8 +-- src/argilla/server/security/model.py | 44 ++++++++++++--- tests/client/sdk/users/test_model.py | 6 +- tests/client/test_api.py | 56 ++++--------------- tests/helpers.py | 2 +- tests/server/security/test_model.py | 42 ++++++++++---- 12 files changed, 94 insertions(+), 95 deletions(-) diff --git a/frontend/database/modules/datasets.js b/frontend/database/modules/datasets.js index d58006c1ec..136851cacf 100644 --- a/frontend/database/modules/datasets.js +++ b/frontend/database/modules/datasets.js @@ -18,7 +18,7 @@ import _ from "lodash"; import { ObservationDataset, USER_DATA_METADATA_KEY } from "@/models/Dataset"; import { DatasetViewSettings, Pagination } from "@/models/DatasetViewSettings"; import { AnnotationProgress } from "@/models/AnnotationProgress"; -import { currentWorkspace, NO_WORKSPACE } from "@/models/Workspace"; +import { currentWorkspace } from "@/models/Workspace"; import { Base64 } from "js-base64"; import { Vector as VectorModel } from "@/models/Vector"; @@ -578,9 +578,8 @@ const actions = { async deleteDataset(_, { workspace, name }) { var url = `/datasets/${name}`; - if (workspace !== NO_WORKSPACE) { - url += `?workspace=${workspace}`; - } + + url += `?workspace=${workspace}`; const deleteResults = await ObservationDataset.api().delete(url, { delete: [workspace, name], }); @@ -598,12 +597,6 @@ const actions = { return await ObservationDataset.api().get("/datasets/", { persistBy: "create", - dataTransformer: ({ data }) => { - return data.map((datasource) => { - datasource.workspace = datasource.workspace || NO_WORKSPACE; - return datasource; - }); - }, }); }, async fetchByName(_, name) { diff --git a/frontend/models/Workspace.js b/frontend/models/Workspace.js index 0e0cf37c85..56af837bd0 100644 --- a/frontend/models/Workspace.js +++ b/frontend/models/Workspace.js @@ -23,5 +23,5 @@ function currentWorkspace(route) { return route.params.workspace; } -const NO_WORKSPACE = "-"; -export { defaultWorkspace, currentWorkspace, NO_WORKSPACE }; + +export { defaultWorkspace, currentWorkspace }; diff --git a/frontend/plugins/vuex-orm-axios.js b/frontend/plugins/vuex-orm-axios.js index bf15e2a5c8..ebac803c1c 100644 --- a/frontend/plugins/vuex-orm-axios.js +++ b/frontend/plugins/vuex-orm-axios.js @@ -19,7 +19,7 @@ import { Model } from "@vuex-orm/core"; import { ExpiredAuthSessionError } from "@nuxtjs/auth-next/dist/runtime"; import { Notification } from "@/models/Notifications"; -import { currentWorkspace, NO_WORKSPACE } from "@/models/Workspace"; +import { currentWorkspace } from "@/models/Workspace"; export default ({ $axios, app }) => { Model.setAxios($axios); @@ -32,9 +32,7 @@ export default ({ $axios, app }) => { } let ws = currentWorkspace(app.context.route); - if (ws === NO_WORKSPACE) { - config.headers["X-Argilla-Workspace"] = ""; - } else if (ws) { + if (ws) { config.headers["X-Argilla-Workspace"] = ws; } return config; diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 70f9938509..66003c8729 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -205,7 +205,7 @@ def set_workspace(self, workspace: str): Args: workspace: The new workspace """ - if workspace is None: + if not workspace: raise Exception("Must provide a workspace") if workspace != self.get_workspace(): diff --git a/src/argilla/server/daos/backend/search/model.py b/src/argilla/server/daos/backend/search/model.py index 49a4986595..3db5db977e 100644 --- a/src/argilla/server/daos/backend/search/model.py +++ b/src/argilla/server/daos/backend/search/model.py @@ -61,6 +61,9 @@ class BaseQuery(BaseModel): class BaseDatasetsQuery(BaseQuery): tasks: Optional[List[str]] = None owners: Optional[List[str]] = None + # This is used to fetch workspaces without owner/workspace. But this should be moved to + # a default workspace + # TODO: Should be deprecated include_no_owner: bool = None name: Optional[str] = None diff --git a/src/argilla/server/daos/datasets.py b/src/argilla/server/daos/datasets.py index d798f823ea..0601209301 100644 --- a/src/argilla/server/daos/datasets.py +++ b/src/argilla/server/daos/datasets.py @@ -28,8 +28,6 @@ from argilla.server.daos.records import DatasetRecordsDAO from argilla.server.errors import WrongTaskError -NO_WORKSPACE = "" - class DatasetsDAO: """Datasets DAO""" @@ -84,7 +82,6 @@ def list_datasets( owner_list = owner_list or [] query = BaseDatasetsQuery( owners=owner_list, - include_no_owner=NO_WORKSPACE in owner_list, tasks=[task for task in task2dataset_map] if task2dataset_map else None, name=name, ) diff --git a/src/argilla/server/security/auth_provider/local/users/service.py b/src/argilla/server/security/auth_provider/local/users/service.py index 09e3c63010..da7680beb0 100644 --- a/src/argilla/server/security/auth_provider/local/users/service.py +++ b/src/argilla/server/security/auth_provider/local/users/service.py @@ -18,8 +18,6 @@ from fastapi import Depends from passlib.context import CryptContext -from argilla.server.daos.datasets import NO_WORKSPACE - from .dao import UsersDAO, create_users_dao from .model import User @@ -66,9 +64,7 @@ def authenticate_user(self, username: str, password: str) -> Optional[User]: def get_user(self, username) -> Optional[User]: user = self.__dao__.get_user(username) if user and user.is_superuser(): - workspaces = list(self._fetch_all_workspaces()) - if NO_WORKSPACE not in workspaces: - workspaces.append(NO_WORKSPACE) + workspaces = self._fetch_all_workspaces() user.workspaces = workspaces return user @@ -84,7 +80,7 @@ def _fetch_all_workspaces(self) -> List[str]: async def find_user_by_api_key(self, api_key: str) -> Optional[User]: user = await self.__dao__.get_user_by_api_key(api_key) if user and user.is_superuser(): - user.workspaces = [NO_WORKSPACE] + list(self._fetch_all_workspaces()) + user.workspaces = list(self._fetch_all_workspaces()) return user def __verify_password__(self, password: str, hashed_password: str) -> bool: diff --git a/src/argilla/server/security/model.py b/src/argilla/server/security/model.py index 9f182f5e6c..007f043507 100644 --- a/src/argilla/server/security/model.py +++ b/src/argilla/server/security/model.py @@ -15,7 +15,7 @@ import re from typing import List, Optional -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator, validator from argilla._constants import DATASET_NAME_REGEX_PATTERN from argilla.server.errors import EntityNotFoundError @@ -31,7 +31,9 @@ class User(BaseModel): email: Optional[str] = Field(None, regex=_EMAIL_REGEX_PATTERN) full_name: Optional[str] = None disabled: Optional[bool] = None - workspaces: Optional[List[str]] = None + + superuser: Optional[bool] + workspaces: Optional[List[str]] @validator("username") def check_username(cls, value): @@ -51,6 +53,32 @@ def check_workspace_pattern(cls, workspace: str): ) return workspace + @root_validator(pre=True) + def check_defaults(cls, values): + superuser = values.get("superuser") + workspaces = values.get("workspaces") + + values["superuser"] = cls._set_default_superuser(superuser, values) + values["workspaces"] = cls._set_default_workspace(workspaces, values) + + return values + + @classmethod + def _set_default_superuser(cls, value, values): + """This will setup the superuser flag when no workspaces are defined""" + if value is not None: + return value + # The current way to define super-users is create them with no workspaces at all + # (IT'S NOT THE SAME AS PASSING AN EMPTY LIST) + return values.get("workspaces", None) is None + + @classmethod + def _set_default_workspace(cls, value, values): + value = (value or []).copy() + value.append(values["username"]) + + return list(set(value)) + @property def default_workspace(self) -> Optional[str]: """Get the default user workspace""" @@ -76,8 +104,8 @@ def check_workspaces(self, workspaces: List[str]) -> List[str]: for workspace in workspaces: self.check_workspace(workspace) return workspaces - - return [self.default_workspace] + (self.workspaces or []) + else: + return self.workspaces def check_workspace(self, workspace: str) -> str: """ @@ -93,17 +121,15 @@ def check_workspace(self, workspace: str) -> str: The original workspace name if user belongs to it """ - if workspace is None or workspace == self.default_workspace: + if not workspace: return self.default_workspace - if not workspace and self.is_superuser(): - return workspace - if workspace not in (self.workspaces or []): + elif workspace not in self.workspaces: raise EntityNotFoundError(name=workspace, type="Workspace") return workspace def is_superuser(self) -> bool: """Check if a user is superuser""" - return self.workspaces is None or "" in self.workspaces + return self.superuser class Token(BaseModel): diff --git a/tests/client/sdk/users/test_model.py b/tests/client/sdk/users/test_model.py index d83da86677..a26ee150cd 100644 --- a/tests/client/sdk/users/test_model.py +++ b/tests/client/sdk/users/test_model.py @@ -20,8 +20,4 @@ def test_users_schema(helpers): client_schema = User.schema() server_schema = ServerUser.schema() - for clean_method in [helpers.remove_description, helpers.remove_pattern]: - client_schema = clean_method(client_schema) - server_schema = clean_method(server_schema) - - assert client_schema == server_schema + assert helpers.are_compatible_api_schemas(client_schema, server_schema) diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 4c7ec4bb44..810d7f006a 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -28,6 +28,7 @@ WORKSPACE_HEADER_NAME, ) from argilla.client import api +from argilla.client.client import Argilla from argilla.client.sdk.client import AuthenticatedClient from argilla.client.sdk.commons.errors import ( AlreadyExistsApiError, @@ -659,26 +660,22 @@ def test_load_text2text(mocked_client, supported_vector_search): def test_client_workspace(mocked_client): - try: - ws = api.get_workspace() - assert ws == "argilla" - - api.set_workspace("") - assert api.get_workspace() == "" + api = Argilla() + ws = api.get_workspace() + assert ws == "argilla" + for ws in [None, ""]: with pytest.raises(Exception, match="Must provide a workspace"): - api.set_workspace(None) + api.set_workspace(ws) - # Mocking user - api.active_api().user.workspaces = ["a", "b"] + # Mocking user + api.user.workspaces = ["a", "b"] - with pytest.raises(Exception, match="Wrong provided workspace c"): - api.set_workspace("c") + with pytest.raises(Exception, match="Wrong provided workspace c"): + api.set_workspace("c") - api.set_workspace("argilla") - assert api.get_workspace() == "argilla" - finally: - api.init() # reset workspace + api.set_workspace("argilla") + assert api.get_workspace() == "argilla" def test_load_sort(mocked_client): @@ -706,32 +703,3 @@ def test_load_sort(mocked_client): ds = api.load(name=dataset, ids=["1str", "2str", "11str"]) df = ds.to_pandas() assert list(df.id) == ["11str", "1str", "2str"] - - -def test_load_workspace_from_different_workspace(mocked_client): - records = [ - ar.TextClassificationRecord( - text="test text", - id=i, - ) - for i in ["1str", 1, 2, 11, "2str", "11str"] - ] - - dataset = "test_load_workspace_from_different_workspace" - workspace = api.get_workspace() - try: - api.set_workspace("") # empty workspace - api.delete(dataset) - api.log(records, name=dataset) - - # check sorting policies - ds = api.load(name=dataset) - df = ds.to_pandas() - assert list(df.id) == [1, 11, "11str", "1str", 2, "2str"] - - api.set_workspace(workspace) - df = api.load(name=dataset) - df = df.to_pandas() - assert list(df.id) == [1, 11, "11str", "1str", 2, "2str"] - finally: - api.set_workspace(workspace) diff --git a/tests/helpers.py b/tests/helpers.py index 4246a8fd45..06c1d1825f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -56,7 +56,7 @@ def reset_default_user(self): rb_api = active_api() rb_api._user = default_user rb_api.http_client.token = default_user.api_key - rb_api.http_client.headers.pop(WORKSPACE_HEADER_NAME) + rb_api.http_client.headers.pop(WORKSPACE_HEADER_NAME, None) self._header[API_KEY_HEADER_NAME] = default_user.api_key def add_workspaces_to_argilla_user(self, workspaces: List[str]): diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py index e89ac76842..07d0dfa5e8 100644 --- a/tests/server/security/test_model.py +++ b/tests/server/security/test_model.py @@ -49,8 +49,8 @@ def test_check_non_provided_workspaces(): user = User(username="test") assert user.check_workspaces([]) == ["test"] - user.workspaces = ["ws"] - assert user.check_workspaces([]) == [user.default_workspace] + user.workspaces + user = User(username="test", workspaces=["ws"]) + assert set(user.check_workspaces([])) == {"ws", "test"} with pytest.raises(EntityNotFoundError, match="not-found"): assert user.check_workspaces(["ws", "not-found"]) @@ -83,23 +83,45 @@ def test_workspace_for_superuser(): assert user.check_workspace("some") == "some" assert user.check_workspace(None) == "admin" - assert user.check_workspace("") == "" + assert user.check_workspace("") == "admin" user.workspaces = ["some"] assert user.check_workspaces(["some"]) == ["some"] +def test_workspaces_with_default(): + expected_workspaces = ["user", "ws1"] + user = User(username="user", workspaces=expected_workspaces) + assert len(user.workspaces) == len(expected_workspaces) + for ws in expected_workspaces: + assert ws in user.workspaces + + +def test_is_superuser(): + admin_user = User(username="admin") + assert admin_user.is_superuser() + + admin_user.workspaces.append("other-workspace") + assert admin_user.is_superuser() + assert set(admin_user.workspaces) == {"other-workspace", "admin"} + + user = User(username="test", workspaces=["bod"]) + assert not user.is_superuser() + user.superuser = True + assert user.is_superuser() + + @pytest.mark.parametrize( "workspaces, expected", [ - (None, ["user"]), - ([], ["user"]), - (["a"], ["user", "a"]), + (None, {"user"}), + ([], {"user"}), + (["a"], {"user", "a"}), ], ) def test_check_workspaces_with_default(workspaces, expected): user = User(username="user", workspaces=workspaces) - assert user.check_workspaces([]) == expected - assert user.check_workspaces(None) == expected - assert user.check_workspaces([None]) == expected - assert user.check_workspace(user.username) == user.username + assert set(user.check_workspaces([])) == expected + assert set(user.check_workspaces(None)) == expected + assert set(user.check_workspaces([None])) == expected + assert set(user.check_workspace(user.username)) == set(user.username) From 317ce42c5f4292190e1be2b94c917ac2589b439d Mon Sep 17 00:00:00 2001 From: Keith Cuniah <88380932+keithCuniah@users.noreply.github.com> Date: Wed, 22 Feb 2023 12:38:33 +0100 Subject: [PATCH 19/45] ci: remove checkpoint from PR template (#2390 # Description If a code needs comments, except in some particular cases, it's not clean code. So maybe it's better to remove the line `I added comments to my code` from the PR template. **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [x] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - N/A **Checklist** - N/A I have merged the original branch into my forked branch - N/A I added relevant documentation - N/A follows the style guidelines of this project - N/A I did a self-review of my code - N/A I added comments to my code - N/A I made corresponding changes to the documentation - N/A My changes generate no new warnings - N/A I have added tests that prove my fix is effective or that my feature works --- pull_request_template.md | 1 - 1 file changed, 1 deletion(-) diff --git a/pull_request_template.md b/pull_request_template.md index 654386b4fd..d83dc14f3d 100644 --- a/pull_request_template.md +++ b/pull_request_template.md @@ -28,7 +28,6 @@ Closes # - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code -- [ ] I added comments to my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works From c35d63a5323fd360f1ef79456e6816d3470b3151 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Fri, 24 Feb 2023 09:07:14 +0100 Subject: [PATCH 20/45] CI: Skip rather than failing in 2 common scenarios (#2392) Hello! ## Pull Request overview * Skip CI jobs rather than letting them fail in 2 common scenarios: 1. If a PR is created from a fork. 2. If a PR does not contain code changes. ## Problems ### Scenario 1 If a PR is created from a fork, then certain secrets will be inaccessible, e.g. `secrets.AR_DOCKER_USERNAME` and `secrets.AR_DOCKER_PASSWORD`. This causes the "Docker Deploy" job to fail, understandably. It would be better if these jobs would be skipped instead. See [here](https://github.com/argilla-io/argilla/actions/runs/4240437103/jobs/7369703791) for an example case. ### Scenario 2 There are two subsequent jobs: One to build the Python package, and one to deploy it as a Docker image. The first one will stop halfway through if there are no code changes found in the PR, but the second job will still try to deploy the build artifact. Upon trying to load it, it will crash. See [here](https://github.com/argilla-io/argilla/actions/runs/4229327600/jobs/7345607762) for an example case. ## Fixes ### Scenario 1 Sadly, secrets are not accessible on the `build.if` and `deploy_docker.if` level, so we can't use e.g. `if: secrets.AR_DOCKER_USERNAME != ''`. Instead, we are required to use an environment variable, e.g. `IS_DEPLOYABLE` to track whether the secrets are accessible, and then add `if env.IS_DEPLOYABLE == 'true'` to all steps. ### Scenario 2 This was as simple as setting `code_outputs` as an output from the `build` job, which can then be read as input for the `deploy_docker` job. This job is now skipped if there are no code changes in the PR/commit. --- **Type of change** - [x] Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** Through this PR. Note: I have only been able to test scenario 1 (PR from a fork). I haven't been able to test scenario 2 nor the "normal" scenarios where these jobs do actually execute (i.e. first-party PR with code changes). These are annoying to test easily. I think we're best off merging this and reverting if any scenarios work unexpectedly, rather than trying to test all of the scenarios in a separate repo. **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works --- With this PR I hope to have reached the point where CI failures imply actual issues and CI passes imply that a PR is "likely ready". - Tom Aarsen --- .github/workflows/package.yml | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/.github/workflows/package.yml b/.github/workflows/package.yml index d6927a4149..9eac0ff5f7 100644 --- a/.github/workflows/package.yml +++ b/.github/workflows/package.yml @@ -208,6 +208,11 @@ jobs: defaults: run: shell: bash -l {0} + # Only build the package if we can deploy it as a docker image + env: + IS_DEPLOYABLE: ${{ secrets.AR_DOCKER_USERNAME != '' }} + outputs: + code_changes: ${{ steps.filter.outputs.code_changes }} steps: - name: Checkout Code 🛎 @@ -228,7 +233,7 @@ jobs: - '.github/workflows/package.yml' - name: Cache pip 👜 uses: actions/cache@v2 - if: steps.filter.outputs.code_changes == 'true' + if: steps.filter.outputs.code_changes == 'true' && env.IS_DEPLOYABLE == 'true' env: # Increase this value to reset cache if pyproject.toml has not changed CACHE_NUMBER: 0 @@ -238,18 +243,18 @@ jobs: - name: Setup Node.js uses: actions/setup-node@v2 - if: steps.filter.outputs.code_changes == 'true' + if: steps.filter.outputs.code_changes == 'true' && env.IS_DEPLOYABLE == 'true' with: node-version: "14" - name: Build Package 🍟 - if: steps.filter.outputs.code_changes == 'true' + if: steps.filter.outputs.code_changes == 'true' && env.IS_DEPLOYABLE == 'true' run: | pip install -U build scripts/build_distribution.sh - name: Upload package artifact - if: steps.filter.outputs.code_changes == 'true' + if: steps.filter.outputs.code_changes == 'true' && env.IS_DEPLOYABLE == 'true' uses: actions/upload-artifact@v2 with: name: python-package @@ -262,6 +267,9 @@ jobs: - build - test-elastic - test-opensearch + env: + IS_DEPLOYABLE: ${{ secrets.AR_DOCKER_USERNAME != '' }} + if: needs.build.outputs.code_changes == 'true' strategy: matrix: include: @@ -281,30 +289,36 @@ jobs: steps: - name: Checkout Code 🛎 uses: actions/checkout@v2 + if: env.IS_DEPLOYABLE == 'true' - name: Download python package uses: actions/download-artifact@v2 with: name: python-package path: dist + if: env.IS_DEPLOYABLE == 'true' - name: Set up QEMU uses: docker/setup-qemu-action@v2 + if: env.IS_DEPLOYABLE == 'true' - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 + if: env.IS_DEPLOYABLE == 'true' - name: Docker meta id: meta uses: crazy-max/ghaction-docker-meta@v2 with: images: ${{ matrix.image }} + if: env.IS_DEPLOYABLE == 'true' - name: Login to DockerHub uses: docker/login-action@v1 with: username: ${{ secrets.AR_DOCKER_USERNAME }} password: ${{ secrets.AR_DOCKER_PASSWORD }} + if: env.IS_DEPLOYABLE == 'true' - name: Build & push Docker image uses: docker/build-push-action@v2 @@ -315,6 +329,7 @@ jobs: tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} push: true + if: env.IS_DEPLOYABLE == 'true' - name: Docker Hub Description uses: peter-evans/dockerhub-description@v3 @@ -323,6 +338,7 @@ jobs: password: ${{ secrets.AR_DOCKER_PASSWORD }} repository: ${{ matrix.image }} readme-filepath: ${{ matrix.readme }} + if: env.IS_DEPLOYABLE == 'true' # This job will upload a Python Package using Twine when a release is created # For more information see: From 68eddcb054bb7db42e7741f4f06691c0f78b3928 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Mon, 27 Feb 2023 11:21:27 +0100 Subject: [PATCH 21/45] Refactor: Replace "ar" with "rg" in test imports (#2393) Hello! # Description As discussed in Slack, I've replaced "ar" with "rg" as the shorthand import in the tests. Long live project-wide "find-and-replace" functionality. **Type of change** - [x] Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** Running the test suite. **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works --- - Tom Aarsen --- docs/_source/conf.py | 4 +- tests/client/conftest.py | 58 +++---- tests/client/sdk/conftest.py | 10 +- tests/client/test_api.py | 50 +++--- tests/client/test_client_errors.py | 4 +- tests/client/test_dataset.py | 120 +++++++------- tests/datasets/test_datasets.py | 16 +- tests/functional_tests/datasets/helpers.py | 6 +- .../test_delete_records_from_datasets.py | 40 ++--- .../test_log_for_text_classification.py | 146 +++++++++--------- .../text_classification/test_label_errors.py | 24 +-- .../labeling/text_classification/test_rule.py | 6 +- tests/listeners/test_listener.py | 6 +- tests/metrics/test_common_metrics.py | 20 +-- tests/metrics/test_text_classification.py | 16 +- tests/metrics/test_token_classification.py | 26 ++-- tests/monitoring/test_flair_monitoring.py | 8 +- tests/monitoring/test_spacy_monitoring.py | 16 +- .../text_classification/test_api_settings.py | 12 +- .../token_classification/test_api_settings.py | 12 +- 20 files changed, 300 insertions(+), 300 deletions(-) diff --git a/docs/_source/conf.py b/docs/_source/conf.py index 0cb7d64a8b..f86545a8d2 100644 --- a/docs/_source/conf.py +++ b/docs/_source/conf.py @@ -33,9 +33,9 @@ import os try: - import argilla as ar + import argilla as rg - version_ = ar.__version__ + version_ = rg.__version__ except ModuleNotFoundError: version_ = os.environ["VERSION"] diff --git a/tests/client/conftest.py b/tests/client/conftest.py index f80b838c10..e5940cb560 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -16,7 +16,7 @@ from typing import List import argilla -import argilla as ar +import argilla as rg import pytest from argilla.client.sdk.datasets.models import TaskType from argilla.client.sdk.text2text.models import ( @@ -76,9 +76,9 @@ def gutenberg_spacy_ner(mocked_client): @pytest.fixture(scope="session") def singlelabel_textclassification_records( request, -) -> List[ar.TextClassificationRecord]: +) -> List[rg.TextClassificationRecord]: return [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( inputs={"text": "mock", "context": "mock"}, prediction=[("a", 0.5), ("b", 0.5)], prediction_agent="mock_pagent", @@ -87,27 +87,27 @@ def singlelabel_textclassification_records( id=1, event_timestamp=datetime.datetime(2000, 1, 1), metadata={"mock_metadata": "mock"}, - explanation={"text": [ar.TokenAttributions(token="mock", attributions={"a": 0.1, "b": 0.5})]}, + explanation={"text": [rg.TokenAttributions(token="mock", attributions={"a": 0.1, "b": 0.5})]}, status="Validated", ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( inputs={"text": "mock2", "context": "mock2"}, prediction=[("a", 0.5), ("b", 0.2)], prediction_agent="mock2_pagent", id=2, event_timestamp=datetime.datetime(2000, 2, 1), metadata={"mock2_metadata": "mock2"}, - explanation={"text": [ar.TokenAttributions(token="mock2", attributions={"a": 0.7, "b": 0.2})]}, + explanation={"text": [rg.TokenAttributions(token="mock2", attributions={"a": 0.7, "b": 0.2})]}, status="Default", ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( inputs={"text": "mock2", "context": "mock2"}, prediction=[("a", 0.5), ("b", 0.2)], prediction_agent="mock2_pagent", id=3, status="Discarded", ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( inputs={"text": "mock3", "context": "mock3"}, annotation="a", annotation_agent="mock_aagent", @@ -115,7 +115,7 @@ def singlelabel_textclassification_records( event_timestamp=datetime.datetime(2000, 3, 1), metadata={"mock_metadata": "mock"}, ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( text="mock", id="b", status="Default", @@ -149,9 +149,9 @@ def log_singlelabel_textclassification_records( @pytest.fixture(scope="session") -def multilabel_textclassification_records(request) -> List[ar.TextClassificationRecord]: +def multilabel_textclassification_records(request) -> List[rg.TextClassificationRecord]: return [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( inputs={"text": "mock", "context": "mock"}, prediction=[("a", 0.6), ("b", 0.4)], prediction_agent="mock_pagent", @@ -161,10 +161,10 @@ def multilabel_textclassification_records(request) -> List[ar.TextClassification id=1, event_timestamp=datetime.datetime(2000, 1, 1), metadata={"mock_metadata": "mock"}, - explanation={"text": [ar.TokenAttributions(token="mock", attributions={"a": 0.1, "b": 0.5})]}, + explanation={"text": [rg.TokenAttributions(token="mock", attributions={"a": 0.1, "b": 0.5})]}, status="Validated", ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( inputs={"text": "mock2", "context": "mock2"}, prediction=[("a", 0.5), ("b", 0.2)], prediction_agent="mock2_pagent", @@ -172,10 +172,10 @@ def multilabel_textclassification_records(request) -> List[ar.TextClassification id=2, event_timestamp=datetime.datetime(2000, 2, 1), metadata={"mock2_metadata": "mock2"}, - explanation={"text": [ar.TokenAttributions(token="mock2", attributions={"a": 0.7, "b": 0.2})]}, + explanation={"text": [rg.TokenAttributions(token="mock2", attributions={"a": 0.7, "b": 0.2})]}, status="Default", ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( inputs={"text": "mock2", "context": "mock2"}, prediction=[("a", 0.5), ("b", 0.2)], prediction_agent="mock2_pagent", @@ -183,7 +183,7 @@ def multilabel_textclassification_records(request) -> List[ar.TextClassification id=3, status="Discarded", ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( inputs={"text": "mock3", "context": "mock3"}, annotation=["a"], annotation_agent="mock_aagent", @@ -193,7 +193,7 @@ def multilabel_textclassification_records(request) -> List[ar.TextClassification metadata={"mock_metadata": "mock"}, metrics={}, ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( text="mock", multi_label=True, id="b", @@ -229,9 +229,9 @@ def log_multilabel_textclassification_records( @pytest.fixture(scope="session") -def tokenclassification_records(request) -> List[ar.TokenClassificationRecord]: +def tokenclassification_records(request) -> List[rg.TokenClassificationRecord]: return [ - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( text="This is an example", tokens=["This", "is", "an", "example"], prediction=[("a", 5, 7), ("b", 11, 18)], @@ -243,7 +243,7 @@ def tokenclassification_records(request) -> List[ar.TokenClassificationRecord]: metadata={"mock_metadata": "mock"}, status="Validated", ), - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( text="This is a second example", tokens=["This", "is", "a", "second", "example"], prediction=[("a", 5, 7), ("b", 8, 9)], @@ -252,7 +252,7 @@ def tokenclassification_records(request) -> List[ar.TokenClassificationRecord]: event_timestamp=datetime.datetime(2000, 1, 1), metadata={"mock_metadata": "mock"}, ), - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( text="This is a secondd example", tokens=["This", "is", "a", "secondd", "example"], prediction=[("a", 5, 7), ("b", 8, 9, 0.5)], @@ -260,7 +260,7 @@ def tokenclassification_records(request) -> List[ar.TokenClassificationRecord]: id=3, status="Default", ), - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( text="This is a third example", tokens=["This", "is", "a", "third", "example"], annotation=[("a", 0, 4), ("b", 16, 23)], @@ -270,7 +270,7 @@ def tokenclassification_records(request) -> List[ar.TokenClassificationRecord]: metadata={"mock_metadata": "mock"}, metrics={}, ), - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( text="This is a third example", tokens=["This", "is", "a", "third", "example"], id="b", @@ -303,9 +303,9 @@ def log_tokenclassification_records( @pytest.fixture(scope="session") -def text2text_records(request) -> List[ar.Text2TextRecord]: +def text2text_records(request) -> List[rg.Text2TextRecord]: return [ - ar.Text2TextRecord( + rg.Text2TextRecord( text="This is an example", prediction=["Das ist ein Beispiel", "Esto es un ejemplo"], prediction_agent="mock_pagent", @@ -316,7 +316,7 @@ def text2text_records(request) -> List[ar.Text2TextRecord]: metadata={"mock_metadata": "mock"}, status="Validated", ), - ar.Text2TextRecord( + rg.Text2TextRecord( text="This is a one and a half example", prediction=[("Das ist ein Beispiell", 0.9), ("Esto es un ejemploo", 0.1)], prediction_agent="mock_pagent", @@ -324,7 +324,7 @@ def text2text_records(request) -> List[ar.Text2TextRecord]: event_timestamp=datetime.datetime(2000, 1, 1), metadata={"mock_metadata": "mock"}, ), - ar.Text2TextRecord( + rg.Text2TextRecord( text="This is a second example", prediction=["Esto es un ejemplooo", ("Das ist ein Beispielll", 0.9)], prediction_agent="mock_pagent", @@ -333,7 +333,7 @@ def text2text_records(request) -> List[ar.Text2TextRecord]: metadata={"mock_metadata": "mock"}, metrics={}, ), - ar.Text2TextRecord( + rg.Text2TextRecord( text="This is a third example", annotation="C'est une très bonne baguette", annotation_agent="mock_pagent", @@ -342,7 +342,7 @@ def text2text_records(request) -> List[ar.Text2TextRecord]: metadata={"mock_metadata": "mock"}, metrics={}, ), - ar.Text2TextRecord( + rg.Text2TextRecord( text="This is a forth example", id="b", status="Discarded", diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py index 68b91ffe0f..e872471376 100644 --- a/tests/client/sdk/conftest.py +++ b/tests/client/sdk/conftest.py @@ -16,7 +16,7 @@ from datetime import datetime from typing import Any, Dict, List -import argilla as ar +import argilla as rg import pytest from argilla._constants import DEFAULT_API_KEY from argilla.client.sdk.client import AuthenticatedClient @@ -142,9 +142,9 @@ def sdk_client(mocked_client, monkeypatch): @pytest.fixture def bulk_textclass_data(): - explanation = {"text": [ar.TokenAttributions(token="test", attributions={"test": 0.5})]} + explanation = {"text": [rg.TokenAttributions(token="test", attributions={"test": 0.5})]} records = [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( text="test", prediction=[("test", 0.5)], prediction_agent="agent", @@ -170,7 +170,7 @@ def bulk_textclass_data(): @pytest.fixture def bulk_text2text_data(): records = [ - ar.Text2TextRecord( + rg.Text2TextRecord( text="test", prediction=[("prueba", 0.5), ("intento", 0.5)], prediction_agent="agent", @@ -194,7 +194,7 @@ def bulk_text2text_data(): @pytest.fixture def bulk_tokenclass_data(): records = [ - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( text="a raw text", tokens=["a", "raw", "text"], prediction=[("test", 2, 5, 0.9)], diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 810d7f006a..822a8e6607 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -17,7 +17,7 @@ from time import sleep from typing import Any, Iterable -import argilla as ar +import argilla as rg import datasets import httpx import pandas as pd @@ -161,7 +161,7 @@ def test_log_something(monkeypatch, mocked_client): response = api.log( name=dataset_name, - records=ar.TextClassificationRecord(inputs={"text": "This is a test"}), + records=rg.TextClassificationRecord(inputs={"text": "This is a test"}), ) assert response.processed == 1 @@ -199,7 +199,7 @@ def test_load_limits(mocked_client, supported_vector_search): def test_log_records_with_too_long_text(mocked_client): dataset_name = "test_log_records_with_too_long_text" mocked_client.delete(f"/api/datasets/{dataset_name}") - item = ar.TextClassificationRecord(inputs={"text": "This is a toooooo long text\n" * 10000}) + item = rg.TextClassificationRecord(inputs={"text": "This is a toooooo long text\n" * 10000}) api.log([item], name=dataset_name) @@ -215,7 +215,7 @@ def test_log_without_name(mocked_client): match="Empty dataset name has been passed as argument.", ): api.log( - ar.TextClassificationRecord(inputs={"text": "This is a single record. Only this. No more."}), + rg.TextClassificationRecord(inputs={"text": "This is a single record. Only this. No more."}), name=None, ) @@ -236,7 +236,7 @@ def test_log_background(mocked_client): # Log in the background, and extract the future sample_text = "Sample text for testing" future = api.log( - ar.TextClassificationRecord(text=sample_text), + rg.TextClassificationRecord(text=sample_text), name=dataset_name, background=True, ) @@ -272,7 +272,7 @@ def raise_http_error(*args, **kwargs): monkeypatch.setattr(httpx.AsyncClient, "post", raise_http_error) future = api.log( - ar.TextClassificationRecord(text=sample_text), + rg.TextClassificationRecord(text=sample_text), name=dataset_name, background=True, ) @@ -313,10 +313,10 @@ def inner(*args, **kwargs): @pytest.mark.parametrize( "records, dataset_class", [ - ("singlelabel_textclassification_records", ar.DatasetForTextClassification), - ("multilabel_textclassification_records", ar.DatasetForTextClassification), - ("tokenclassification_records", ar.DatasetForTokenClassification), - ("text2text_records", ar.DatasetForText2Text), + ("singlelabel_textclassification_records", rg.DatasetForTextClassification), + ("multilabel_textclassification_records", rg.DatasetForTextClassification), + ("tokenclassification_records", rg.DatasetForTokenClassification), + ("text2text_records", rg.DatasetForText2Text), ], ) def test_general_log_load(mocked_client, monkeypatch, request, records, dataset_class): @@ -366,9 +366,9 @@ def test_log_with_generator(mocked_client, monkeypatch): dataset_name = "test_log_with_generator" mocked_client.delete(f"/api/datasets/{dataset_name}") - def generator(items: int = 10) -> Iterable[ar.TextClassificationRecord]: + def generator(items: int = 10) -> Iterable[rg.TextClassificationRecord]: for i in range(0, items): - yield ar.TextClassificationRecord(id=i, inputs={"text": "The text data"}) + yield rg.TextClassificationRecord(id=i, inputs={"text": "The text data"}) api.log(generator(), name=dataset_name) @@ -378,7 +378,7 @@ def test_create_ds_with_wrong_name(mocked_client): with pytest.raises(InputValueError): api.log( - ar.TextClassificationRecord( + rg.TextClassificationRecord( inputs={"text": "The text data"}, ), name=dataset_name, @@ -390,7 +390,7 @@ def test_delete_dataset(mocked_client): mocked_client.delete(f"/api/datasets/{dataset_name}") api.log( - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation_agent="test", @@ -422,7 +422,7 @@ def test_dataset_copy(mocked_client): mocked_client.delete(f"/api/datasets/{dataset_copy}") mocked_client.delete(f"/api/datasets/{dataset}") - record = ar.TextClassificationRecord( + record = rg.TextClassificationRecord( id=0, text="This is the record input", annotation_agent="test", @@ -458,7 +458,7 @@ def test_dataset_copy_to_another_workspace(mocked_client): mocked_client.delete(f"/api/datasets/{dataset_copy}?workspace={new_workspace}") api.log( - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=0, text="This is the record input", annotation_agent="test", @@ -485,7 +485,7 @@ def test_update_record(mocked_client): mocked_client.delete(f"/api/datasets/{dataset}") expected_inputs = ["This is a text"] - record = ar.TextClassificationRecord( + record = rg.TextClassificationRecord( id=0, inputs=expected_inputs, annotation_agent="test", @@ -502,7 +502,7 @@ def test_update_record(mocked_client): assert len(records) == 1 assert records[0]["annotation"] == "T" # This record will replace the old one - record = ar.TextClassificationRecord( + record = rg.TextClassificationRecord( id=0, inputs=expected_inputs, ) @@ -526,7 +526,7 @@ def test_text_classifier_with_inputs_list(mocked_client): expected_inputs = ["A", "List", "of", "values"] api.log( - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=0, inputs=expected_inputs, annotation_agent="test", @@ -589,8 +589,8 @@ def test_load_as_pandas(mocked_client, supported_vector_search): ) records = api.load(name=dataset) - assert isinstance(records, ar.DatasetForTextClassification) - assert isinstance(records[0], ar.TextClassificationRecord) + assert isinstance(records, rg.DatasetForTextClassification) + assert isinstance(records[0], rg.TextClassificationRecord) if supported_vector_search: for record in records: @@ -609,7 +609,7 @@ def test_load_as_pandas(mocked_client, supported_vector_search): def test_token_classification_spans(span, valid): texto = "Esto es una prueba" if valid: - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( text=texto, tokens=texto.split(), prediction=[("test", *span)], @@ -621,7 +621,7 @@ def test_token_classification_spans(span, valid): r"Spans:\n\('test', 1, 2\) - 's'\n" r"Tokens:\n\['Esto', 'es', 'una', 'prueba'\]", ): - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( text=texto, tokens=texto.split(), prediction=[("test", *span)], @@ -633,7 +633,7 @@ def test_load_text2text(mocked_client, supported_vector_search): records = [] for i in range(0, 2): - record = ar.Text2TextRecord( + record = rg.Text2TextRecord( text="test text", prediction=["test prediction"], annotation="test annotation", @@ -680,7 +680,7 @@ def test_client_workspace(mocked_client): def test_load_sort(mocked_client): records = [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( text="test text", id=i, ) diff --git a/tests/client/test_client_errors.py b/tests/client/test_client_errors.py index f9f0cd4703..32d713ade2 100644 --- a/tests/client/test_client_errors.py +++ b/tests/client/test_client_errors.py @@ -18,6 +18,6 @@ def test_unauthorized_response_error(mocked_client): with pytest.raises(UnauthorizedApiError, match="Could not validate credentials"): - import argilla as ar + import argilla as rg - ar.init(api_key="wrong-api-key") + rg.init(api_key="wrong-api-key") diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py index 404b2768b4..050bdfee5f 100644 --- a/tests/client/test_dataset.py +++ b/tests/client/test_dataset.py @@ -17,7 +17,7 @@ import sys from time import sleep -import argilla as ar +import argilla as rg import datasets import pandas as pd import pytest @@ -62,7 +62,7 @@ def test_init(self, monkeypatch, singlelabel_textclassification_records): with pytest.raises(WrongRecordTypeError, match="but you provided Text2TextRecord"): DatasetBase( - records=[ar.Text2TextRecord(text="test")], + records=[rg.Text2TextRecord(text="test")], ) with pytest.raises( @@ -71,8 +71,8 @@ def test_init(self, monkeypatch, singlelabel_textclassification_records): ): DatasetBase( records=[ - ar.TextClassificationRecord(text="test"), - ar.Text2TextRecord(text="test"), + rg.TextClassificationRecord(text="test"), + rg.Text2TextRecord(text="test"), ], ) @@ -183,7 +183,7 @@ def test_setitem_delitem(self, monkeypatch, singlelabel_textclassification_recor [rec.copy(deep=True) for rec in singlelabel_textclassification_records], ) - record = ar.TextClassificationRecord(text="mock") + record = rg.TextClassificationRecord(text="mock") dataset[0] = record assert dataset._records[0] is record @@ -199,7 +199,7 @@ def test_setitem_delitem(self, monkeypatch, singlelabel_textclassification_recor " .*TextClassificationRecord.* but you provided .*Text2TextRecord.*" ), ): - dataset[0] = ar.Text2TextRecord(text="mock") + dataset[0] = rg.Text2TextRecord(text="mock") def test_prepare_for_training_train_test_splits(self, monkeypatch, singlelabel_textclassification_records): monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord) @@ -223,8 +223,8 @@ def test_prepare_for_training_train_test_splits(self, monkeypatch, singlelabel_t class TestDatasetForTextClassification: def test_init(self, singlelabel_textclassification_records): - ds = ar.DatasetForTextClassification(singlelabel_textclassification_records) - assert ds._RECORD_TYPE == ar.TextClassificationRecord + ds = rg.DatasetForTextClassification(singlelabel_textclassification_records) + assert ds._RECORD_TYPE == rg.TextClassificationRecord assert ds._records == singlelabel_textclassification_records @pytest.mark.parametrize( @@ -236,7 +236,7 @@ def test_init(self, singlelabel_textclassification_records): ) def test_to_from_datasets(self, records, request): records = request.getfixturevalue(records) - expected_dataset = ar.DatasetForTextClassification(records) + expected_dataset = rg.DatasetForTextClassification(records) expected_dataset.prepare_for_training(train_size=0.5) dataset_ds = expected_dataset.to_datasets() @@ -266,24 +266,24 @@ def test_to_from_datasets(self, records, request): } ] - dataset = ar.DatasetForTextClassification.from_datasets(dataset_ds) + dataset = rg.DatasetForTextClassification.from_datasets(dataset_ds) - assert isinstance(dataset, ar.DatasetForTextClassification) + assert isinstance(dataset, rg.DatasetForTextClassification) _compare_datasets(dataset, expected_dataset) missing_optional_cols = datasets.Dataset.from_dict({"inputs": ["mock"]}) - rec = ar.DatasetForTextClassification.from_datasets(missing_optional_cols)[0] + rec = rg.DatasetForTextClassification.from_datasets(missing_optional_cols)[0] assert rec.inputs == {"text": "mock"} def test_from_to_datasets_id(self): - dataset_rb = ar.DatasetForTextClassification([ar.TextClassificationRecord(text="mock")]) + dataset_rb = rg.DatasetForTextClassification([rg.TextClassificationRecord(text="mock")]) dataset_ds = dataset_rb.to_datasets() assert dataset_ds["id"] == [None] - assert ar.read_datasets(dataset_ds, task="TextClassification")[0].id is None + assert rg.read_datasets(dataset_ds, task="TextClassification")[0].id is None def test_datasets_empty_metadata(self): - dataset = ar.DatasetForTextClassification([ar.TextClassificationRecord(text="mock")]) + dataset = rg.DatasetForTextClassification([rg.TextClassificationRecord(text="mock")]) assert dataset.to_datasets()["metadata"] == [None] @pytest.mark.parametrize( @@ -295,16 +295,16 @@ def test_datasets_empty_metadata(self): ) def test_to_from_pandas(self, records, request): records = request.getfixturevalue(records) - expected_dataset = ar.DatasetForTextClassification(records) + expected_dataset = rg.DatasetForTextClassification(records) dataset_df = expected_dataset.to_pandas() assert isinstance(dataset_df, pd.DataFrame) assert list(dataset_df.columns) == list(expected_dataset[0].__fields__.keys()) - dataset = ar.DatasetForTextClassification.from_pandas(dataset_df) + dataset = rg.DatasetForTextClassification.from_pandas(dataset_df) - assert isinstance(dataset, ar.DatasetForTextClassification) + assert isinstance(dataset, rg.DatasetForTextClassification) for rec, expected in zip(dataset, expected_dataset): assert rec == expected @@ -323,7 +323,7 @@ def test_push_to_hub(self, request, name: str): records = request.getfixturevalue(name) # TODO(@frascuchon): move dataset to new organization dataset_name = f"rubrix/_test_text_classification_records-{name}" - dataset_ds = ar.DatasetForTextClassification(records).to_datasets() + dataset_ds = rg.DatasetForTextClassification(records).to_datasets() _push_to_hub_with_retries(dataset_ds, repo_id=dataset_name, token=_HF_HUB_ACCESS_TOKEN, private=True) sleep(1) dataset_ds = datasets.load_dataset( @@ -344,7 +344,7 @@ def test_push_to_hub(self, request, name: str): def test_prepare_for_training(self, request, records): records = request.getfixturevalue(records) - ds = ar.DatasetForTextClassification(records) + ds = rg.DatasetForTextClassification(records) train = ds.prepare_for_training() if not ds[0].multi_label: @@ -381,14 +381,14 @@ def test_from_dataset_with_non_argilla_format_multilabel(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, ) - rb_ds = ar.DatasetForTextClassification.from_datasets( + rb_ds = rg.DatasetForTextClassification.from_datasets( ds, inputs="id", annotation="labels", ) assert rb_ds[0].inputs == {"id": "eecwqtt"} - rb_ds = ar.DatasetForTextClassification.from_datasets( + rb_ds = rg.DatasetForTextClassification.from_datasets( ds, text="text", annotation="labels", @@ -422,7 +422,7 @@ def test_from_dataset_with_non_argilla_format(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, ) - rb_ds = ar.DatasetForTextClassification.from_datasets( + rb_ds = rg.DatasetForTextClassification.from_datasets( ds, text="review", annotation="star", metadata=["package_name", "date"] ) @@ -453,19 +453,19 @@ def test_from_datasets_with_annotation_arg(self): } ), ) - dataset_rb = ar.DatasetForTextClassification.from_datasets(dataset_ds, annotation="label") + dataset_rb = rg.DatasetForTextClassification.from_datasets(dataset_ds, annotation="label") assert [rec.annotation for rec in dataset_rb] == ["HAM", None] class TestDatasetForTokenClassification: def test_init(self, tokenclassification_records): - ds = ar.DatasetForTokenClassification(tokenclassification_records) - assert ds._RECORD_TYPE == ar.TokenClassificationRecord + ds = rg.DatasetForTokenClassification(tokenclassification_records) + assert ds._RECORD_TYPE == rg.TokenClassificationRecord assert ds._records == tokenclassification_records def test_to_from_datasets(self, tokenclassification_records): - expected_dataset = ar.DatasetForTokenClassification(tokenclassification_records) + expected_dataset = rg.DatasetForTokenClassification(tokenclassification_records) dataset_ds = expected_dataset.to_datasets() @@ -502,42 +502,42 @@ def test_to_from_datasets(self, tokenclassification_records): } ] - dataset = ar.DatasetForTokenClassification.from_datasets(dataset_ds) + dataset = rg.DatasetForTokenClassification.from_datasets(dataset_ds) - assert isinstance(dataset, ar.DatasetForTokenClassification) + assert isinstance(dataset, rg.DatasetForTokenClassification) _compare_datasets(dataset, expected_dataset) missing_optional_cols = datasets.Dataset.from_dict({"text": ["mock"], "tokens": [["mock"]]}) - rec = ar.DatasetForTokenClassification.from_datasets(missing_optional_cols)[0] + rec = rg.DatasetForTokenClassification.from_datasets(missing_optional_cols)[0] assert rec.text == "mock" and rec.tokens == ["mock"] def test_from_to_datasets_id(self): - dataset_rb = ar.DatasetForTokenClassification([ar.TokenClassificationRecord(text="mock", tokens=["mock"])]) + dataset_rb = rg.DatasetForTokenClassification([rg.TokenClassificationRecord(text="mock", tokens=["mock"])]) dataset_ds = dataset_rb.to_datasets() assert dataset_ds["id"] == [None] - assert ar.read_datasets(dataset_ds, task="TokenClassification")[0].id is None + assert rg.read_datasets(dataset_ds, task="TokenClassification")[0].id is None def test_prepare_for_training_empty(self): - dataset = ar.DatasetForTokenClassification([ar.TokenClassificationRecord(text="mock", tokens=["mock"])]) + dataset = rg.DatasetForTokenClassification([rg.TokenClassificationRecord(text="mock", tokens=["mock"])]) with pytest.raises(AssertionError): dataset.prepare_for_training() def test_datasets_empty_metadata(self): - dataset = ar.DatasetForTokenClassification([ar.TokenClassificationRecord(text="mock", tokens=["mock"])]) + dataset = rg.DatasetForTokenClassification([rg.TokenClassificationRecord(text="mock", tokens=["mock"])]) assert dataset.to_datasets()["metadata"] == [None] def test_to_from_pandas(self, tokenclassification_records): - expected_dataset = ar.DatasetForTokenClassification(tokenclassification_records) + expected_dataset = rg.DatasetForTokenClassification(tokenclassification_records) dataset_df = expected_dataset.to_pandas() assert isinstance(dataset_df, pd.DataFrame) assert list(dataset_df.columns) == list(expected_dataset[0].__fields__.keys()) - dataset = ar.DatasetForTokenClassification.from_pandas(dataset_df) + dataset = rg.DatasetForTokenClassification.from_pandas(dataset_df) - assert isinstance(dataset, ar.DatasetForTokenClassification) + assert isinstance(dataset, rg.DatasetForTokenClassification) for rec, expected in zip(dataset, expected_dataset): assert rec == expected @@ -546,7 +546,7 @@ def test_to_from_pandas(self, tokenclassification_records): reason="You need a HF Hub access token to test the push_to_hub feature", ) def test_push_to_hub(self, tokenclassification_records): - dataset_ds = ar.DatasetForTokenClassification(tokenclassification_records).to_datasets() + dataset_ds = rg.DatasetForTokenClassification(tokenclassification_records).to_datasets() _push_to_hub_with_retries( dataset_ds, # TODO(@frascuchon): Move dataset to the new org @@ -575,7 +575,7 @@ def test_prepare_for_training_with_spacy(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, split="train", ) - rb_dataset: DatasetForTokenClassification = ar.read_datasets(ner_dataset, task="TokenClassification") + rb_dataset: DatasetForTokenClassification = rg.read_datasets(ner_dataset, task="TokenClassification") for r in rb_dataset: r.annotation = [(label, start, end) for label, start, end, _ in r.prediction] @@ -603,7 +603,7 @@ def test_prepare_for_training_with_spark_nlp(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, split="train", ) - rb_dataset: DatasetForTokenClassification = ar.read_datasets(ner_dataset, task="TokenClassification") + rb_dataset: DatasetForTokenClassification = rg.read_datasets(ner_dataset, task="TokenClassification") for r in rb_dataset: r.annotation = [(label, start, end) for label, start, end, _ in r.prediction] @@ -628,7 +628,7 @@ def test_prepare_for_training(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, split="train", ) - rb_dataset: DatasetForTokenClassification = ar.read_datasets(ner_dataset, task="TokenClassification") + rb_dataset: DatasetForTokenClassification = rg.read_datasets(ner_dataset, task="TokenClassification") for r in rb_dataset: r.annotation = [(label, start, end) for label, start, end, _ in r.prediction] @@ -691,7 +691,7 @@ def test_from_dataset_with_non_argilla_format(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, ) - rb_ds = ar.DatasetForTokenClassification.from_datasets(ds, tags="ner_tags", metadata=["spans"]) + rb_ds = rg.DatasetForTokenClassification.from_datasets(ds, tags="ner_tags", metadata=["spans"]) again_the_ds = rb_ds.to_datasets() assert again_the_ds.column_names == [ @@ -710,7 +710,7 @@ def test_from_dataset_with_non_argilla_format(self): def test_from_datasets_with_empty_tokens(self, caplog): dataset_ds = datasets.Dataset.from_dict({"empty_tokens": [["mock"], []]}) - dataset_rb = ar.DatasetForTokenClassification.from_datasets(dataset_ds, tokens="empty_tokens") + dataset_rb = rg.DatasetForTokenClassification.from_datasets(dataset_ds, tokens="empty_tokens") assert caplog.record_tuples[0][1] == 30 assert caplog.record_tuples[0][2] == "Ignoring row with no tokens." @@ -721,12 +721,12 @@ def test_from_datasets_with_empty_tokens(self, caplog): class TestDatasetForText2Text: def test_init(self, text2text_records): - ds = ar.DatasetForText2Text(text2text_records) - assert ds._RECORD_TYPE == ar.Text2TextRecord + ds = rg.DatasetForText2Text(text2text_records) + assert ds._RECORD_TYPE == rg.Text2TextRecord assert ds._records == text2text_records def test_to_from_datasets(self, text2text_records): - expected_dataset = ar.DatasetForText2Text(text2text_records) + expected_dataset = rg.DatasetForText2Text(text2text_records) dataset_ds = expected_dataset.to_datasets() @@ -751,43 +751,43 @@ def test_to_from_datasets(self, text2text_records): } ] - dataset = ar.DatasetForText2Text.from_datasets(dataset_ds) + dataset = rg.DatasetForText2Text.from_datasets(dataset_ds) - assert isinstance(dataset, ar.DatasetForText2Text) + assert isinstance(dataset, rg.DatasetForText2Text) _compare_datasets(dataset, expected_dataset) missing_optional_cols = datasets.Dataset.from_dict({"text": ["mock"]}) - rec = ar.DatasetForText2Text.from_datasets(missing_optional_cols)[0] + rec = rg.DatasetForText2Text.from_datasets(missing_optional_cols)[0] assert rec.text == "mock" # alternative format for the predictions ds = datasets.Dataset.from_dict({"text": ["example"], "prediction": [["ejemplo"]]}) - rec = ar.DatasetForText2Text.from_datasets(ds)[0] + rec = rg.DatasetForText2Text.from_datasets(ds)[0] assert rec.prediction[0][0] == "ejemplo" assert rec.prediction[0][1] == pytest.approx(1.0) def test_from_to_datasets_id(self): - dataset_rb = ar.DatasetForText2Text([ar.Text2TextRecord(text="mock")]) + dataset_rb = rg.DatasetForText2Text([rg.Text2TextRecord(text="mock")]) dataset_ds = dataset_rb.to_datasets() assert dataset_ds["id"] == [None] - assert ar.read_datasets(dataset_ds, task="Text2Text")[0].id is None + assert rg.read_datasets(dataset_ds, task="Text2Text")[0].id is None def test_datasets_empty_metadata(self): - dataset = ar.DatasetForText2Text([ar.Text2TextRecord(text="mock")]) + dataset = rg.DatasetForText2Text([rg.Text2TextRecord(text="mock")]) assert dataset.to_datasets()["metadata"] == [None] def test_to_from_pandas(self, text2text_records): - expected_dataset = ar.DatasetForText2Text(text2text_records) + expected_dataset = rg.DatasetForText2Text(text2text_records) dataset_df = expected_dataset.to_pandas() assert isinstance(dataset_df, pd.DataFrame) assert list(dataset_df.columns) == list(expected_dataset[0].__fields__.keys()) - dataset = ar.DatasetForText2Text.from_pandas(dataset_df) + dataset = rg.DatasetForText2Text.from_pandas(dataset_df) - assert isinstance(dataset, ar.DatasetForText2Text) + assert isinstance(dataset, rg.DatasetForText2Text) for rec, expected in zip(dataset, expected_dataset): assert rec == expected @@ -796,7 +796,7 @@ def test_to_from_pandas(self, text2text_records): reason="You need a HF Hub access token to test the push_to_hub feature", ) def test_push_to_hub(self, text2text_records): - dataset_ds = ar.DatasetForText2Text(text2text_records).to_datasets() + dataset_ds = rg.DatasetForText2Text(text2text_records).to_datasets() _push_to_hub_with_retries( dataset_ds, # TODO(@frascuchon): Move dataset to the new org @@ -824,7 +824,7 @@ def test_from_dataset_with_non_argilla_format(self): use_auth_token=_HF_HUB_ACCESS_TOKEN, ) - rb_ds = ar.DatasetForText2Text.from_datasets(ds, text="description", annotation="abstract") + rb_ds = rg.DatasetForText2Text.from_datasets(ds, text="description", annotation="abstract") again_the_ds = rb_ds.to_datasets() assert again_the_ds.column_names == [ @@ -864,7 +864,7 @@ def mock_from_pandas(mock): monkeypatch.setattr(f"argilla.client.datasets.{dataset_class}.from_pandas", mock_from_pandas) - assert ar.read_pandas("mock", task) == "mock" + assert rg.read_pandas("mock", task) == "mock" @pytest.mark.parametrize( @@ -881,4 +881,4 @@ def mock_from_datasets(mock): monkeypatch.setattr(f"argilla.client.datasets.{dataset_class}.from_datasets", mock_from_datasets) - assert ar.read_datasets("mock", task) == "mock" + assert rg.read_datasets("mock", task) == "mock" diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 5f71f9d282..5c7a385ceb 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argilla as ar +import argilla as rg import pytest from argilla import TextClassificationSettings, TokenClassificationSettings from argilla.client import api @@ -34,8 +34,8 @@ ) def test_settings_workflow(mocked_client, settings_, wrong_settings): dataset = "test-dataset" - ar.delete(dataset) - ar.configure_dataset(dataset, settings=settings_) + rg.delete(dataset) + rg.configure_dataset(dataset, settings=settings_) current_api = api.active_api() datasets_api = current_api.datasets @@ -44,13 +44,13 @@ def test_settings_workflow(mocked_client, settings_, wrong_settings): assert found_settings == settings_ settings_.label_schema = {"LALALA"} - ar.configure_dataset(dataset, settings_) + rg.configure_dataset(dataset, settings_) found_settings = datasets_api.load_settings(dataset) assert found_settings == settings_ with pytest.raises(ValueError, match="Task type mismatch"): - ar.configure_dataset(dataset, wrong_settings) + rg.configure_dataset(dataset, wrong_settings) def test_list_dataset(mocked_client): @@ -66,10 +66,10 @@ def test_list_dataset(mocked_client): def test_delete_dataset_by_non_creator(mocked_client): try: dataset = "test_delete_dataset_by_non_creator" - ar.delete(dataset) - ar.configure_dataset(dataset, settings=TextClassificationSettings(label_schema={"A", "B", "C"})) + rg.delete(dataset) + rg.configure_dataset(dataset, settings=TextClassificationSettings(label_schema={"A", "B", "C"})) mocked_client.change_current_user("mock-user") with pytest.raises(ForbiddenApiError): - ar.delete(dataset) + rg.delete(dataset) finally: mocked_client.reset_default_user() diff --git a/tests/functional_tests/datasets/helpers.py b/tests/functional_tests/datasets/helpers.py index aeea8bd0da..a138e95e35 100644 --- a/tests/functional_tests/datasets/helpers.py +++ b/tests/functional_tests/datasets/helpers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argilla as ar +import argilla as rg from argilla import Text2TextRecord, TextClassificationRecord, TokenClassificationRecord from argilla.server.commons.models import TaskType @@ -38,7 +38,7 @@ def text2text(idx): ) dataset = "test_dataset" - ar.delete(dataset) + rg.delete(dataset) if task == TaskType.text_classification: record_builder = text_class @@ -49,7 +49,7 @@ def text2text(idx): records = [record_builder(i) for i in range(0, 50)] - ar.log( + rg.log( name=dataset, records=records, ) diff --git a/tests/functional_tests/datasets/test_delete_records_from_datasets.py b/tests/functional_tests/datasets/test_delete_records_from_datasets.py index 56b2e6820c..f13509f6a7 100644 --- a/tests/functional_tests/datasets/test_delete_records_from_datasets.py +++ b/tests/functional_tests/datasets/test_delete_records_from_datasets.py @@ -20,45 +20,45 @@ def test_delete_records_from_dataset(mocked_client): dataset = "test_delete_records_from_dataset" - import argilla as ar + import argilla as rg - ar.delete(dataset) - ar.log( + rg.delete(dataset) + rg.log( name=dataset, records=[ - ar.TextClassificationRecord(id=i, text="This is the text", metadata=dict(idx=i)) for i in range(0, 50) + rg.TextClassificationRecord(id=i, text="This is the text", metadata=dict(idx=i)) for i in range(0, 50) ], ) - matched, processed = ar.delete_records(name=dataset, ids=[10], discard_only=True) + matched, processed = rg.delete_records(name=dataset, ids=[10], discard_only=True) assert matched, processed == (1, 1) - ds = ar.load(name=dataset) + ds = rg.load(name=dataset) assert len(ds) == 50 time.sleep(1) - matched, processed = ar.delete_records(name=dataset, query="id:10", discard_only=False) + matched, processed = rg.delete_records(name=dataset, query="id:10", discard_only=False) assert matched, processed == (1, 1) time.sleep(1) - ds = ar.load(name=dataset) + ds = rg.load(name=dataset) assert len(ds) == 49 def test_delete_records_without_permission(mocked_client): dataset = "test_delete_records_without_permission" - import argilla as ar + import argilla as rg - ar.delete(dataset) - ar.log( + rg.delete(dataset) + rg.log( name=dataset, records=[ - ar.TextClassificationRecord(id=i, text="This is the text", metadata=dict(idx=i)) for i in range(0, 50) + rg.TextClassificationRecord(id=i, text="This is the text", metadata=dict(idx=i)) for i in range(0, 50) ], ) try: mocked_client.change_current_user("mock-user") - matched, processed = ar.delete_records( + matched, processed = rg.delete_records( name=dataset, ids=[10], discard_only=True, @@ -66,14 +66,14 @@ def test_delete_records_without_permission(mocked_client): assert matched, processed == (1, 1) with pytest.raises(ForbiddenApiError): - ar.delete_records( + rg.delete_records( name=dataset, query="id:10", discard_only=False, discard_when_forbidden=False, ) - matched, processed = ar.delete_records( + matched, processed = rg.delete_records( name=dataset, query="id:10", discard_only=False, @@ -86,13 +86,13 @@ def test_delete_records_without_permission(mocked_client): def test_delete_records_with_unmatched_records(mocked_client): dataset = "test_delete_records_with_unmatched_records" - import argilla as ar + import argilla as rg - ar.delete(dataset) - ar.log( + rg.delete(dataset) + rg.log( name=dataset, records=[ - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=i, text="This is the text", metadata=dict(idx=i), @@ -101,5 +101,5 @@ def test_delete_records_with_unmatched_records(mocked_client): ], ) - matched, processed = ar.delete_records(dataset, ids=["you-wont-find-me-here"]) + matched, processed = rg.delete_records(dataset, ids=["you-wont-find-me-here"]) assert (matched, processed) == (0, 0) diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py index be50cdf406..78c305f13c 100644 --- a/tests/functional_tests/test_log_for_text_classification.py +++ b/tests/functional_tests/test_log_for_text_classification.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argilla as ar +import argilla as rg import pytest from argilla.client.sdk.commons.errors import ( BadRequestApiError, @@ -29,14 +29,14 @@ def test_log_records_with_multi_and_single_label_task(mocked_client): dataset = "test_log_records_with_multi_and_single_label_task" expected_inputs = ["This is a text"] - ar.delete(dataset) + rg.delete(dataset) records = [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=0, inputs=expected_inputs, multi_label=False, ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=1, inputs=expected_inputs, multi_label=True, @@ -44,30 +44,30 @@ def test_log_records_with_multi_and_single_label_task(mocked_client): ] with pytest.raises(ValidationApiError): - ar.log( + rg.log( records, name=dataset, ) - ar.log(records[0], name=dataset) + rg.log(records[0], name=dataset) with pytest.raises(Exception): - ar.log(records[1], name=dataset) + rg.log(records[1], name=dataset) def test_delete_and_create_for_different_task(mocked_client): dataset = "test_delete_and_create_for_different_task" text = "This is a text" - ar.delete(dataset) - ar.log(ar.TextClassificationRecord(id=0, inputs=text), name=dataset) - ar.load(dataset) + rg.delete(dataset) + rg.log(rg.TextClassificationRecord(id=0, inputs=text), name=dataset) + rg.load(dataset) - ar.delete(dataset) - ar.log( - ar.TokenClassificationRecord(id=0, text=text, tokens=text.split(" ")), + rg.delete(dataset) + rg.log( + rg.TokenClassificationRecord(id=0, text=text, tokens=text.split(" ")), name=dataset, ) - ar.load(dataset) + rg.load(dataset) @pytest.mark.skipif( @@ -81,34 +81,34 @@ def test_similarity_search_in_python_client( text = "This is a text" vectors = {"my_bert": [1, 2, 3, 4]} - ar.delete(dataset) - ar.log( - ar.TextClassificationRecord( + rg.delete(dataset) + rg.log( + rg.TextClassificationRecord( id=0, inputs=text, vectors=vectors, ), name=dataset, ) - ds = ar.load(dataset, vector=("my_bert", [1, 1, 1, 1])) + ds = rg.load(dataset, vector=("my_bert", [1, 1, 1, 1])) assert len(ds) == 1 - ar.log( - ar.TextClassificationRecord( + rg.log( + rg.TextClassificationRecord( id=1, inputs=text, vectors={"my_bert_2": [1, 2, 3, 4]}, ), name=dataset, ) - ds = ar.load(dataset, vector=("my_bert_2", [1, 1, 1, 1])) + ds = rg.load(dataset, vector=("my_bert_2", [1, 1, 1, 1])) assert len(ds) == 1 with pytest.raises( BadRequestApiError, match="Cannot create more than 5 kind of vectors per dataset", ): - ar.log( - ar.TextClassificationRecord( + rg.log( + rg.TextClassificationRecord( id=3, inputs=text, vectors={ @@ -132,10 +132,10 @@ def test_log_data_with_vectors_and_update_ok( ): dataset = "test_log_data_with_vectors_and_update_ok" text = "This is a text" - ar.delete(dataset) + rg.delete(dataset) records = [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=i, inputs=text, vectors={"text": [i] * 5}, @@ -143,11 +143,11 @@ def test_log_data_with_vectors_and_update_ok( for i in range(1, 10) ] - ar.log( + rg.log( records=records, name=dataset, ) - ds = ar.load( + ds = rg.load( dataset, vector=( "text", @@ -169,17 +169,17 @@ def test_log_data_with_vectors_and_update_ko(mocked_client: SecuredClient): text = "This is a text" vectors = {"my_bert": [1, 2, 3, 4]} - ar.delete(dataset) - ar.log( - ar.TextClassificationRecord(id=0, inputs=text, vectors=vectors), + rg.delete(dataset) + rg.log( + rg.TextClassificationRecord(id=0, inputs=text, vectors=vectors), name=dataset, ) - ar.load(dataset) + rg.load(dataset) updated_vectors = {"my_bert": [2, 3, 5]} with pytest.raises(GenericApiError): - ar.log( - ar.TextClassificationRecord(id=0, text=text, vectors=updated_vectors), + rg.log( + rg.TextClassificationRecord(id=0, text=text, vectors=updated_vectors), name=dataset, ) @@ -191,21 +191,21 @@ def test_log_data_in_several_workspaces(mocked_client: SecuredClient): mocked_client.add_workspaces_to_argilla_user([workspace]) - curr_ws = ar.get_workspace() + curr_ws = rg.get_workspace() for ws in [curr_ws, workspace]: - ar.set_workspace(ws) - ar.delete(dataset) + rg.set_workspace(ws) + rg.delete(dataset) - ar.set_workspace(curr_ws) - ar.log(ar.TextClassificationRecord(id=0, inputs=text), name=dataset) + rg.set_workspace(curr_ws) + rg.log(rg.TextClassificationRecord(id=0, inputs=text), name=dataset) - ar.set_workspace(workspace) - ar.log(ar.TextClassificationRecord(id=1, inputs=text), name=dataset) - ds = ar.load(dataset) + rg.set_workspace(workspace) + rg.log(rg.TextClassificationRecord(id=1, inputs=text), name=dataset) + ds = rg.load(dataset) assert len(ds) == 1 - ar.set_workspace(curr_ws) - ds = ar.load(dataset) + rg.set_workspace(curr_ws) + ds = rg.load(dataset) assert len(ds) == 1 @@ -214,12 +214,12 @@ def test_search_keywords(mocked_client): from datasets import load_dataset dataset_ds = load_dataset("Recognai/sentiment-banking", split="train") - dataset_rb = ar.read_datasets(dataset_ds, task="TextClassification") + dataset_rb = rg.read_datasets(dataset_ds, task="TextClassification") - ar.delete(dataset) - ar.log(name=dataset, records=dataset_rb) + rg.delete(dataset) + rg.log(name=dataset, records=dataset_rb) - ds = ar.load(dataset, query="lim*") + ds = rg.load(dataset, query="lim*") df = ds.to_pandas() assert not df.empty assert "search_keywords" in df.columns @@ -236,16 +236,16 @@ def test_search_keywords(mocked_client): def test_log_records_with_empty_metadata_list(mocked_client): dataset = "test_log_records_with_empty_metadata_list" - ar.delete(dataset) + rg.delete(dataset) expected_records = [ - ar.TextClassificationRecord(text="The input text", metadata={"emptyList": []}), - ar.TextClassificationRecord(text="The input text", metadata={"emptyTuple": ()}), - ar.TextClassificationRecord(text="The input text", metadata={"emptyDict": {}}), - ar.TextClassificationRecord(text="The input text", metadata={"none": None}), + rg.TextClassificationRecord(text="The input text", metadata={"emptyList": []}), + rg.TextClassificationRecord(text="The input text", metadata={"emptyTuple": ()}), + rg.TextClassificationRecord(text="The input text", metadata={"emptyDict": {}}), + rg.TextClassificationRecord(text="The input text", metadata={"none": None}), ] - ar.log(expected_records, name=dataset) + rg.log(expected_records, name=dataset) - df = ar.load(dataset) + df = rg.load(dataset) df = df.to_pandas() assert len(df) == len(expected_records) @@ -256,67 +256,67 @@ def test_log_records_with_empty_metadata_list(mocked_client): def test_logging_with_metadata_limits_exceeded(mocked_client): dataset = "test_logging_with_metadata_limits_exceeded" - ar.delete(dataset) + rg.delete(dataset) - expected_record = ar.TextClassificationRecord( + expected_record = rg.TextClassificationRecord( text="The input text", metadata={k: f"this is a string {k}" for k in range(0, settings.metadata_fields_limit + 1)}, ) with pytest.raises(BadRequestApiError): - ar.log(expected_record, name=dataset) + rg.log(expected_record, name=dataset) expected_record.metadata = {k: f"This is a string {k}" for k in range(0, settings.metadata_fields_limit)} # Dataset creation with data - ar.log(expected_record, name=dataset) + rg.log(expected_record, name=dataset) # This call will check already included fields - ar.log(expected_record, name=dataset) + rg.log(expected_record, name=dataset) expected_record.metadata["new_key"] = "value" with pytest.raises(BadRequestApiError): - ar.log(expected_record, name=dataset) + rg.log(expected_record, name=dataset) def test_log_with_other_task(mocked_client): dataset = "test_log_with_other_task" - ar.delete(dataset) - record = ar.TextClassificationRecord( + rg.delete(dataset) + record = rg.TextClassificationRecord( text="The input text", ) - ar.log(record, name=dataset) + rg.log(record, name=dataset) with pytest.raises(BadRequestApiError): - ar.log( - ar.TokenClassificationRecord(text="The text", tokens=["The", "text"]), + rg.log( + rg.TokenClassificationRecord(text="The text", tokens=["The", "text"]), name=dataset, ) def test_dynamics_metadata(mocked_client): dataset = "test_dynamics_metadata" - ar.log( - ar.TextClassificationRecord(text="This is a text", metadata={"a": "value"}), + rg.log( + rg.TextClassificationRecord(text="This is a text", metadata={"a": "value"}), name=dataset, ) - ar.log( - ar.TextClassificationRecord(text="Another text", metadata={"b": "value"}), + rg.log( + rg.TextClassificationRecord(text="Another text", metadata={"b": "value"}), name=dataset, ) def test_log_with_bulk_error(mocked_client): dataset = "test_log_with_bulk_error" - ar.delete(dataset) + rg.delete(dataset) try: - ar.log( + rg.log( [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=0, text="This is an special text", metadata={"key": 1}, ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=1, text="This is an special text", metadata={"key": "wrong-value"}, diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py index 0e46394bf0..612933aa36 100644 --- a/tests/labeling/text_classification/test_label_errors.py +++ b/tests/labeling/text_classification/test_label_errors.py @@ -14,7 +14,7 @@ # limitations under the License. import sys -import argilla as ar +import argilla as rg import cleanlab import pytest from _pytest.logging import LogCaptureFixture @@ -32,7 +32,7 @@ def records(request): if request.param: return [ - ar.TextClassificationRecord(text="test", annotation=anot, prediction=pred, multi_label=True, id=i) + rg.TextClassificationRecord(text="test", annotation=anot, prediction=pred, multi_label=True, id=i) for i, anot, pred in zip( range(2 * 6), [["bad"], ["bad", "good"]] * 6, @@ -41,7 +41,7 @@ def records(request): ] return [ - ar.TextClassificationRecord(text="test", annotation=anot, prediction=pred, id=i) + rg.TextClassificationRecord(text="test", annotation=anot, prediction=pred, id=i) for i, anot, pred in zip( range(2 * 6), ["good", "bad"] * 6, @@ -63,8 +63,8 @@ def test_not_installed(monkeypatch): def test_no_records(): records = [ - ar.TextClassificationRecord(text="test", prediction=[("mock", 0.0)]), - ar.TextClassificationRecord(text="test", annotation="test"), + rg.TextClassificationRecord(text="test", prediction=[("mock", 0.0)]), + rg.TextClassificationRecord(text="test", annotation="test"), ] with pytest.raises(NoRecordsError, match="none of your records have a prediction AND annotation"): @@ -72,7 +72,7 @@ def test_no_records(): def test_multi_label_warning(caplog: LogCaptureFixture): - record = ar.TextClassificationRecord( + record = rg.TextClassificationRecord( text="test", prediction=[("mock", 0.0), ("mock2", 0.0)], annotation=["mock", "mock2"], @@ -112,7 +112,7 @@ def mock_find_label_issues(*args, **kwargs): mock_find_label_issues, ) - record = ar.TextClassificationRecord(text="mock", prediction=[("mock", 0.1)], annotation="mock") + record = rg.TextClassificationRecord(text="mock", prediction=[("mock", 0.1)], annotation="mock") find_label_errors(records=[record], sort_by=sort_by) @@ -191,14 +191,14 @@ def test_construct_s_and_psx(records): def test_missing_predictions(): - records = [ar.TextClassificationRecord(text="test", annotation="mock", prediction=[("mock2", 0.1)])] + records = [rg.TextClassificationRecord(text="test", annotation="mock", prediction=[("mock2", 0.1)])] with pytest.raises( MissingPredictionError, match="It seems predictions are missing for the label 'mock'", ): _construct_s_and_psx(records) - records.append(ar.TextClassificationRecord(text="test", annotation="mock", prediction=[("mock", 0.1)])) + records.append(rg.TextClassificationRecord(text="test", annotation="mock", prediction=[("mock", 0.1)])) with pytest.raises( MissingPredictionError, match="It seems a prediction for 'mock' is missing in the following record", @@ -210,13 +210,13 @@ def test_missing_predictions(): def dataset(mocked_client, records): dataset = "dataset_for_label_errors" - ar.log(records, name=dataset) + rg.log(records, name=dataset) yield dataset - ar.delete(dataset) + rg.delete(dataset) def test_find_label_errors_integration(dataset): - records = ar.load(dataset) + records = rg.load(dataset) recs = find_label_errors(records) assert [rec.id for rec in recs] == list(range(0, 11, 2)) + list(range(1, 12, 2)) diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py index 7920319b4e..04acd71abe 100644 --- a/tests/labeling/text_classification/test_rule.py +++ b/tests/labeling/text_classification/test_rule.py @@ -252,7 +252,7 @@ def test_update_rules(mocked_client, log_dataset): def test_copy_dataset_with_rules(mocked_client, log_dataset): - import argilla as ar + import argilla as rg rule = Rule(query="a query", label="LALA") delete_rule_silently(mocked_client, log_dataset, rule) @@ -263,8 +263,8 @@ def test_copy_dataset_with_rules(mocked_client, log_dataset): ) copied_dataset = f"{log_dataset}_copy" - ar.delete(copied_dataset) - ar.copy(log_dataset, name_of_copy=copied_dataset) + rg.delete(copied_dataset) + rg.copy(log_dataset, name_of_copy=copied_dataset) assert [{"q": r.query, "l": r.label} for r in load_rules(copied_dataset)] == [ {"q": r.query, "l": r.label} for r in load_rules(log_dataset) diff --git a/tests/listeners/test_listener.py b/tests/listeners/test_listener.py index 7651a7dbaf..13302bd448 100644 --- a/tests/listeners/test_listener.py +++ b/tests/listeners/test_listener.py @@ -15,7 +15,7 @@ import time from typing import List -import argilla as ar +import argilla as rg import pytest from argilla import RGListenerContext, listener from argilla.client.models import Record @@ -40,7 +40,7 @@ def condition_check_params(search): ], ) def test_listener_with_parameters(mocked_client, dataset, query, metrics, condition, query_params): - ar.delete(dataset) + rg.delete(dataset) class TestListener: executed = False @@ -78,7 +78,7 @@ def action(self, records: List[Record], ctx: RGListenerContext): time.sleep(1.5) assert test.action.is_running() - ar.log(ar.TextClassificationRecord(text="This is a text"), name=dataset) + rg.log(rg.TextClassificationRecord(text="This is a text"), name=dataset) with pytest.raises(ValueError): test.action.start() diff --git a/tests/metrics/test_common_metrics.py b/tests/metrics/test_common_metrics.py index 1bf51590aa..c26d9915a3 100644 --- a/tests/metrics/test_common_metrics.py +++ b/tests/metrics/test_common_metrics.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argilla -import argilla as ar +import argilla as rg import pytest from argilla.metrics.commons import keywords, records_status, text_length @@ -40,17 +40,17 @@ def gutenberg_spacy_ner(mocked_client): def test_status_distribution(mocked_client): dataset = "test_status_distribution" - ar.delete(dataset) + rg.delete(dataset) - ar.log( + rg.log( [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=1, inputs={"text": "my first example"}, prediction=[("spam", 0.8), ("ham", 0.2)], annotation=["spam"], ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=2, inputs={"text": "my second example"}, prediction=[("ham", 0.8), ("spam", 0.2)], @@ -70,24 +70,24 @@ def test_status_distribution(mocked_client): def test_text_length(mocked_client): dataset = "test_text_length" - ar.delete(dataset) + rg.delete(dataset) - ar.log( + rg.log( [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=1, inputs={"text": "my first example"}, prediction=[("spam", 0.8), ("ham", 0.2)], annotation=["spam"], ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=2, inputs={"text": "my second example"}, prediction=[("ham", 0.8), ("spam", 0.2)], annotation=["ham"], status="Default", ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=3, inputs={"text": "simple text"}, prediction=[("ham", 0.8), ("spam", 0.2)], diff --git a/tests/metrics/test_text_classification.py b/tests/metrics/test_text_classification.py index 30e3d61564..f6741166f8 100644 --- a/tests/metrics/test_text_classification.py +++ b/tests/metrics/test_text_classification.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argilla as ar +import argilla as rg from argilla import TextClassificationRecord from argilla.client import api from argilla.metrics.text_classification import f1, f1_multilabel @@ -21,15 +21,15 @@ def test_metrics_for_text_classification(mocked_client): dataset = "test_metrics_for_text_classification" - ar.log( + rg.log( [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=1, text="my first argilla example", prediction=[("spam", 0.8), ("ham", 0.2)], annotation=["spam"], ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=2, inputs={"text": "my first argilla example"}, prediction=[("ham", 0.8), ("spam", 0.2)], @@ -81,13 +81,13 @@ def test_metrics_for_text_classification(mocked_client): def test_f1_without_results(mocked_client): dataset = "test_f1_without_results" - ar.log( + rg.log( [ - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=1, text="my first argilla example", ), - ar.TextClassificationRecord( + rg.TextClassificationRecord( id=2, inputs={"text": "my first argilla example"}, ), @@ -133,7 +133,7 @@ def test_dataset_labels_metric(mocked_client): for i in range(2000, 3000) ] ) - ar.log( + rg.log( name=dataset, records=records, ) diff --git a/tests/metrics/test_token_classification.py b/tests/metrics/test_token_classification.py index 89ee11c961..466b71ecca 100644 --- a/tests/metrics/test_token_classification.py +++ b/tests/metrics/test_token_classification.py @@ -13,7 +13,7 @@ # limitations under the License. import argilla -import argilla as ar +import argilla as rg import pytest from argilla.metrics import entity_consistency from argilla.metrics.token_classification import ( @@ -35,30 +35,30 @@ def log_some_data(dataset: str): argilla.delete(dataset) text = "My first great example \n" tokens = text.split(" ") - ar.log( + rg.log( [ - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( id=1, text=text, tokens=tokens, prediction=[("CARDINAL", 3, 8)], annotation=[("CARDINAL", 3, 8)], ), - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( id=2, text=text, tokens=tokens, prediction=[("CARDINAL", 3, 8)], annotation=[("CARDINAL", 3, 8)], ), - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( id=3, text=text, tokens=tokens, prediction=[("NUMBER", 3, 8)], annotation=[("NUMBER", 3, 8)], ), - ar.TokenClassificationRecord( + rg.TokenClassificationRecord( id=4, text=text, tokens=tokens, @@ -72,10 +72,10 @@ def log_some_data(dataset: str): def test_search_by_nested_metric(mocked_client): dataset = "test_search_by_nested_metric" - ar.delete(dataset) + rg.delete(dataset) log_some_data(dataset) - df = ar.load(dataset, query="metrics.predicted.mentions.capitalness: LOWER") + df = rg.load(dataset, query="metrics.predicted.mentions.capitalness: LOWER") assert len(df) > 0 @@ -286,12 +286,12 @@ def validate_mentions( ) def test_metrics_without_data(mocked_client, metric, expected_results, monkeypatch): dataset = "test_metrics_without_data" - ar.delete(dataset) + rg.delete(dataset) text = "M" tokens = text.split(" ") - ar.log( - ar.TokenClassificationRecord( + rg.log( + rg.TokenClassificationRecord( id=1, text=text, tokens=tokens, @@ -309,8 +309,8 @@ def test_metrics_for_text_classification(mocked_client): dataset = "test_metrics_for_token_classification" text = "test the f1 metric of the token classification task" - ar.log( - ar.TokenClassificationRecord( + rg.log( + rg.TokenClassificationRecord( id=1, text=text, tokens=text.split(), diff --git a/tests/monitoring/test_flair_monitoring.py b/tests/monitoring/test_flair_monitoring.py index 18aa38201e..91b7886ce1 100644 --- a/tests/monitoring/test_flair_monitoring.py +++ b/tests/monitoring/test_flair_monitoring.py @@ -15,18 +15,18 @@ def test_flair_monitoring(mocked_client, monkeypatch): - import argilla as ar + import argilla as rg from flair.data import Sentence from flair.models import SequenceTagger dataset = "test_flair_monitoring" model = "flair/ner-english" - ar.delete(dataset) + rg.delete(dataset) # load tagger tagger = SequenceTagger.load(model) - tagger = ar.monitor( + tagger = rg.monitor( tagger, dataset=dataset, sample_rate=1.0, @@ -42,7 +42,7 @@ def test_flair_monitoring(mocked_client, monkeypatch): sleep(1) # wait for the consumer time detected_labels = sentence.get_labels("ner") - records = ar.load(dataset) + records = rg.load(dataset) assert len(records) == 1 record = records[0] diff --git a/tests/monitoring/test_spacy_monitoring.py b/tests/monitoring/test_spacy_monitoring.py index 83a1414dfe..29cb8b6080 100644 --- a/tests/monitoring/test_spacy_monitoring.py +++ b/tests/monitoring/test_spacy_monitoring.py @@ -15,17 +15,17 @@ import random from time import sleep -import argilla as ar +import argilla as rg def test_spacy_ner_monitor(monkeypatch, mocked_client): dataset = "spacy-dataset" - ar.delete(dataset) + rg.delete(dataset) import spacy nlp = spacy.load("en_core_web_sm") - nlp = ar.monitor( + nlp = rg.monitor( nlp, dataset=dataset, sample_rate=0.5, @@ -38,26 +38,26 @@ def test_spacy_ner_monitor(monkeypatch, mocked_client): nlp("Paris is my favourite city") sleep(1) # wait for the consumer time - df = ar.load(dataset) + df = rg.load(dataset) df = df.to_pandas() assert len(df) == 11 # assert 10 - std < len(df) < 10 + std assert df.text.unique().tolist() == ["Paris is my favourite city"] - ar.delete(dataset) + rg.delete(dataset) list(nlp.pipe(["This is a text"] * 20)) sleep(1) # wait for the consumer time - df = ar.load(dataset) + df = rg.load(dataset) df = df.to_pandas() assert len(df) == 6 assert df.text.unique().tolist() == ["This is a text"] - ar.delete(dataset) + rg.delete(dataset) list(nlp.pipe([("This is a text", {"meta": "data"})] * 20, as_tuples=True)) sleep(1) # wait for the consumer time - df = ar.load(dataset) + df = rg.load(dataset) df = df.to_pandas() assert len(df) == 14 for metadata in df.metadata.values.tolist(): diff --git a/tests/server/text_classification/test_api_settings.py b/tests/server/text_classification/test_api_settings.py index e091ce7952..d69472a59f 100644 --- a/tests/server/text_classification/test_api_settings.py +++ b/tests/server/text_classification/test_api_settings.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argilla as ar +import argilla as rg from argilla.server.commons.models import TaskType @@ -23,7 +23,7 @@ def create_dataset(client, name: str): def test_create_dataset_settings(mocked_client): name = "test_create_dataset_settings" - ar.delete(name) + rg.delete(name) create_dataset(mocked_client, name) response = create_settings(mocked_client, name) @@ -44,7 +44,7 @@ def create_settings(mocked_client, name): def test_get_dataset_settings_not_found(mocked_client): name = "test_get_dataset_settings" - ar.delete(name) + rg.delete(name) create_dataset(mocked_client, name) response = fetch_settings(mocked_client, name) @@ -53,7 +53,7 @@ def test_get_dataset_settings_not_found(mocked_client): def test_delete_settings(mocked_client): name = "test_delete_settings" - ar.delete(name) + rg.delete(name) create_dataset(mocked_client, name) assert create_settings(mocked_client, name).status_code == 200 @@ -65,7 +65,7 @@ def test_delete_settings(mocked_client): def test_validate_settings_when_logging_data(mocked_client): name = "test_validate_settings_when_logging_data" - ar.delete(name) + rg.delete(name) create_dataset(mocked_client, name) assert create_settings(mocked_client, name).status_code == 200 @@ -88,7 +88,7 @@ def test_validate_settings_when_logging_data(mocked_client): def test_validate_settings_after_logging(mocked_client): name = "test_validate_settings_after_logging" - ar.delete(name) + rg.delete(name) response = log_some_data(mocked_client, name) assert response.status_code == 200 diff --git a/tests/server/token_classification/test_api_settings.py b/tests/server/token_classification/test_api_settings.py index 587ea89177..7370903d05 100644 --- a/tests/server/token_classification/test_api_settings.py +++ b/tests/server/token_classification/test_api_settings.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argilla as ar +import argilla as rg from argilla.server.commons.models import TaskType @@ -23,7 +23,7 @@ def create_dataset(client, name: str): def test_create_dataset_settings(mocked_client): name = "test_create_dataset_settings" - ar.delete(name) + rg.delete(name) create_dataset(mocked_client, name) response = create_settings(mocked_client, name) @@ -44,7 +44,7 @@ def create_settings(mocked_client, name): def test_get_dataset_settings_not_found(mocked_client): name = "test_get_dataset_settings" - ar.delete(name) + rg.delete(name) create_dataset(mocked_client, name) response = fetch_settings(mocked_client, name) @@ -53,7 +53,7 @@ def test_get_dataset_settings_not_found(mocked_client): def test_delete_settings(mocked_client): name = "test_delete_settings" - ar.delete(name) + rg.delete(name) create_dataset(mocked_client, name) assert create_settings(mocked_client, name).status_code == 200 @@ -65,7 +65,7 @@ def test_delete_settings(mocked_client): def test_validate_settings_when_logging_data(mocked_client): name = "test_validate_settings_when_logging_data" - ar.delete(name) + rg.delete(name) create_dataset(mocked_client, name) assert create_settings(mocked_client, name).status_code == 200 @@ -110,7 +110,7 @@ def log_some_data(mocked_client, name): def test_validate_settings_after_logging(mocked_client): name = "test_validate_settings_after_logging" - ar.delete(name) + rg.delete(name) response = log_some_data(mocked_client, name) assert response.status_code == 200 From c27672f49781cb0127eb9f4e14cc9cf9b707084f Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Mon, 27 Feb 2023 11:22:18 +0100 Subject: [PATCH 22/45] Refactor: Add `require_version` function and `requires_version` decorator (#2380) Closes #2367 Hello! ## Pull Request overview * Implementation for `require_version` based on the transformer module. * Implementation for `requires_version`, a decorator that wraps `require_version` * `requires_datasets`, `requires_spacy` and `requires_sklearn` as shorthands for `requires_version` for specific modules. * Removed `try-except ModuleNotFoundError` throughout the project for e.g. `cleanlab`, `snorkel`, `flyingsquid`, `pgmpy`, `starlette` and replaced them with `requires_version` decorators. * Added tests to accompany the additions * Modified existing tests ## Details I have added `require_version` which ought to be used like: ```python ... if train_size and test_size: require_version("scikit-learn") from sklearn.model_selection import train_test_split ... ``` And also `requires_version`: ```python ... @requires_version("cleanlab") def find_label_errors( records: Union[List[TextClassificationRecord], DatasetForTextClassification], ... ``` and some shorthands, e.g.: ```python ... @requires_datasets def to_datasets() -> "datasets.Dataset": """Exports your records to a `datasets.Dataset`. ... ``` When `require_version` is called, or when a function/method wrapped with `requires_...` is called, [`importlib.metadata`](https://docs.python.org/3/library/importlib.metadata.html) is used to see if the package is installed, and at which version. This is the recommended approach for collecting version information, i.e. recommended over the less efficient `pkg_resources`. ### Advanced usage We can set version requirements using this approach: ```python require_version("datasets>1.17.0") # or @requires_version("datasets>1.17.0") def ... ``` And not just one, we can set multiple: ```python require_version("datasets>1.17.0,<2.0.0") # or @requires_version("datasets>1.17.0,<2.0.0") def ... ``` We can also specify Python versions for certain functions/methods/sections of code: ```python require_version("python>3.7") # or @requires_version("python>3.7") def ... ``` ### Missing & outdated dependencies Consider the following example: ```python from argilla.utils.dependency import requires_datasets @requires_datasets def foo(): pass foo() ``` When executed without `datasets` installed, the following exception is thrown: ``` ModuleNotFoundError: 'datasets' must be installed to use `foo`! You can install 'datasets' with this command: `pip install datasets>1.17.0` ``` Alternatively, if I install `datasets==1.16.0` and run it again, I get this exception: ``` ImportError: datasets>1.17.0 must be installed to use `foo`, but found datasets==1.16.0. You can install a supported version of 'datasets' with this command: `pip install -U datasets>1.17.0` ``` ### Notes I had to update various tests that contained the following snippet: ```python monkeypatch.setitem(sys.modules, "datasets", None) with pytest.raises(ModuleNotFoundError): ... ``` The reasoning is that although `import datasets` would fail in this case, the `require_version` function is not fooled by this monkeypatch as it reads directly from files in `sys.meta_path`. As a result, `require_version` still recognises that the module is installed. That way, it becomes a bit harder to pretend that a module is not installed. An alternative that I went with is clearing out `sys.meta_path`, which means that nothing can be imported after the monkeypatch. One other tiny note: the following snippet also stopped working. ``` monkeypatch.setattr(FlyingSquid, "_predict", mock_predict) ``` This is because `FlyingSquid` is now decorated by `requires_version`, and this seems to have broken the `monkeypatch.setattr`. As a result, I now perform the `monkeypatch.setattr` on the class instance instead of the class. I want to reiterate that everything from this section only affects the tests: the behaviour in practice works well. Lastly, the tests result in a 97% coverage on the new file. --- **Type of change** - [x] Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** The introduction of various tests under `tests/utils/test_dependency.py` and through updating tests throughout the project. I also ran some test cases manually. **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I added comments to my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --- I'm open to feedback, as always. cc: @davidberenstein1957 @dvsrepo due to the discussion in #2367 - Tom Aarsen --- src/argilla/client/datasets.py | 62 ++----- .../text_classification/label_errors.py | 18 +-- .../text_classification/label_models.py | 45 ++---- src/argilla/monitoring/asgi.py | 23 +-- .../tasks/text_classification/metrics.py | 8 +- src/argilla/utils/dependency.py | 151 ++++++++++++++++++ tests/client/test_dataset.py | 8 +- .../text_classification/test_label_errors.py | 2 +- .../text_classification/test_label_models.py | 28 ++-- tests/utils/test_dependency.py | 107 +++++++++++++ 10 files changed, 321 insertions(+), 131 deletions(-) create mode 100644 src/argilla/utils/dependency.py create mode 100644 tests/utils/test_dependency.py diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py index 35cf564883..e6d80e4f14 100644 --- a/src/argilla/client/datasets.py +++ b/src/argilla/client/datasets.py @@ -12,14 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools import logging import random import uuid from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import pandas as pd -from pkg_resources import parse_version from argilla.client.models import ( Framework, @@ -30,6 +28,7 @@ TokenClassificationRecord, ) from argilla.client.sdk.datasets.models import TaskType +from argilla.utils.dependency import require_version, requires_version from argilla.utils.span_utils import SpanUtils if TYPE_CHECKING: @@ -40,42 +39,6 @@ _LOGGER = logging.getLogger(__name__) -def _requires_datasets(func): - @functools.wraps(func) - def check_if_datasets_installed(*args, **kwargs): - try: - import datasets - except ModuleNotFoundError: - raise ModuleNotFoundError( - f"'datasets' must be installed to use `{func.__name__}`! You can" - " install 'datasets' with the command: `pip install datasets>1.17.0`" - ) - if not (parse_version(datasets.__version__) > parse_version("1.17.0")): - raise ModuleNotFoundError( - "Version >1.17.0 of 'datasets' must be installed to use `to_datasets`!" - " You can update 'datasets' with the command: `pip install -U" - " datasets>1.17.0`" - ) - return func(*args, **kwargs) - - return check_if_datasets_installed - - -def _requires_spacy(func): - @functools.wraps(func) - def check_if_spacy_installed(*args, **kwargs): - try: - import spacy # noqa: F401 - except ModuleNotFoundError: - raise ModuleNotFoundError( - f"'spacy' must be installed to use `{func.__name__}`" - "You can install 'spacy' with the command: `pip install spacy`" - ) - return func(*args, **kwargs) - - return check_if_spacy_installed - - class DatasetBase: """The Dataset classes are containers for argilla records. @@ -156,7 +119,7 @@ def __repr__(self): def __str__(self): return repr(self) - @_requires_datasets + @requires_version("datasets>1.17.0") def to_datasets(self) -> "datasets.Dataset": """Exports your records to a `datasets.Dataset`. @@ -473,6 +436,7 @@ def prepare_for_training( ) elif framework in [Framework.SPACY, Framework.SPARK_NLP]: if train_size and test_size: + require_version("scikit-learn") from sklearn.model_selection import train_test_split records_train, records_test = train_test_split( @@ -498,7 +462,7 @@ def prepare_for_training( else: raise NotImplementedError(f"Framework {framework} is not supported. Choose from:" f" {list(Framework)}") - @_requires_spacy + @requires_version("spacy") def _prepare_for_training_with_spacy( self, **kwargs ) -> Union["spacy.token.DocBin", Tuple["spacy.token.DocBin", "spacy.token.DocBin"]]: @@ -513,7 +477,7 @@ def _prepare_for_training_with_spacy( raise NotImplementedError - @_requires_datasets + @requires_version("datasets>1.17.0") def _prepare_for_training_with_transformers(self, **kwargs) -> "datasets.Dataset": """Prepares the dataset for training using the "transformers" framework. @@ -526,7 +490,7 @@ def _prepare_for_training_with_transformers(self, **kwargs) -> "datasets.Dataset raise NotImplementedError - @_requires_datasets + @requires_version("datasets>1.17.0") def _prepare_for_training_with_spark_nlp(self, **kwargs) -> "datasets.Dataset": """Prepares the dataset for training using the "spark-nlp" framework. @@ -595,6 +559,7 @@ def __init__(self, records: Optional[List[TextClassificationRecord]] = None): super().__init__(records=records) @classmethod + @requires_version("datasets>1.17.0") def from_datasets( cls, dataset: "datasets.Dataset", @@ -745,7 +710,7 @@ def _to_datasets_dict(self) -> Dict: def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTextClassification": return cls([TextClassificationRecord(**row) for row in dataframe.to_dict("records")]) - @_requires_datasets + @requires_version("datasets>1.17.0") def _prepare_for_training_with_transformers( self, train_size: Optional[float] = None, @@ -784,6 +749,7 @@ def _prepare_for_training_with_transformers( ds = datasets.Dataset.from_dict(ds_dict, features=datasets.Features(feature_dict)) if self._records[0].multi_label: + require_version("scikit-learn") from sklearn.preprocessing import MultiLabelBinarizer labels = [rec["label"] for rec in ds] @@ -804,7 +770,7 @@ def _prepare_for_training_with_transformers( return ds - @_requires_spacy + @requires_version("spacy") def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[Record]) -> "spacy.tokens.DocBin": from spacy.tokens import DocBin @@ -903,6 +869,7 @@ def _record_init_args(cls) -> List[str]: return parent_fields + ["tags"] # compute annotation from tags @classmethod + @requires_version("datasets>1.17.0") def from_datasets( cls, dataset: "datasets.Dataset", @@ -980,7 +947,7 @@ def from_pandas( ) -> "DatasetForTokenClassification": return super().from_pandas(dataframe) - @_requires_datasets + @requires_version("datasets>1.17.0") def _prepare_for_training_with_transformers( self, train_size: Optional[float] = None, @@ -1019,7 +986,7 @@ def spans2iob(example): return ds - @_requires_spacy + @requires_version("spacy") def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[Record]) -> "spacy.tokens.DocBin": from spacy.tokens import DocBin @@ -1160,6 +1127,7 @@ def __init__(self, records: Optional[List[Text2TextRecord]] = None): super().__init__(records=records) @classmethod + @requires_version("datasets>1.17.0") def from_datasets( cls, dataset: "datasets.Dataset", @@ -1262,7 +1230,7 @@ def pred_to_dict(pred: Union[str, Tuple[str, float]]): def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForText2Text": return cls([Text2TextRecord(**row) for row in dataframe.to_dict("records")]) - @_requires_datasets + @requires_version("datasets>1.17.0") def prepare_for_training(self, **kwargs) -> "datasets.Dataset": """Prepares the dataset for training. diff --git a/src/argilla/labeling/text_classification/label_errors.py b/src/argilla/labeling/text_classification/label_errors.py index 2460c816cf..a8e6c37e4b 100644 --- a/src/argilla/labeling/text_classification/label_errors.py +++ b/src/argilla/labeling/text_classification/label_errors.py @@ -21,6 +21,7 @@ from argilla.client.datasets import DatasetForTextClassification from argilla.client.models import TextClassificationRecord +from argilla.utils.dependency import requires_version _LOGGER = logging.getLogger(__name__) @@ -39,6 +40,7 @@ def _missing_(cls, value): ) +@requires_version("cleanlab") def find_label_errors( records: Union[List[TextClassificationRecord], DatasetForTextClassification], sort_by: Union[str, SortBy] = "likelihood", @@ -76,18 +78,12 @@ def find_label_errors( >>> records = rg.load("my_dataset") >>> records_with_label_errors = find_label_errors(records) """ - try: - import cleanlab - except ModuleNotFoundError: - raise ModuleNotFoundError( - "'cleanlab' must be installed to use the `find_label_errors` method! " - "You can install 'cleanlab' with the command: `pip install cleanlab`" - ) + import cleanlab + + if parse_version(cleanlab.__version__) < parse_version("2.0"): + from cleanlab.pruning import get_noise_indices as find_label_issues else: - if parse_version(cleanlab.__version__) < parse_version("2.0"): - from cleanlab.pruning import get_noise_indices as find_label_issues - else: - from cleanlab.filter import find_label_issues + from cleanlab.filter import find_label_issues if isinstance(sort_by, str): sort_by = SortBy(sort_by) diff --git a/src/argilla/labeling/text_classification/label_models.py b/src/argilla/labeling/text_classification/label_models.py index 705894549b..468c60527a 100644 --- a/src/argilla/labeling/text_classification/label_models.py +++ b/src/argilla/labeling/text_classification/label_models.py @@ -21,6 +21,7 @@ from argilla import DatasetForTextClassification, TextClassificationRecord from argilla.labeling.text_classification.weak_labels import WeakLabels, WeakMultiLabels +from argilla.utils.dependency import requires_version _LOGGER = logging.getLogger(__name__) @@ -310,6 +311,7 @@ def _make_multi_label_records( return records_with_prediction + @requires_version("scikit-learn") def score( self, tie_break_policy: Union[TieBreakPolicy, str] = "abstain", @@ -344,13 +346,6 @@ def score( Raises: MissingAnnotationError: If the ``weak_labels`` do not contain annotated records. """ - try: - import sklearn # noqa: F401 - except ModuleNotFoundError: - raise ModuleNotFoundError( - "'sklearn' must be installed to compute the metrics! " - "You can install 'sklearn' with the command: `pip install scikit-learn`" - ) from sklearn.metrics import classification_report wl_matrix = self._weak_labels.matrix(has_annotation=True) @@ -443,6 +438,7 @@ def _score_multi_label(self, probabilities: np.ndarray) -> Tuple[np.ndarray, np. return annotation, prediction +@requires_version("snorkel") class Snorkel(LabelModel): """The label model by `Snorkel `__. @@ -463,15 +459,7 @@ class Snorkel(LabelModel): """ def __init__(self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu"): - try: - import snorkel # noqa: F401 - except ModuleNotFoundError: - raise ModuleNotFoundError( - "'snorkel' must be installed to use the `Snorkel` label model! " - "You can install 'snorkel' with the command: `pip install snorkel`" - ) - else: - from snorkel.labeling.model import LabelModel as SnorkelLabelModel + from snorkel.labeling.model import LabelModel as SnorkelLabelModel super().__init__(weak_labels) @@ -618,6 +606,7 @@ def predict( return DatasetForTextClassification(records_with_prediction) + @requires_version("scikit-learn") def score( self, tie_break_policy: Union[TieBreakPolicy, str] = "abstain", @@ -689,6 +678,8 @@ def score( ) +@requires_version("flyingsquid") +@requires_version("pgmpy") class FlyingSquid(LabelModel): """The label model by `FlyingSquid `__. @@ -708,19 +699,9 @@ class FlyingSquid(LabelModel): """ def __init__(self, weak_labels: WeakLabels, **kwargs): - try: - import flyingsquid # noqa: F401 - import pgmpy # noqa: F401 - except ModuleNotFoundError: - raise ModuleNotFoundError( - "'flyingsquid' must be installed to use the `FlyingSquid` label model!" - " You can install 'flyingsquid' with the command: `pip install pgmpy" - " flyingsquid`" - ) - else: - from flyingsquid.label_model import LabelModel as FlyingSquidLabelModel + from flyingsquid.label_model import LabelModel as FlyingSquidLabelModel - self._FlyingSquidLabelModel = FlyingSquidLabelModel + self._FlyingSquidLabelModel = FlyingSquidLabelModel super().__init__(weak_labels) @@ -895,6 +876,7 @@ def _predict(self, weak_label_matrix: np.ndarray, verbose: bool) -> np.ndarray: return probas + @requires_version("scikit-learn") def score( self, tie_break_policy: Union[TieBreakPolicy, str] = "abstain", @@ -932,13 +914,6 @@ def score( NotFittedError: If the label model was still not fitted. MissingAnnotationError: If the ``weak_labels`` do not contain annotated records. """ - try: - import sklearn # noqa: F401 - except ModuleNotFoundError: - raise ModuleNotFoundError( - "'sklearn' must be installed to compute the metrics! " - "You can install 'sklearn' with the command: `pip install scikit-learn`" - ) from sklearn.metrics import classification_report if isinstance(tie_break_policy, str): diff --git a/src/argilla/monitoring/asgi.py b/src/argilla/monitoring/asgi.py index deabe0ef47..ea5e299c9a 100644 --- a/src/argilla/monitoring/asgi.py +++ b/src/argilla/monitoring/asgi.py @@ -18,26 +18,19 @@ import re from typing import Any, Callable, Dict, List, Optional, Tuple -from argilla.monitoring.base import BaseMonitor - -try: - import starlette # noqa: F401 -except ModuleNotFoundError: - raise ModuleNotFoundError( - "'starlette' must be installed to use the middleware feature! " - "You can install 'starlette' with the command: `pip install starlette>=0.13.0`" - ) -else: - from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint - from starlette.requests import Request - from starlette.responses import JSONResponse, Response, StreamingResponse - from starlette.types import Message, Receive - from argilla.client.models import ( Record, TextClassificationRecord, TokenClassificationRecord, ) +from argilla.monitoring.base import BaseMonitor +from argilla.utils.dependency import require_version + +require_version("starlette>=0.13.0") +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse +from starlette.types import Message, Receive _logger = logging.getLogger(__name__) _default_tokenization_pattern = re.compile(r"\W+") diff --git a/src/argilla/server/services/tasks/text_classification/metrics.py b/src/argilla/server/services/tasks/text_classification/metrics.py index 2b9bd006a9..38caf57e21 100644 --- a/src/argilla/server/services/tasks/text_classification/metrics.py +++ b/src/argilla/server/services/tasks/text_classification/metrics.py @@ -15,8 +15,6 @@ from typing import Any, ClassVar, Dict, Iterable, List from pydantic import Field -from sklearn.metrics import precision_recall_fscore_support -from sklearn.preprocessing import MultiLabelBinarizer from argilla.server.services.metrics import ServiceBaseMetric, ServicePythonMetric from argilla.server.services.metrics.models import CommonTasksMetrics @@ -24,6 +22,7 @@ from argilla.server.services.tasks.text_classification.model import ( ServiceTextClassificationRecord, ) +from argilla.utils.dependency import requires_version class F1Metric(ServicePythonMetric): @@ -38,7 +37,10 @@ class F1Metric(ServicePythonMetric): multi_label: bool = False + @requires_version("scikit-learn") def apply(self, records: Iterable[ServiceTextClassificationRecord]) -> Any: + from sklearn.metrics import precision_recall_fscore_support + filtered_records = list(filter(lambda r: r.predicted is not None, records)) # TODO: This must be precomputed with using a global dataset metric ds_labels = {label for record in filtered_records for label in record.annotated_as + record.predicted_as} @@ -61,6 +63,8 @@ def apply(self, records: Iterable[ServiceTextClassificationRecord]) -> Any: y_pred.append([labels_mapping[label] for label in predictions]) if self.multi_label: + from sklearn.preprocessing import MultiLabelBinarizer + mlb = MultiLabelBinarizer(classes=list(labels_mapping.values())) y_true = mlb.fit_transform(y_true) y_pred = mlb.fit_transform(y_pred) diff --git a/src/argilla/utils/dependency.py b/src/argilla/utils/dependency.py new file mode 100644 index 0000000000..6883fcb4a2 --- /dev/null +++ b/src/argilla/utils/dependency.py @@ -0,0 +1,151 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import operator +import re +import sys +from typing import Optional + +from packaging import version + +# This file was adapted from Hugging Face's wonderful transformers module + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +ops = { + "<": operator.lt, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + ">=": operator.ge, + ">": operator.gt, +} + + +def _compare_versions( + op: str, + got_version: Optional[str], + want_version: Optional[str], + requirement: str, + package: str, + func_name: Optional[str], +): + if got_version is None or want_version is None: + raise ValueError( + f"Unable to compare versions for {requirement}: need={want_version} found={got_version}. This is unusual. Consider" + f" reinstalling {package}." + ) + if not ops[op](version.parse(got_version), version.parse(want_version)): + raise ImportError( + f"{requirement} must be installed{f' to use `{func_name}`' if func_name else ''}, but found {package}=={got_version}." + f" You can install a supported version of '{package}' with this command: `pip install -U {requirement}`" + ) + + +def require_version(requirement: str, func_name: Optional[str] = None) -> None: + """ + Perform a runtime check of the dependency versions, using the exact same syntax used by pip. + The installed module version comes from the *site-packages* dir via *importlib_metadata*. + + Args: + requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" + func_name (`str`, *optional*): what suggestion to print in case of requirements not being met + + Example: + ```python + require_version("pandas>1.1.2") + require_version("datasets>1.17.0", "from_datasets") + ``` + """ + + # non-versioned check + if re.match(r"^[\w_\-\d]+$", requirement): + package, op, want_version = requirement, None, None + else: + match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement) + if not match: + raise ValueError( + "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but" + f" got {requirement!r}." + ) + package, want_full = match[0] + want_range = want_full.split(",") # there could be multiple requirements + wanted = {} + for w in want_range: + match = re.findall(r"^([\s!=<>]{1,2})(.+)", w) + if not match: + raise ValueError( + "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23," + f" but got {requirement!r}." + ) + op, want_version = match[0] + wanted[op] = want_version + if op not in ops: + raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op!r}.") + + # special case + if package == "python": + got_version = ".".join([str(x) for x in sys.version_info[:3]]) + for op, want_version in wanted.items(): + _compare_versions(op, got_version, want_version, requirement, package, func_name=func_name) + return + + # check if any version is installed + try: + got_version = importlib_metadata.version(package) + except importlib_metadata.PackageNotFoundError: + raise ModuleNotFoundError( + f"'{package}' must be installed{f' to use `{func_name}`' if func_name else ''}! You can" + f" install '{package}' with this command: `pip install {requirement}`" + ) + + # check that the right version is installed if version number or a range was provided + if want_version is not None: + for op, want_version in wanted.items(): + _compare_versions(op, got_version, want_version, requirement, package, func_name=func_name) + + +def requires_decorator(requirement, func): + @functools.wraps(func) + def check_if_installed(*args, **kwargs): + require_version(requirement, func.__name__) + return func(*args, **kwargs) + + return check_if_installed + + +def requires_version(requirement): + """Decorator variant of `require_version`. + Perform a runtime check of the dependency versions, using the exact same syntax used by pip. + The installed module version comes from the *site-packages* dir via *importlib_metadata*. + + Args: + requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" + + Example: + ```python + @requires_version("datasets>1.17.0") + def from_datasets(self, ...): + ... + ``` + """ + return functools.partial(requires_decorator, requirement) + + +__all__ = ["requires_version", "require_version"] diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py index 050bdfee5f..a5c67bc68e 100644 --- a/tests/client/test_dataset.py +++ b/tests/client/test_dataset.py @@ -157,16 +157,10 @@ def test_to_datasets(self, monkeypatch, caplog): def test_datasets_not_installed(self, monkeypatch): monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", "mock") - monkeypatch.setitem(sys.modules, "datasets", None) + monkeypatch.setattr(sys, "meta_path", [], raising=False) with pytest.raises(ModuleNotFoundError, match="pip install datasets>1.17.0"): DatasetBase().to_datasets() - def test_datasets_wrong_version(self, monkeypatch): - monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", "mock") - monkeypatch.setattr("datasets.__version__", "1.16.0") - with pytest.raises(ModuleNotFoundError, match="pip install -U datasets>1.17.0"): - DatasetBase().to_datasets() - def test_iter_len_getitem(self, monkeypatch, singlelabel_textclassification_records): monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord) dataset = DatasetBase(singlelabel_textclassification_records) diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py index 612933aa36..00d9eff1d4 100644 --- a/tests/labeling/text_classification/test_label_errors.py +++ b/tests/labeling/text_classification/test_label_errors.py @@ -56,7 +56,7 @@ def test_sort_by_enum(): def test_not_installed(monkeypatch): - monkeypatch.setitem(sys.modules, "cleanlab", None) + monkeypatch.setattr(sys, "meta_path", [], raising=False) with pytest.raises(ModuleNotFoundError, match="pip install cleanlab"): find_label_errors(None) diff --git a/tests/labeling/text_classification/test_label_models.py b/tests/labeling/text_classification/test_label_models.py index 4cad2d86f4..68349bd453 100644 --- a/tests/labeling/text_classification/test_label_models.py +++ b/tests/labeling/text_classification/test_label_models.py @@ -270,7 +270,7 @@ def test_make_multi_label_records(self, weak_multi_labels, include_abstentions, assert records[2].prediction is None def test_score_sklearn_not_installed(self, monkeypatch, weak_labels): - monkeypatch.setitem(sys.modules, "sklearn", None) + monkeypatch.setattr(sys, "meta_path", [], raising=False) mj = MajorityVoter(weak_labels) with pytest.raises(ModuleNotFoundError, match="pip install scikit-learn"): @@ -354,7 +354,7 @@ def test_score_multi_label(self, weak_multi_labels): class TestSnorkel: def test_not_installed(self, monkeypatch): - monkeypatch.setitem(sys.modules, "snorkel", None) + monkeypatch.setattr(sys, "meta_path", [], raising=False) with pytest.raises(ModuleNotFoundError, match="pip install snorkel"): Snorkel(None) @@ -540,8 +540,8 @@ def test_integration(self, weak_labels_from_guide, change_mapping): class TestFlyingSquid: def test_not_installed(self, monkeypatch): - monkeypatch.setitem(sys.modules, "flyingsquid", None) - with pytest.raises(ModuleNotFoundError, match="pip install pgmpy flyingsquid"): + monkeypatch.setattr(sys, "meta_path", [], raising=False) + with pytest.raises(ModuleNotFoundError, match="pip install flyingsquid"): FlyingSquid(None) def test_init(self, weak_labels): @@ -720,22 +720,23 @@ def test_score_not_fitted_error(self, weak_labels): with pytest.raises(NotFittedError, match="not fitted yet"): label_model.score() - def test_score_sklearn_not_installed(self, monkeypatch, weak_labels): - monkeypatch.setitem(sys.modules, "sklearn", None) - + def test_score_sklearn_not_installed(self, monkeypatch: pytest.MonkeyPatch, weak_labels): label_model = FlyingSquid(weak_labels) + + monkeypatch.setattr(sys, "meta_path", [], raising=False) with pytest.raises(ModuleNotFoundError, match="pip install scikit-learn"): label_model.score() def test_score(self, monkeypatch, weak_labels): - def mock_predict(self, weak_label_matrix, verbose): + def mock_predict(weak_label_matrix, verbose): assert verbose is False assert len(weak_label_matrix) == 3 return np.array([[0.8, 0.1, 0.1], [0.1, 0.8, 0.1], [0.1, 0.1, 0.8]]) - monkeypatch.setattr(FlyingSquid, "_predict", mock_predict) - label_model = FlyingSquid(weak_labels) + # We have to monkeypatch the instance rather than the class due to decorators + # on the class + monkeypatch.setattr(label_model, "_predict", mock_predict) metrics = label_model.score() assert "accuracy" in metrics @@ -746,14 +747,15 @@ def mock_predict(self, weak_label_matrix, verbose): @pytest.mark.parametrize("tbp,vrb,expected", [("abstain", False, 1.0), ("random", True, 2 / 3.0)]) def test_score_tbp(self, monkeypatch, weak_labels, tbp, vrb, expected): - def mock_predict(self, weak_label_matrix, verbose): + def mock_predict(weak_label_matrix, verbose): assert verbose is vrb assert len(weak_label_matrix) == 3 return np.array([[0.8, 0.1, 0.1], [0.4, 0.4, 0.2], [1 / 3.0, 1 / 3.0, 1 / 3.0]]) - monkeypatch.setattr(FlyingSquid, "_predict", mock_predict) - label_model = FlyingSquid(weak_labels) + + monkeypatch.setattr(label_model, "_predict", mock_predict) + metrics = label_model.score(tie_break_policy=tbp, verbose=vrb) assert metrics["accuracy"] == pytest.approx(expected) diff --git a/tests/utils/test_dependency.py b/tests/utils/test_dependency.py new file mode 100644 index 0000000000..b5f33f5626 --- /dev/null +++ b/tests/utils/test_dependency.py @@ -0,0 +1,107 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import sys + +import pytest +from argilla.utils.dependency import ( + require_version, + requires_version, +) + + +class TestDependencyRequirements: + @pytest.mark.parametrize( + ("decorator", "package_name", "import_name", "version"), + [ + (requires_version("datasets>1.17.0"), "datasets", "datasets", ">1.17.0"), + (requires_version("spacy"), "spacy", "spacy", ""), + (requires_version("scikit-learn"), "scikit-learn", "sklearn", ""), + (requires_version("faiss"), "faiss", "faiss", ""), + ], + ) + def test_missing_dependency_decorator( + self, monkeypatch: pytest.MonkeyPatch, decorator, package_name: str, import_name: str, version: str + ): + monkeypatch.setitem(sys.modules, import_name, None) + monkeypatch.setattr(sys, "meta_path", [], raising=False) + + # Ensure that the package indeed cannot be imported due to the monkeypatch + with pytest.raises(ModuleNotFoundError): + importlib.import_module(import_name) + + @decorator + def test_inner(): + pass + + requirement = package_name + version + # Verify that the decorator does its work and shows the desired output with `pip install ...` + with pytest.raises( + ModuleNotFoundError, + match=f"'{package_name}' must be installed to use `test_inner`!.*?`pip install {requirement}`", + ): + test_inner() + + @pytest.mark.parametrize( + ("decorator"), + [ + requires_version("datasets>1.17.0"), + requires_version("spacy"), + requires_version("scikit-learn"), + ], + ) + def test_installed_dependency_decorator(self, decorator): + # Ensure that the decorated function can be called just fine if the dependencies are installed, + # which they should be for these tests + + @decorator + def test_inner(): + return True + + assert test_inner() + + def test_installed_dependency_but_incorrect_version(self): + def test_inner(): + require_version("datasets<1.0.0") + return True + + # This method should fail, as our dependencies require a higher version of datasets + with pytest.raises( + ImportError, + match=f"but found datasets==.*?You can install a supported version of 'datasets' with this command: `pip install -U datasets<1.0.0`", + ): + test_inner() + + def test_require_version_failures(self): + # This operation is not supported + with pytest.raises(ValueError): + require_version("datasets~=1.0.0") + + # Add some unsupported tokens, e.g. " " + with pytest.raises(ValueError): + require_version(" datasets") + + # Add unsupported operation in second requirement version + with pytest.raises(ValueError): + require_version("datasets>1.0.0,~1.17.0") + + def test_special_python_case(self): + require_version("python>3.6") + + def test_multiple_version_requirements(self): + # This is equivalent to just datasets>1.17.0, but we expect it to work still + require_version("datasets>1.0.0,>1.8.0,>1.17.0") + # A more common example (designed not to break eventually): + require_version("datasets>1.17.0,<1000.0.0") From fd728347919d314fef98faee2d77f0444a7baa50 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Feb 2023 04:35:15 +0000 Subject: [PATCH 23/45] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/charliermarsh/ruff-pre-commit: v0.0.249 → v0.0.253](https://github.com/charliermarsh/ruff-pre-commit/compare/v0.0.249...v0.0.253) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0f8bffe411..0395640409 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: additional_dependencies: ["click==8.0.4"] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.249 + rev: v0.0.253 hooks: # Simulate isort via (the much faster) ruff - id: ruff From 4a92b3510a9c1703cddd925782bf5d57c4d28259 Mon Sep 17 00:00:00 2001 From: Ceyda Cinarel <15624271+cceyda@users.noreply.github.com> Date: Tue, 28 Feb 2023 20:58:42 +0900 Subject: [PATCH 24/45] feat: Extend shortcuts to include alphabet for token classification (#2339) # Description This is a stop-gap solution to issue https://github.com/argilla-io/argilla/issues/1852 A more sophisticated solution can be found but in the mean time this solves the inconvenience when there are tags more than 10. **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [x] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [ ] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I added comments to my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works --- frontend/components/token-classifier/header/EntitiesHeader.vue | 2 +- frontend/components/token-classifier/results/TextSpan.vue | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/components/token-classifier/header/EntitiesHeader.vue b/frontend/components/token-classifier/header/EntitiesHeader.vue index 7a0acbfb9a..c4614b5b86 100755 --- a/frontend/components/token-classifier/header/EntitiesHeader.vue +++ b/frontend/components/token-classifier/header/EntitiesHeader.vue @@ -53,7 +53,7 @@ export default { }), computed: { visibleEntities() { - const characters = "1234567890".split(""); + const characters = "1234567890QWERTYUIOPASDFGHJKLZXCVBNM".split(""); let entities = [...this.dataset.entities] .sort((a, b) => a.text.localeCompare(b.text)) .map((ent, index) => ({ diff --git a/frontend/components/token-classifier/results/TextSpan.vue b/frontend/components/token-classifier/results/TextSpan.vue index 806689a82a..232ce210ee 100755 --- a/frontend/components/token-classifier/results/TextSpan.vue +++ b/frontend/components/token-classifier/results/TextSpan.vue @@ -111,7 +111,7 @@ export default { .sort((a, b) => a.text.localeCompare(b.text)); }, formattedEntities() { - const characters = "1234567890".split(""); + const characters = "1234567890QWERTYUIOPASDFGHJKLZXCVBNM".split(""); return this.filteredEntities.map((ent, index) => ({ ...ent, shortCut: characters[index], From f5834a5408051bf1fa60d42aede6b325dc07fdbd Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 1 Mar 2023 14:04:05 +0100 Subject: [PATCH 25/45] refactor: Improve efficiency of `.scan` (and `.load`) if `limit` is set (#2434) Hello! ## Pull Request overview * Refactor `Datasets.scan` to never request more samples than required. * Implements a `batch_size` variable which is now set to `self.DEFAULT_SCAN_SIZE`. This parameter can easily be exposed in the future for `batch_size` support, as discussed previously in Slack etc. ## Details The core of the changes is that the URL is now built on-the-fly at every request, rather than just once. For each request, the limit for that request is computed using `min(limit, batch_size)`, and `limit` is decremented, allowing it to represent a "remaining samples to fetch". When `limit` reaches 0, i.e. 0 more samples to fetch, we return from `scan`. I've also added a check to ensure that the `limit` passed to `scan` can't negative. ## Tests For tests, I've created a `test_scan_efficient_limiting` function which verifies the new and improved behaviour. It contains two monkeypatches: 1. `DEFAULT_SCAN_SIZE` is set to 10. Because our dataset in this test only has 100 samples, we want to ensure that we can't just sample the entire dataset with one request. 2. The `http_client.post` method is monkeypatched to allow us to track the calls to the server. We test the following scenarios: * `limit=23` -> 3 requests: for 10, 10 and 3 samples. * `limit=20` -> 2 requests: for 10, 10 samples. * `limit=6` -> 1 request: for 6 samples. There's also a test to cover the new ValueError if `limit` < 0. ## Effects Consider the following script: ```python import argilla as rg import time def test_scan_records(fields): client = rg.active_client() data = client.datasets.scan( name="banking77-topics-setfit", projection=fields, limit=1, # <- Note, just one sample ) start_t = time.time() list(data) print(f"load time for 1 sample with fields={fields}: {time.time() - start_t:.4f}s") test_scan_records(set()) test_scan_records({"text"}) test_scan_records({"tokens"}) ``` On this PR, this outputs: ``` load time for 1 sample with fields=set(): 0.0774s load time for 1 sample with fields={'text'}: 0.0646s load time for 1 sample with fields={'tokens'}: 0.0669s ``` On the `develop` branch, this outputs: ``` load time for 1 sample with fields=set(): 0.1355s load time for 1 sample with fields={'text'}: 0.1347s load time for 1 sample with fields={'tokens'}: 0.1173s ``` This can be repeated for `rg.load(..., limit=1)`, as that relies on `.scan` under the hood. Note that this is the most extreme example of performance gains. The performance increases in almost all scenarios if `limit` is set, but the effects are generally not this big. Going from fetching 250 samples 8 times to fetching 250 samples 7 times and 173 once doesn't have as big of an impact. --- **Type of change** - [x] Refactor (change restructuring the codebase without changing functionality) **How Has This Been Tested** See above. **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --- - Tom Aarsen --- src/argilla/client/apis/datasets.py | 33 +++++------ .../functional_tests/test_scan_raw_records.py | 55 +++++++++++++++++++ 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py index fb1135b1e7..cf717a4089 100644 --- a/src/argilla/client/apis/datasets.py +++ b/src/argilla/client/apis/datasets.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import warnings from dataclasses import dataclass from datetime import datetime @@ -163,23 +164,23 @@ def scan( name: the dataset query: the search query projection: a subset of record fields to retrieve. If not provided, - limit: The number of records to retrieve + only id's will be returned + limit: The number of records to retrieve. id_from: If provided, starts gathering the records starting from that Record. As the Records returned with the load method are sorted by ID, ´id_from´ can be used to load using batches. - only id's will be returned Returns: - An iterable of raw object containing per-record info - """ - url = f"{self._API_PREFIX}/{name}/records/:search?limit={self.DEFAULT_SCAN_SIZE}" - query = self._parse_query(query=query) + if limit and limit < 0: + raise ValueError("The scan limit must be non-negative.") - if limit == 0: - limit = None + batch_size = self.DEFAULT_SCAN_SIZE + limit = limit if limit else math.inf + url = f"{self._API_PREFIX}/{name}/records/:search?limit={{limit}}" + query = self._parse_query(query=query) request = { "fields": list(projection) if projection else ["id"], @@ -189,24 +190,24 @@ def scan( if id_from: request["next_idx"] = id_from - yield_fields = 0 with api_compatibility(self, min_version="1.2.0"): + request_limit = min(limit, batch_size) response = self.http_client.post( - url, + url.format(limit=request_limit), json=request, ) while response.get("records"): - for record in response["records"]: - yield record - yield_fields += 1 - if limit and limit <= yield_fields: - return + yield from response["records"] + limit -= request_limit + if limit <= 0: + return next_idx = response.get("next_idx") if next_idx: + request_limit = min(limit, batch_size) response = self.http_client.post( - path=url, + path=url.format(limit=request_limit), json={**request, "next_idx": next_idx}, ) diff --git a/tests/client/functional_tests/test_scan_raw_records.py b/tests/client/functional_tests/test_scan_raw_records.py index 7af6af9b6f..ea3224b3f3 100644 --- a/tests/client/functional_tests/test_scan_raw_records.py +++ b/tests/client/functional_tests/test_scan_raw_records.py @@ -67,3 +67,58 @@ def test_scan_records_without_results( ) data = list(data) assert len(data) == 0 + + +def test_scan_fail_negative_limit( + mocked_client, + gutenberg_spacy_ner, +): + with pytest.raises(ValueError, match="limit.*negative"): + data = active_api().datasets.scan( + name=gutenberg_spacy_ner, + limit=-20, + ) + # Actually load the generator its data + data = list(data) + + +@pytest.mark.parametrize(("limit"), [6, 23, 20]) +def test_scan_efficient_limiting( + monkeypatch: pytest.MonkeyPatch, + limit, + gutenberg_spacy_ner, +): + client_datasets = active_api().datasets + # Reduce the default scan size to something small to better test the situation + # where limit > DEFAULT_SCAN_SIZE + batch_size = 10 + monkeypatch.setattr(client_datasets, "DEFAULT_SCAN_SIZE", batch_size) + + # Monkeypatch the .post() call to track with what URLs the server is called + called_paths = [] + original_post = active_api().http_client.post + + def tracked_post(path, *args, **kwargs): + called_paths.append(path) + return original_post(path, *args, **kwargs) + + monkeypatch.setattr(active_api().http_client, "post", tracked_post) + + # Try to fetch `limit` samples from the 100 + data = client_datasets.scan(name=gutenberg_spacy_ner, limit=limit) + data = list(data) + + # Ensure that `limit` samples were indeed fetched + assert len(data) == limit + # Ensure that the samples were fetched in the expected number of requests + # Equivalent to math.upper(limit / batch_size): + assert len(called_paths) == (limit - 1) // batch_size + 1 + + if limit % batch_size == 0: + # If limit is divisible by batch_size, then we expect all calls to have a limit of batch_size + assert all(path.endswith(f"?limit={batch_size}") for path in called_paths) + else: + # Otherwise, expect all calls except for the last one to have a limit of batch_size + # while the last one has limit limit % batch_size + assert all(path.endswith(f"?limit={batch_size}") for path in called_paths[:-1]) + assert called_paths[-1].endswith(f"?limit={limit % batch_size}") From d789fa1570e70cd0fdeb9d9ec0b4a7ac1c12f5aa Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 2 Mar 2023 12:07:10 +0100 Subject: [PATCH 26/45] fix: added regex match to set workspace method (#2427) # Description I re-used the logic for the dataset regex check for the workspace declaration too. Closes #2388 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [X] Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [X] [Test A](https://github.com/argilla-io/argilla/blob/5138961e4148749dca85e647fa32d8c9a13f87f2/tests/client/test_api.py#L97) **Checklist** --- src/argilla/_constants.py | 2 +- src/argilla/client/client.py | 17 +++++++++++++---- .../server/apis/v0/models/commons/params.py | 4 ++-- src/argilla/server/apis/v0/models/datasets.py | 4 ++-- src/argilla/server/daos/models/datasets.py | 4 ++-- src/argilla/server/security/model.py | 6 +++--- tests/client/test_api.py | 5 +++++ 7 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/argilla/_constants.py b/src/argilla/_constants.py index b721488e1b..3f7af54304 100644 --- a/src/argilla/_constants.py +++ b/src/argilla/_constants.py @@ -26,4 +26,4 @@ _OLD_API_KEY_HEADER_NAME = "X-Rubrix-Api-Key" _OLD_WORKSPACE_HEADER_NAME = "X-Rubrix-Workspace" -DATASET_NAME_REGEX_PATTERN = r"^(?!-|_)[a-z0-9-_]+$" +ES_INDEX_REGEX_PATTERN = r"^(?!-|_)[a-z0-9-_]+$" diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 66003c8729..1df03aca60 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -18,15 +18,15 @@ import re import warnings from asyncio import Future -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from rich import print as rprint from rich.progress import Progress from argilla._constants import ( _OLD_WORKSPACE_HEADER_NAME, - DATASET_NAME_REGEX_PATTERN, DEFAULT_API_KEY, + ES_INDEX_REGEX_PATTERN, WORKSPACE_HEADER_NAME, ) from argilla.client.apis.datasets import Datasets @@ -208,6 +208,15 @@ def set_workspace(self, workspace: str): if not workspace: raise Exception("Must provide a workspace") + if not re.match(ES_INDEX_REGEX_PATTERN, workspace): + raise InputValueError( + f"Provided workspace name {workspace} does not match the pattern" + f" {ES_INDEX_REGEX_PATTERN}. Please, use a valid name for your" + " workspace. This limitation is caused by naming conventions for indexes" + " in Elasticsearch. If applicable, you can try to lowercase the name of your workspace." + " https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html" + ) + if workspace != self.get_workspace(): if workspace == self._user.username: self._client.headers.pop(WORKSPACE_HEADER_NAME, workspace) @@ -326,10 +335,10 @@ async def log_async( if not name: raise InputValueError("Empty dataset name has been passed as argument.") - if not re.match(DATASET_NAME_REGEX_PATTERN, name): + if not re.match(ES_INDEX_REGEX_PATTERN, name): raise InputValueError( f"Provided dataset name {name} does not match the pattern" - f" {DATASET_NAME_REGEX_PATTERN}. Please, use a valid name for your" + f" {ES_INDEX_REGEX_PATTERN}. Please, use a valid name for your" " dataset. This limitation is caused by naming conventions for indexes" " in Elasticsearch." " https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html" diff --git a/src/argilla/server/apis/v0/models/commons/params.py b/src/argilla/server/apis/v0/models/commons/params.py index 4bd92aa4e2..68b5842649 100644 --- a/src/argilla/server/apis/v0/models/commons/params.py +++ b/src/argilla/server/apis/v0/models/commons/params.py @@ -18,12 +18,12 @@ from argilla._constants import ( _OLD_WORKSPACE_HEADER_NAME, - DATASET_NAME_REGEX_PATTERN, + ES_INDEX_REGEX_PATTERN, WORKSPACE_HEADER_NAME, ) from argilla.server.security.model import WORKSPACE_NAME_PATTERN -DATASET_NAME_PATH_PARAM = Path(..., regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name") +DATASET_NAME_PATH_PARAM = Path(..., regex=ES_INDEX_REGEX_PATTERN, description="The dataset name") @dataclass diff --git a/src/argilla/server/apis/v0/models/datasets.py b/src/argilla/server/apis/v0/models/datasets.py index 7590f4ed13..65e4a2cc30 100644 --- a/src/argilla/server/apis/v0/models/datasets.py +++ b/src/argilla/server/apis/v0/models/datasets.py @@ -21,7 +21,7 @@ from pydantic import BaseModel, Field -from argilla._constants import DATASET_NAME_REGEX_PATTERN +from argilla._constants import ES_INDEX_REGEX_PATTERN from argilla.server.commons.models import TaskType from argilla.server.services.datasets import ServiceBaseDataset @@ -43,7 +43,7 @@ class UpdateDatasetRequest(BaseModel): class _BaseDatasetRequest(UpdateDatasetRequest): - name: str = Field(regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name") + name: str = Field(regex=ES_INDEX_REGEX_PATTERN, description="The dataset name") class CreateDatasetRequest(_BaseDatasetRequest): diff --git a/src/argilla/server/daos/models/datasets.py b/src/argilla/server/daos/models/datasets.py index 03f690f777..6b3c55f6cf 100644 --- a/src/argilla/server/daos/models/datasets.py +++ b/src/argilla/server/daos/models/datasets.py @@ -17,12 +17,12 @@ from pydantic import BaseModel, Field, validator -from argilla._constants import DATASET_NAME_REGEX_PATTERN +from argilla._constants import ES_INDEX_REGEX_PATTERN from argilla.server.commons.models import TaskType class BaseDatasetDB(BaseModel): - name: str = Field(regex=DATASET_NAME_REGEX_PATTERN) + name: str = Field(regex=ES_INDEX_REGEX_PATTERN) task: TaskType owner: Optional[str] = Field(description="Deprecated. Use `workspace` instead. Will be removed in v1.5.0") workspace: Optional[str] = None diff --git a/src/argilla/server/security/model.py b/src/argilla/server/security/model.py index 007f043507..5e7f77a491 100644 --- a/src/argilla/server/security/model.py +++ b/src/argilla/server/security/model.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field, root_validator, validator -from argilla._constants import DATASET_NAME_REGEX_PATTERN +from argilla._constants import ES_INDEX_REGEX_PATTERN from argilla.server.errors import EntityNotFoundError WORKSPACE_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_\-]*$") @@ -37,9 +37,9 @@ class User(BaseModel): @validator("username") def check_username(cls, value): - if not re.compile(DATASET_NAME_REGEX_PATTERN).match(value): + if not re.compile(ES_INDEX_REGEX_PATTERN).match(value): raise ValueError( - "Wrong username. " f"The username {value} does not match the pattern {DATASET_NAME_REGEX_PATTERN}" + "Wrong username. " f"The username {value} does not match the pattern {ES_INDEX_REGEX_PATTERN}" ) return value diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 822a8e6607..d9de6616cb 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -94,6 +94,11 @@ def mock_get(*args, **kwargs): monkeypatch.setattr(users_api, "whoami", mock_get) +def test_init_uppercase_workspace(mocked_client): + with pytest.raises(InputValueError): + api.init(workspace="UPPERCASE_WORKSPACE") + + def test_init_correct(mock_response_200): """Testing correct default initialization From fc71c3b538e767133ec4fa53333a76490e0d9449 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Thu, 2 Mar 2023 12:24:23 +0100 Subject: [PATCH 27/45] fix: error when loading record with empty string query (#2429) # Description empty query strings were not returning any records. These are currently set to None and therefore return records again. Closes #2400 Closes #2303 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [X] Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [x] [Test A](https://github.com/argilla-io/argilla/blob/34da84b5b0ac2c71d17459519915630dda7393a2/tests/client/test_api.py#L179) **Checklist** N.A. --- .../installation/deployments/docker.md | 3 ++- src/argilla/server/daos/backend/search/model.py | 9 ++++++++- tests/client/test_api.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/docs/_source/getting_started/installation/deployments/docker.md b/docs/_source/getting_started/installation/deployments/docker.md index b24d451760..d967392fd8 100644 --- a/docs/_source/getting_started/installation/deployments/docker.md +++ b/docs/_source/getting_started/installation/deployments/docker.md @@ -9,12 +9,13 @@ First, you need to create a network to make both standalone containers visibles Just run the folowing command: ```bash docker network create argilla-net +``` Setting up Elasticsearch (ES) via docker is straightforward. Simply run the following command: ```bash -docker run -d --name elasticsearch-for-argilla --network argilla-net -p 9200:9200 -p 9300:9300 -e "ES_JAVA_OPTS=-Xms512m -Xmx512m" -e "discovery.type=single-node" docker.elastic.co/elasticsearch/elasticsearch:8.5.3 +docker run -d --name elasticsearch-for-argilla --network argilla-net -p 9200:9200 -p 9300:9300 -e "ES_JAVA_OPTS=-Xms512m -Xmx512m" -e "discovery.type=single-node" -e "xpack.security.enabled=false" docker.elastic.co/elasticsearch/elasticsearch:8.5.3 ``` This will create an ES docker container named *"elasticsearch-for-argilla"* that will run in the background. diff --git a/src/argilla/server/daos/backend/search/model.py b/src/argilla/server/daos/backend/search/model.py index 3db5db977e..2f9941d8a1 100644 --- a/src/argilla/server/daos/backend/search/model.py +++ b/src/argilla/server/daos/backend/search/model.py @@ -15,7 +15,7 @@ from enum import Enum from typing import Any, Dict, List, Optional, TypeVar, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator from argilla.server.commons.models import TaskStatus @@ -94,6 +94,13 @@ class BaseRecordsQuery(BaseQuery): vector: Optional[VectorSearch] = Field(default=None) + @validator("query_text") + def check_empty_query_text(cls, value): + if value is not None: + if value.strip() == "": + value = None + return value + BackendQuery = TypeVar("BackendQuery", bound=BaseQuery) BackendRecordsQuery = TypeVar("BackendRecordsQuery", bound=BaseRecordsQuery) diff --git a/tests/client/test_api.py b/tests/client/test_api.py index d9de6616cb..7aaff4e86c 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -181,6 +181,18 @@ def test_log_something(monkeypatch, mocked_client): assert results.records[0].inputs["text"] == "This is a test" +def test_load_empty_string(monkeypatch, mocked_client): + dataset_name = "test-dataset" + mocked_client.delete(f"/api/datasets/{dataset_name}") + + api.log( + name=dataset_name, + records=rg.TextClassificationRecord(inputs={"text": "This is a test"}), + ) + assert len(api.load(name=dataset_name, query="")) == 1 + assert len(api.load(name=dataset_name, query=" ")) == 1 + + def test_load_limits(mocked_client, supported_vector_search): dataset = "test_load_limits" api_ds_prefix = f"/api/datasets/{dataset}" From 6649e5fefbf711e6b6f86996b0fa01c16d362045 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Thu, 2 Mar 2023 16:44:25 +0100 Subject: [PATCH 28/45] Refactor/prepare datasets endpoints (#2403) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR will simplify current datasets-related endpoints to adapt to the new API definition. - [x] Moving pydantic api models under the `server.schemas.datasets` module - [x] Making workspace required for dataset creation - [x] Add dataset id to dataset endpoints response - [x] Working with `dataset.workspace` instead of `dataset.owner` - [x] Avoid dataset inheritance as much as possible --------- Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Co-authored-by: David Berenstein Co-authored-by: keithCuniah Co-authored-by: Keith Cuniah <88380932+keithCuniah@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: José Francisco Calvo --- scripts/migrations/es_migration_25042021.py | 132 ------------------ .../server/apis/v0/handlers/datasets.py | 30 ++-- .../server/apis/v0/handlers/metrics.py | 3 - .../server/apis/v0/handlers/text2text.py | 18 +-- .../apis/v0/handlers/text_classification.py | 30 ++-- .../apis/v0/handlers/token_classification.py | 18 +-- .../server/apis/v0/models/text2text.py | 7 +- .../apis/v0/models/text_classification.py | 2 +- .../apis/v0/models/token_classification.py | 9 +- .../apis/v0/validators/text_classification.py | 2 +- .../v0/validators/token_classification.py | 2 +- src/argilla/server/commons/config.py | 6 +- .../server/daos/backend/generic_elastic.py | 19 --- .../server/daos/backend/search/model.py | 6 +- .../daos/backend/search/query_builder.py | 22 +-- src/argilla/server/daos/datasets.py | 21 +-- src/argilla/server/daos/models/datasets.py | 38 +++-- .../datasets.py => schemas/__init__.py} | 5 - .../{apis/v0/models => schemas}/datasets.py | 34 +++-- src/argilla/server/services/datasets.py | 100 +++++-------- .../server/services/tasks/text2text/models.py | 4 - .../services/tasks/text2text/service.py | 8 +- .../tasks/text_classification/model.py | 2 +- .../tasks/token_classification/model.py | 5 - .../tasks/token_classification/service.py | 10 +- tests/client/sdk/datasets/test_models.py | 4 +- .../search/test_search_service.py | 16 +-- tests/server/commons/test_records_dao.py | 1 + tests/server/datasets/test_api.py | 9 +- tests/server/datasets/test_dao.py | 13 +- tests/server/datasets/test_model.py | 40 +++++- tests/server/text_classification/test_api.py | 2 +- 32 files changed, 204 insertions(+), 414 deletions(-) delete mode 100644 scripts/migrations/es_migration_25042021.py rename src/argilla/server/{daos/backend/metrics/datasets.py => schemas/__init__.py} (74%) rename src/argilla/server/{apis/v0/models => schemas}/datasets.py (72%) diff --git a/scripts/migrations/es_migration_25042021.py b/scripts/migrations/es_migration_25042021.py deleted file mode 100644 index 1f5a04d602..0000000000 --- a/scripts/migrations/es_migration_25042021.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from itertools import zip_longest -from typing import Any, Dict, List, Optional - -from elasticsearch import Elasticsearch -from elasticsearch.helpers import bulk, scan -from pydantic import BaseSettings -from rubrix.server.tasks.commons import TaskType - - -class Settings(BaseSettings): - """ - Migration argument settings - """ - - elasticsearch: str = "http://localhost:9200" - migration_datasets: List[str] = [] - chunk_size: int = 1000 - task: TaskType - - -settings = Settings() - - -source_datasets_index = ".rubric.datasets-v1" -target_datasets_index = ".rubrix.datasets-v0" -source_record_index_pattern = ".rubric.dataset.{}.records-v1" -target_record_index_pattern = ".rubrix.dataset.{}.records-v0" - - -def batcher(iterable, n, fillvalue=None): - "batches an iterable" - args = [iter(iterable)] * n - return zip_longest(*args, fillvalue=fillvalue) - - -def map_doc_2_action(index: str, doc: Dict[str, Any], task: TaskType) -> Optional[Dict[str, Any]]: - """Configures bulk action""" - doc_data = doc["_source"] - new_record = { - "id": doc_data["id"], - "metadata": doc_data.get("metadata"), - "last_updated": doc_data.get("last_updated"), - "words": doc_data.get("words"), - } - - task_info = doc_data["tasks"].get(task) - if task_info is None: - return None - - new_record.update( - { - "status": task_info.get("status"), - "prediction": task_info.get("prediction"), - "annotation": task_info.get("annotation"), - "event_timestamp": task_info.get("event_timestamp"), - "predicted": task_info.get("predicted"), - "annotated_as": task_info.get("annotated_as"), - "predicted_as": task_info.get("predicted_as"), - "annotated_by": task_info.get("annotated_by"), - "predicted_by": task_info.get("predicted_by"), - "score": task_info.get("confidences"), - "owner": task_info.get("owner"), - } - ) - - if task == TaskType.text_classification: - new_record.update( - { - "inputs": doc_data.get("text"), - "multi_label": task_info.get("multi_label"), - "explanation": task_info.get("explanation"), - } - ) - elif task == TaskType.token_classification: - new_record.update( - { - "tokens": doc_data.get("tokens"), - "text": doc_data.get("raw_text"), - } - ) - return { - "_op_type": "index", - "_index": index, - "_id": doc["_id"], - **new_record, - } - - -if __name__ == "__main__": - client = Elasticsearch(hosts=settings.elasticsearch) - - for dataset in settings.migration_datasets: - source_index = source_record_index_pattern.format(dataset) - source_index_info = client.get(index=source_datasets_index, id=dataset) - - target_dataset_name = f"{dataset}-{settings.task}".lower() - target_index = target_record_index_pattern.format(target_dataset_name) - - target_index_info = source_index_info["_source"] - target_index_info["task"] = settings.task - target_index_info["name"] = target_dataset_name - - client.index( - index=target_datasets_index, - id=target_index_info["name"], - body=target_index_info, - ) - - index_docs = scan(client, index=source_index) - for batch in batcher(index_docs, n=settings.chunk_size): - bulk( - client, - actions=( - map_doc_2_action(index=target_index, doc=doc, task=settings.task) - for doc in batch - if doc is not None - ), - ) diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py index 46adf21652..e785645a46 100644 --- a/src/argilla/server/apis/v0/handlers/datasets.py +++ b/src/argilla/server/apis/v0/handlers/datasets.py @@ -16,17 +16,17 @@ from typing import List from fastapi import APIRouter, Body, Depends, Security +from pydantic import parse_obj_as from argilla.server.apis.v0.helpers import deprecate_endpoint from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies -from argilla.server.apis.v0.models.datasets import ( +from argilla.server.errors import EntityNotFoundError +from argilla.server.schemas.datasets import ( CopyDatasetRequest, CreateDatasetRequest, Dataset, UpdateDatasetRequest, ) -from argilla.server.commons.config import TasksFactory -from argilla.server.errors import EntityNotFoundError from argilla.server.security import auth from argilla.server.security.model import User from argilla.server.services.datasets import DatasetsService @@ -47,11 +47,13 @@ async def list_datasets( service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> List[Dataset]: - return service.list( + datasets = service.list( user=current_user, workspaces=[request_deps.workspace] if request_deps.workspace is not None else None, ) + return parse_obj_as(List[Dataset], datasets) + @router.post( "", @@ -67,14 +69,10 @@ async def create_dataset( datasets: DatasetsService = Depends(DatasetsService.get_instance), user: User = Security(auth.get_user, scopes=["create:datasets"]), ) -> Dataset: - owner = user.check_workspace(ws_params.workspace) + request.workspace = request.workspace or ws_params.workspace + dataset = datasets.create_dataset(user=user, dataset=request) - dataset_class = TasksFactory.get_task_dataset(request.task) - dataset = dataset_class.parse_obj({**request.dict()}) - dataset.owner = owner - - response = datasets.create_dataset(user=user, dataset=dataset) - return Dataset.parse_obj(response) + return Dataset.from_orm(dataset) @router.get( @@ -89,7 +87,7 @@ def get_dataset( service: DatasetsService = Depends(DatasetsService.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ) -> Dataset: - return Dataset.parse_obj( + return Dataset.from_orm( service.find_by_name( user=current_user, name=name, @@ -113,13 +111,15 @@ def update_dataset( ) -> Dataset: found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) - return service.update( + dataset = service.update( user=current_user, dataset=found_ds, tags=request.tags, metadata=request.metadata, ) + return Dataset.from_orm(dataset) + @router.delete( "/{name}", @@ -187,7 +187,7 @@ def copy_dataset( current_user: User = Security(auth.get_user, scopes=[]), ) -> Dataset: found = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace) - return service.copy_dataset( + dataset = service.copy_dataset( user=current_user, dataset=found, copy_name=copy_request.name, @@ -195,3 +195,5 @@ def copy_dataset( copy_tags=copy_request.tags, copy_metadata=copy_request.metadata, ) + + return Dataset.from_orm(dataset) diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py index 9408f1e3f0..eec6397c7d 100644 --- a/src/argilla/server/apis/v0/handlers/metrics.py +++ b/src/argilla/server/apis/v0/handlers/metrics.py @@ -81,7 +81,6 @@ def get_dataset_metrics( name=name, task=cfg.task, workspace=request_deps.workspace, - as_dataset_class=TasksFactory.get_task_dataset(cfg.task), ) metrics = TasksFactory.get_task_metrics(dataset.task) @@ -111,7 +110,6 @@ def metric_summary( name=name, task=cfg.task, workspace=request_deps.workspace, - as_dataset_class=TasksFactory.get_task_dataset(cfg.task), ) metric_ = TasksFactory.find_task_metric(task=cfg.task, metric_id=metric) @@ -119,7 +117,6 @@ def metric_summary( return metrics.summarize_metric( dataset=dataset, - owner=current_user.check_workspace(request_deps.workspace), metric=metric_, record_class=record_class, query=query, diff --git a/src/argilla/server/apis/v0/handlers/text2text.py b/src/argilla/server/apis/v0/handlers/text2text.py index cf6230c4ad..24e7d03493 100644 --- a/src/argilla/server/apis/v0/handlers/text2text.py +++ b/src/argilla/server/apis/v0/handlers/text2text.py @@ -27,7 +27,6 @@ ) from argilla.server.apis.v0.models.text2text import ( Text2TextBulkRequest, - Text2TextDataset, Text2TextMetrics, Text2TextQuery, Text2TextRecord, @@ -40,9 +39,10 @@ from argilla.server.errors import EntityNotFoundError from argilla.server.helpers import takeuntil from argilla.server.responses import StreamingResponseWithErrorHandling +from argilla.server.schemas.datasets import CreateDatasetRequest from argilla.server.security import auth from argilla.server.security.model import User -from argilla.server.services.datasets import DatasetsService +from argilla.server.services.datasets import DatasetsService, ServiceBaseDataset from argilla.server.services.tasks.text2text import Text2TextService from argilla.server.services.tasks.text2text.models import ( ServiceText2TextQuery, @@ -56,7 +56,6 @@ def configure_router(): TasksFactory.register_task( task_type=TaskType.text2text, - dataset_class=Text2TextDataset, query_request=Text2TextQuery, record_class=ServiceText2TextRecord, metrics=Text2TextMetrics, @@ -79,14 +78,13 @@ async def bulk_records( current_user: User = Security(auth.get_user, scopes=[]), ) -> BulkResponse: task = task_type - owner = current_user.check_workspace(common_params.workspace) + workspace = current_user.check_workspace(common_params.workspace) try: dataset = datasets.find_by_name( current_user, name=name, task=task, - workspace=owner, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + workspace=workspace, ) datasets.update( user=current_user, @@ -95,10 +93,8 @@ async def bulk_records( metadata=bulk.metadata, ) except EntityNotFoundError: - dataset_class = TasksFactory.get_task_dataset(task) - dataset = dataset_class.parse_obj({**bulk.dict(), "name": name}) - dataset.owner = owner - datasets.create_dataset(user=current_user, dataset=dataset) + dataset = CreateDatasetRequest(name=name, workspace=workspace, task=task, **bulk.dict()) + dataset = datasets.create_dataset(user=current_user, dataset=dataset) result = await service.add_records( dataset=dataset, @@ -133,7 +129,6 @@ def search_records( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), ) result = service.search( dataset=dataset, @@ -225,7 +220,6 @@ async def stream_data( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), ) data_stream = map( Text2TextRecord.parse_obj, diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py index 59f9436425..fb53afb55f 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification.py +++ b/src/argilla/server/apis/v0/handlers/text_classification.py @@ -49,6 +49,7 @@ from argilla.server.errors import EntityNotFoundError from argilla.server.helpers import takeuntil from argilla.server.responses import StreamingResponseWithErrorHandling +from argilla.server.schemas.datasets import CreateDatasetRequest from argilla.server.security import auth from argilla.server.security.model import User from argilla.server.services.datasets import DatasetsService @@ -95,26 +96,23 @@ async def bulk_records( current_user: User = Security(auth.get_user, scopes=[]), ) -> BulkResponse: task = task_type - owner = current_user.check_workspace(common_params.workspace) + workspace = current_user.check_workspace(common_params.workspace) try: dataset = datasets.find_by_name( current_user, name=name, task=task, - workspace=owner, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + workspace=workspace, ) - datasets.update( + dataset = datasets.update( user=current_user, dataset=dataset, tags=bulk.tags, metadata=bulk.metadata, ) except EntityNotFoundError: - dataset_class = TasksFactory.get_task_dataset(task) - dataset = dataset_class.parse_obj({**bulk.dict(), "name": name}) - dataset.owner = owner - datasets.create_dataset(user=current_user, dataset=dataset) + dataset = CreateDatasetRequest(name=name, workspace=workspace, task=task, **bulk.dict()) + dataset = datasets.create_dataset(user=current_user, dataset=dataset) # TODO(@frascuchon): Validator should be applied in the service layer records = [ServiceTextClassificationRecord.parse_obj(r) for r in bulk.records] @@ -182,7 +180,6 @@ def search_records( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), ) result = service.search( dataset=dataset, @@ -275,7 +272,6 @@ async def stream_data( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), ) data_stream = map( @@ -313,7 +309,7 @@ async def list_labeling_rules( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + as_dataset_class=TextClassificationDataset, ) return [LabelingRule.parse_obj(rule) for rule in service.list_labeling_rules(dataset)] @@ -340,7 +336,7 @@ async def create_rule( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + as_dataset_class=TextClassificationDataset, ) rule = ServiceLabelingRule( @@ -376,7 +372,7 @@ async def compute_rule_metrics( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + as_dataset_class=TextClassificationDataset, ) return service.compute_labeling_rule(dataset, rule_query=query, labels=labels) @@ -402,7 +398,7 @@ async def compute_dataset_rules_metrics( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + as_dataset_class=TextClassificationDataset, ) metrics = service.compute_all_labeling_rules(dataset) return DatasetLabelingRulesMetricsSummary.parse_obj(metrics) @@ -427,7 +423,7 @@ async def delete_labeling_rule( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + as_dataset_class=TextClassificationDataset, ) service.delete_labeling_rule(dataset, rule_query=query) @@ -454,7 +450,7 @@ async def get_rule( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + as_dataset_class=TextClassificationDataset, ) rule = service.find_labeling_rule(dataset, rule_query=query) return LabelingRule.parse_obj(rule) @@ -482,7 +478,7 @@ async def update_rule( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + as_dataset_class=TextClassificationDataset, ) rule = service.update_labeling_rule( diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py index b4a851b8de..5b8a0e6a1f 100644 --- a/src/argilla/server/apis/v0/handlers/token_classification.py +++ b/src/argilla/server/apis/v0/handlers/token_classification.py @@ -31,7 +31,6 @@ from argilla.server.apis.v0.models.token_classification import ( TokenClassificationAggregations, TokenClassificationBulkRequest, - TokenClassificationDataset, TokenClassificationQuery, TokenClassificationRecord, TokenClassificationSearchRequest, @@ -43,9 +42,10 @@ from argilla.server.errors import EntityNotFoundError from argilla.server.helpers import takeuntil from argilla.server.responses import StreamingResponseWithErrorHandling +from argilla.server.schemas.datasets import CreateDatasetRequest from argilla.server.security import auth from argilla.server.security.model import User -from argilla.server.services.datasets import DatasetsService +from argilla.server.services.datasets import DatasetsService, ServiceBaseDataset from argilla.server.services.tasks.token_classification import ( TokenClassificationService, ) @@ -65,7 +65,6 @@ def configure_router(): TasksFactory.register_task( task_type=task_type, - dataset_class=TokenClassificationDataset, query_request=TokenClassificationQuery, record_class=ServiceTokenClassificationRecord, metrics=TokenClassificationMetrics, @@ -89,14 +88,13 @@ async def bulk_records( current_user: User = Security(auth.get_user, scopes=[]), ) -> BulkResponse: task = task_type - owner = current_user.check_workspace(common_params.workspace) + workspace = current_user.check_workspace(common_params.workspace) try: dataset = datasets.find_by_name( current_user, name=name, task=task, - workspace=owner, - as_dataset_class=TasksFactory.get_task_dataset(task_type), + workspace=workspace, ) datasets.update( user=current_user, @@ -105,10 +103,8 @@ async def bulk_records( metadata=bulk.metadata, ) except EntityNotFoundError: - dataset_class = TasksFactory.get_task_dataset(task) - dataset = dataset_class.parse_obj({**bulk.dict(), "name": name}) - dataset.owner = owner - datasets.create_dataset(user=current_user, dataset=dataset) + dataset = CreateDatasetRequest(name=name, workspace=workspace, task=task, **bulk.dict()) + dataset = datasets.create_dataset(user=current_user, dataset=dataset) records = [ServiceTokenClassificationRecord.parse_obj(r) for r in bulk.records] # TODO(@frascuchon): validator can be applied in service layer @@ -155,7 +151,6 @@ def search_records( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), ) results = service.search( dataset=dataset, @@ -247,7 +242,6 @@ async def stream_data( name=name, task=task_type, workspace=common_params.workspace, - as_dataset_class=TasksFactory.get_task_dataset(task_type), ) data_stream = map( TokenClassificationRecord.parse_obj, diff --git a/src/argilla/server/apis/v0/models/text2text.py b/src/argilla/server/apis/v0/models/text2text.py index 1caa2b107d..5a5b5b843d 100644 --- a/src/argilla/server/apis/v0/models/text2text.py +++ b/src/argilla/server/apis/v0/models/text2text.py @@ -25,14 +25,13 @@ ScoreRange, SortableField, ) -from argilla.server.apis.v0.models.datasets import UpdateDatasetRequest from argilla.server.commons.models import PredictionStatus +from argilla.server.schemas.datasets import UpdateDatasetRequest from argilla.server.services.metrics.models import CommonTasksMetrics from argilla.server.services.search.model import ( ServiceBaseRecordsQuery, ServiceBaseSearchResultsAggregations, ) -from argilla.server.services.tasks.text2text.models import ServiceText2TextDataset class Text2TextPrediction(BaseModel): @@ -75,10 +74,6 @@ class Text2TextSearchResults(BaseSearchResults[Text2TextRecord, Text2TextSearchA pass -class Text2TextDataset(ServiceText2TextDataset): - pass - - class Text2TextMetrics(CommonTasksMetrics[Text2TextRecord]): pass diff --git a/src/argilla/server/apis/v0/models/text_classification.py b/src/argilla/server/apis/v0/models/text_classification.py index a4aa5621b4..48eae58c36 100644 --- a/src/argilla/server/apis/v0/models/text_classification.py +++ b/src/argilla/server/apis/v0/models/text_classification.py @@ -25,8 +25,8 @@ ScoreRange, SortableField, ) -from argilla.server.apis.v0.models.datasets import UpdateDatasetRequest from argilla.server.commons.models import PredictionStatus +from argilla.server.schemas.datasets import UpdateDatasetRequest from argilla.server.services.search.model import ( ServiceBaseRecordsQuery, ServiceBaseSearchResultsAggregations, diff --git a/src/argilla/server/apis/v0/models/token_classification.py b/src/argilla/server/apis/v0/models/token_classification.py index 4ec6608a85..356208ebf2 100644 --- a/src/argilla/server/apis/v0/models/token_classification.py +++ b/src/argilla/server/apis/v0/models/token_classification.py @@ -22,9 +22,9 @@ BaseSearchResults, ScoreRange, ) -from argilla.server.apis.v0.models.datasets import UpdateDatasetRequest from argilla.server.commons.models import PredictionStatus from argilla.server.daos.backend.search.model import SortableField +from argilla.server.schemas.datasets import UpdateDatasetRequest from argilla.server.services.search.model import ( ServiceBaseRecordsQuery, ServiceBaseSearchResultsAggregations, @@ -32,9 +32,6 @@ from argilla.server.services.tasks.token_classification.model import ( ServiceTokenClassificationAnnotation as _TokenClassificationAnnotation, ) -from argilla.server.services.tasks.token_classification.model import ( - ServiceTokenClassificationDataset, -) class TokenClassificationAnnotation(_TokenClassificationAnnotation): @@ -88,7 +85,3 @@ class TokenClassificationAggregations(ServiceBaseSearchResultsAggregations): class TokenClassificationSearchResults(BaseSearchResults[TokenClassificationRecord, TokenClassificationAggregations]): pass - - -class TokenClassificationDataset(ServiceTokenClassificationDataset): - pass diff --git a/src/argilla/server/apis/v0/validators/text_classification.py b/src/argilla/server/apis/v0/validators/text_classification.py index c7a3209a64..15e2d356db 100644 --- a/src/argilla/server/apis/v0/validators/text_classification.py +++ b/src/argilla/server/apis/v0/validators/text_classification.py @@ -17,9 +17,9 @@ from fastapi import Depends from argilla.server.apis.v0.models.dataset_settings import TextClassificationSettings -from argilla.server.apis.v0.models.datasets import Dataset from argilla.server.commons.models import TaskType from argilla.server.errors import BadRequestError, EntityNotFoundError +from argilla.server.schemas.datasets import Dataset from argilla.server.security.model import User from argilla.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings from argilla.server.services.tasks.text_classification.metrics import DatasetLabels diff --git a/src/argilla/server/apis/v0/validators/token_classification.py b/src/argilla/server/apis/v0/validators/token_classification.py index 52616d366c..5880a6f380 100644 --- a/src/argilla/server/apis/v0/validators/token_classification.py +++ b/src/argilla/server/apis/v0/validators/token_classification.py @@ -17,9 +17,9 @@ from fastapi import Depends from argilla.server.apis.v0.models.dataset_settings import TokenClassificationSettings -from argilla.server.apis.v0.models.datasets import Dataset from argilla.server.commons.models import TaskType from argilla.server.errors import BadRequestError, EntityNotFoundError +from argilla.server.schemas.datasets import Dataset from argilla.server.security.model import User from argilla.server.services.datasets import DatasetsService, ServiceBaseDatasetSettings from argilla.server.services.tasks.token_classification.metrics import DatasetLabels diff --git a/src/argilla/server/commons/config.py b/src/argilla/server/commons/config.py index 47d068fc39..03ecf8d7d9 100644 --- a/src/argilla/server/commons/config.py +++ b/src/argilla/server/commons/config.py @@ -18,7 +18,7 @@ from argilla.server.commons.models import TaskType from argilla.server.errors import EntityNotFoundError, WrongTaskError -from argilla.server.services.datasets import ServiceDataset +from argilla.server.services.datasets import ServiceBaseDataset, ServiceDataset from argilla.server.services.metrics import ServiceBaseMetric from argilla.server.services.metrics.models import ServiceBaseTaskMetrics from argilla.server.services.search.model import ServiceRecordsQuery @@ -40,14 +40,14 @@ class TasksFactory: def register_task( cls, task_type: TaskType, - dataset_class: Type[ServiceDataset], query_request: Type[ServiceRecordsQuery], record_class: Type[ServiceRecord], + dataset_class: Optional[Type[ServiceDataset]] = None, metrics: Optional[Type[ServiceBaseTaskMetrics]] = None, ): cls.__REGISTERED_TASKS__[task_type] = TaskConfig( task=task_type, - dataset=dataset_class, + dataset=dataset_class or ServiceBaseDataset, query=query_request, record=record_class, metrics=metrics, diff --git a/src/argilla/server/daos/backend/generic_elastic.py b/src/argilla/server/daos/backend/generic_elastic.py index 61d2ece88a..179f1ed2c2 100644 --- a/src/argilla/server/daos/backend/generic_elastic.py +++ b/src/argilla/server/daos/backend/generic_elastic.py @@ -467,30 +467,11 @@ def update_record( def find_dataset( self, id: str, - name: Optional[str] = None, - owner: Optional[str] = None, ): document = self.client.get_index_document_by_id( index=DATASETS_INDEX_NAME, id=id, ) - if not document and owner is None and name: - # We must search by name since we have no owner - docs = self.client.list_index_documents( - index=DATASETS_INDEX_NAME, - query=BaseDatasetsQuery(name=name), - size=self.__MAX_NUMBER_OF_LISTED_DATASETS__, - fetch_once=True, - ) - docs = list(docs) - if len(docs) == 0: - return None - - if len(docs) > 1: - raise ValueError( - f"Ambiguous dataset info found for name {name}. " "Please provide a valid owner/workspace" - ) - document = docs[0] return document def compute_argilla_metric(self, metric_id): diff --git a/src/argilla/server/daos/backend/search/model.py b/src/argilla/server/daos/backend/search/model.py index 2f9941d8a1..08419cd899 100644 --- a/src/argilla/server/daos/backend/search/model.py +++ b/src/argilla/server/daos/backend/search/model.py @@ -60,11 +60,7 @@ class BaseQuery(BaseModel): class BaseDatasetsQuery(BaseQuery): tasks: Optional[List[str]] = None - owners: Optional[List[str]] = None - # This is used to fetch workspaces without owner/workspace. But this should be moved to - # a default workspace - # TODO: Should be deprecated - include_no_owner: bool = None + workspaces: Optional[List[str]] = None name: Optional[str] = None diff --git a/src/argilla/server/daos/backend/search/query_builder.py b/src/argilla/server/daos/backend/search/query_builder.py index 3da0591864..d093c6ca85 100644 --- a/src/argilla/server/daos/backend/search/query_builder.py +++ b/src/argilla/server/daos/backend/search/query_builder.py @@ -103,23 +103,13 @@ def _datasets_to_es_query(self, query: Optional[BackendDatasetsQuery] = None) -> return filters.match_all() query_filters = [] - if query.owners: - owners_filter = filters.terms_filter( - "owner.keyword", - query.owners, - ) - if query.include_no_owner: - query_filters.append( - filters.boolean_filter( - minimum_should_match=1, # OR Condition - should_filters=[ - owners_filter, - filters.boolean_filter(must_not_query=filters.exists_field("owner")), - ], - ) + if query.workspaces: + query_filters.append( + filters.terms_filter( + "owner.keyword", # This will be moved to "workspace.keyword" + query.workspaces, ) - else: - query_filters.append(owners_filter) + ) if query.tasks: query_filters.append( diff --git a/src/argilla/server/daos/datasets.py b/src/argilla/server/daos/datasets.py index 0601209301..6f6b0874a5 100644 --- a/src/argilla/server/daos/datasets.py +++ b/src/argilla/server/daos/datasets.py @@ -75,13 +75,13 @@ def init(self): def list_datasets( self, - owner_list: List[str] = None, + workspaces: List[str] = None, task2dataset_map: Dict[str, Type[DatasetDB]] = None, name: Optional[str] = None, ) -> List[DatasetDB]: - owner_list = owner_list or [] + workspaces = workspaces or [] query = BaseDatasetsQuery( - owners=owner_list, + workspaces=workspaces, tasks=[task for task in task2dataset_map] if task2dataset_map else None, name=name, ) @@ -121,20 +121,16 @@ def delete_dataset(self, dataset: DatasetDB): def find_by_name( self, name: str, - owner: Optional[str], + workspace: str, as_dataset_class: Type[DatasetDB] = BaseDatasetDB, - task: Optional[str] = None, ) -> Optional[DatasetDB]: dataset_id = BaseDatasetDB.build_dataset_id( name=name, - owner=owner, + workspace=workspace, ) - document = self._es.find_dataset(id=dataset_id, name=name, owner=owner) + document = self._es.find_dataset(id=dataset_id) if document is None: return None - base_ds = self._es_doc_to_instance(document) - if task and task != base_ds.task: - raise WrongTaskError(detail=f"Provided task {task} cannot be applied to dataset") dataset_type = as_dataset_class or BaseDatasetDB return self._es_doc_to_instance(document, ds_class=dataset_type) @@ -193,11 +189,6 @@ def open(self, dataset: DatasetDB): """Make available a dataset""" self._es.open(dataset.id) - def get_all_workspaces(self) -> List[str]: - """Get all datasets (Only for super users)""" - metric_data = self._es.compute_argilla_metric(metric_id="all_argilla_workspaces") - return [k for k in metric_data] - def save_settings( self, dataset: DatasetDB, diff --git a/src/argilla/server/daos/models/datasets.py b/src/argilla/server/daos/models/datasets.py index 6b3c55f6cf..c2ddb4a751 100644 --- a/src/argilla/server/daos/models/datasets.py +++ b/src/argilla/server/daos/models/datasets.py @@ -15,7 +15,7 @@ from datetime import datetime from typing import Any, Dict, Optional, TypeVar, Union -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator from argilla._constants import ES_INDEX_REGEX_PATTERN from argilla.server.commons.models import TaskType @@ -25,7 +25,7 @@ class BaseDatasetDB(BaseModel): name: str = Field(regex=ES_INDEX_REGEX_PATTERN) task: TaskType owner: Optional[str] = Field(description="Deprecated. Use `workspace` instead. Will be removed in v1.5.0") - workspace: Optional[str] = None + workspace: Optional[str] tags: Dict[str, str] = Field(default_factory=dict) metadata: Dict[str, Any] = Field(default_factory=dict) @@ -36,24 +36,32 @@ class BaseDatasetDB(BaseModel): ) last_updated: datetime = None - @validator("workspace", pre=True, always=True) - def set_workspace_defaults(cls, value, values): - if value: - return value - else: - return values.get("owner") + @root_validator(pre=True) + def set_defaults(cls, values): + workspace = values.get("workspace") or values.get("owner") + + cls._check_workspace(workspace) + + values["workspace"] = workspace + values["owner"] = workspace + + return values + + @classmethod + def _check_workspace(cls, workspace: str): + if not workspace: + raise ValueError("Missing workspace") @classmethod - def build_dataset_id(cls, name: str, owner: Optional[str] = None) -> str: - """Build a dataset id for a given name and owner""" - if owner: - return f"{owner}.{name}" - return name + def build_dataset_id(cls, name: str, workspace: str) -> str: + """Build a dataset id for a given name and workspace""" + cls._check_workspace(workspace) + return f"{workspace}.{name}" @property def id(self) -> str: - """The dataset id. Compounded by owner and name""" - return self.build_dataset_id(self.name, self.owner) + """The dataset id. Compounded by workspace and name""" + return self.build_dataset_id(self.name, self.workspace) def dict(self, *args, **kwargs) -> Dict[str, Any]: """ diff --git a/src/argilla/server/daos/backend/metrics/datasets.py b/src/argilla/server/schemas/__init__.py similarity index 74% rename from src/argilla/server/daos/backend/metrics/datasets.py rename to src/argilla/server/schemas/__init__.py index b261599aed..55be41799b 100644 --- a/src/argilla/server/daos/backend/metrics/datasets.py +++ b/src/argilla/server/schemas/__init__.py @@ -11,8 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# All metrics related to the datasets index -from argilla.server.daos.backend.metrics.base import TermsAggregation - -METRICS = {"all_workspaces": TermsAggregation(id="all_workspaces", field="owner.keyword")} diff --git a/src/argilla/server/apis/v0/models/datasets.py b/src/argilla/server/schemas/datasets.py similarity index 72% rename from src/argilla/server/apis/v0/models/datasets.py rename to src/argilla/server/schemas/datasets.py index 65e4a2cc30..b411430b5f 100644 --- a/src/argilla/server/apis/v0/models/datasets.py +++ b/src/argilla/server/schemas/datasets.py @@ -16,14 +16,14 @@ """ Dataset models definition """ - -from typing import Any, Dict, Optional +from datetime import datetime +from typing import Any, Dict, Optional, Union +from uuid import UUID from pydantic import BaseModel, Field from argilla._constants import ES_INDEX_REGEX_PATTERN from argilla.server.commons.models import TaskType -from argilla.server.services.datasets import ServiceBaseDataset class UpdateDatasetRequest(BaseModel): @@ -48,6 +48,7 @@ class _BaseDatasetRequest(UpdateDatasetRequest): class CreateDatasetRequest(_BaseDatasetRequest): task: TaskType = Field(description="The dataset task") + workspace: Optional[str] = None class CopyDatasetRequest(_BaseDatasetRequest): @@ -58,20 +59,17 @@ class CopyDatasetRequest(_BaseDatasetRequest): target_workspace: Optional[str] = None -class Dataset(_BaseDatasetRequest, ServiceBaseDataset): - """ - Low level dataset data model +class Dataset(CreateDatasetRequest): + id: Union[str, UUID] + task: TaskType + owner: str = Field(description="Deprecated. Use `workspace` instead. Will be removed in v1.5.0") + workspace: str - Attributes: - ----------- - task: - The dataset task type. Deprecated - owner: - The dataset owner - created_at: - The dataset creation date - last_updated: - The last modification date - """ + tags: Dict[str, str] = Field(default_factory=dict) + metadata: Dict[str, Any] = Field(default_factory=dict) + created_at: datetime + created_by: Optional[str] = Field(description="The argilla user that created the dataset") + last_updated: datetime - task: TaskType + class Config: + orm_mode = True diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index 1c4e7f6843..da48e21d0d 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -26,6 +26,7 @@ ForbiddenOperationError, WrongTaskError, ) +from argilla.server.schemas.datasets import CreateDatasetRequest, Dataset from argilla.server.security.model import User @@ -53,82 +54,43 @@ def get_instance(cls, dao: DatasetsDAO = Depends(DatasetsDAO.get_instance)) -> " def __init__(self, dao: DatasetsDAO): self.__dao__ = dao - def create_dataset(self, user: User, dataset: ServiceDataset) -> ServiceDataset: - user.check_workspace(dataset.owner) - + def create_dataset(self, user: User, dataset: CreateDatasetRequest) -> BaseDatasetDB: + dataset.workspace = user.check_workspace(dataset.workspace) try: - self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner) - raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.owner) + self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.workspace) + raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.workspace) except WrongTaskError: # Found a dataset with same name but different task - raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.owner) + raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.workspace) except EntityNotFoundError: # The dataset does not exist -> create it ! date_now = datetime.utcnow() - dataset.created_by = user.username - dataset.created_at = date_now - dataset.last_updated = date_now - return self.__dao__.create_dataset(dataset) + + new_dataset = BaseDatasetDB.parse_obj(dataset.dict()) + new_dataset.created_by = user.username + new_dataset.created_at = date_now + new_dataset.last_updated = date_now + + return self.__dao__.create_dataset(new_dataset) def find_by_name( self, user: User, name: str, + workspace: str, as_dataset_class: Type[ServiceDataset] = ServiceBaseDataset, task: Optional[str] = None, - workspace: Optional[str] = None, ) -> ServiceDataset: - owner = user.check_workspace(workspace) - - if task is None: - found_ds = self.__find_by_name_with_superuser_fallback__( - user, name=name, owner=owner, as_dataset_class=as_dataset_class - ) - if found_ds: - task = found_ds.task - - found_ds = self.__find_by_name_with_superuser_fallback__( - user, name=name, owner=owner, task=task, as_dataset_class=as_dataset_class - ) - + workspace = user.check_workspace(workspace) + found_ds = self.__dao__.find_by_name(name=name, workspace=workspace, as_dataset_class=as_dataset_class) if found_ds is None: raise EntityNotFoundError(name=name, type=ServiceDataset) - if found_ds.owner and owner and found_ds.owner != owner: - raise EntityNotFoundError( - name=name, type=ServiceDataset - ) if user.is_superuser() else ForbiddenOperationError() - - return cast(ServiceDataset, found_ds) - - def __find_by_name_with_superuser_fallback__( - self, - user: User, - name: str, - owner: Optional[str], - as_dataset_class: Optional[Type[ServiceDataset]], - task: Optional[str] = None, - ): - found_ds = self.__dao__.find_by_name(name=name, owner=owner, task=task, as_dataset_class=as_dataset_class) - if not found_ds and user.is_superuser(): - try: - found_ds = self.__dao__.find_by_name( - name=name, owner=None, task=task, as_dataset_class=as_dataset_class - ) - except WrongTaskError: - # A dataset exists in a different workspace and with a different task - pass - return found_ds + elif task and found_ds.task != task: + raise WrongTaskError(detail=f"Provided task {task} cannot be applied to dataset") + else: + return cast(ServiceDataset, found_ds) def delete(self, user: User, dataset: ServiceDataset): - user.check_workspace(dataset.owner) - found = self.__find_by_name_with_superuser_fallback__( - user=user, - name=dataset.name, - owner=dataset.owner, - task=dataset.task, - as_dataset_class=None, - ) - if not found: - return + dataset = self.find_by_name(user=user, name=dataset.name, workspace=dataset.workspace, task=dataset.task) if user.is_superuser() or user.username == dataset.created_by: self.__dao__.delete_dataset(dataset) @@ -144,8 +106,8 @@ def update( dataset: ServiceDataset, tags: Dict[str, str], metadata: Dict[str, Any], - ) -> ServiceDataset: - found = self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner) + ) -> Dataset: + found = self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.workspace) dataset.tags = {**found.tags, **(tags or {})} dataset.metadata = {**found.metadata, **(metadata or {})} @@ -158,15 +120,15 @@ def list( workspaces: Optional[List[str]], task2dataset_map: Dict[str, Type[ServiceDataset]] = None, ) -> List[ServiceDataset]: - owners = user.check_workspaces(workspaces) - return self.__dao__.list_datasets(owner_list=owners, task2dataset_map=task2dataset_map) + workspaces = user.check_workspaces(workspaces) + return self.__dao__.list_datasets(workspaces=workspaces, task2dataset_map=task2dataset_map) def close(self, user: User, dataset: ServiceDataset): - found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.owner) + found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.workspace) self.__dao__.close(found) def open(self, user: User, dataset: ServiceDataset): - found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.owner) + found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.workspace) self.__dao__.open(found) def copy_dataset( @@ -178,7 +140,7 @@ def copy_dataset( copy_tags: Dict[str, Any] = None, copy_metadata: Dict[str, Any] = None, ) -> ServiceDataset: - dataset_workspace = copy_workspace or dataset.owner + dataset_workspace = copy_workspace or dataset.workspace dataset_workspace = user.check_workspace(dataset_workspace) self._validate_create_dataset( @@ -189,15 +151,17 @@ def copy_dataset( copy_dataset = dataset.copy() copy_dataset.name = copy_name - copy_dataset.owner = dataset_workspace + copy_dataset.workspace = dataset_workspace + date_now = datetime.utcnow() + copy_dataset.created_at = date_now copy_dataset.last_updated = date_now copy_dataset.tags = {**copy_dataset.tags, **(copy_tags or {})} copy_dataset.metadata = { **copy_dataset.metadata, **(copy_metadata or {}), - "source_workspace": dataset.owner, + "source_workspace": dataset.workspace, "copied_from": dataset.name, } diff --git a/src/argilla/server/services/tasks/text2text/models.py b/src/argilla/server/services/tasks/text2text/models.py index d07e0c1002..9c6ff0f334 100644 --- a/src/argilla/server/services/tasks/text2text/models.py +++ b/src/argilla/server/services/tasks/text2text/models.py @@ -76,7 +76,3 @@ def extended_fields(self) -> Dict[str, Any]: class ServiceText2TextQuery(ServiceBaseRecordsQuery): score: Optional[ServiceScoreRange] = Field(default=None) predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) - - -class ServiceText2TextDataset(ServiceBaseDataset): - task: TaskType = Field(default=TaskType.text2text, const=True) diff --git a/src/argilla/server/services/tasks/text2text/service.py b/src/argilla/server/services/tasks/text2text/service.py index a83e219153..40d75a70ca 100644 --- a/src/argilla/server/services/tasks/text2text/service.py +++ b/src/argilla/server/services/tasks/text2text/service.py @@ -18,6 +18,7 @@ from fastapi import Depends from argilla.server.commons.config import TasksFactory +from argilla.server.services.datasets import ServiceDataset from argilla.server.services.search.model import ( ServiceSearchResults, ServiceSortableField, @@ -27,7 +28,6 @@ from argilla.server.services.storage.service import RecordsStorageService from argilla.server.services.tasks.commons import BulkResponse from argilla.server.services.tasks.text2text.models import ( - ServiceText2TextDataset, ServiceText2TextQuery, ServiceText2TextRecord, ) @@ -61,7 +61,7 @@ def __init__( async def add_records( self, - dataset: ServiceText2TextDataset, + dataset: ServiceDataset, records: List[ServiceText2TextRecord], ): failed = await self.__storage__.store_records( @@ -73,7 +73,7 @@ async def add_records( def search( self, - dataset: ServiceText2TextDataset, + dataset: ServiceDataset, query: ServiceText2TextQuery, sort_by: List[ServiceSortableField], record_from: int = 0, @@ -113,7 +113,7 @@ def search( def read_dataset( self, - dataset: ServiceText2TextDataset, + dataset: ServiceDataset, query: Optional[ServiceText2TextQuery] = None, id_from: Optional[str] = None, limit: int = 1000, diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py index f8a82db328..aa95a09284 100644 --- a/src/argilla/server/services/tasks/text_classification/model.py +++ b/src/argilla/server/services/tasks/text_classification/model.py @@ -64,7 +64,7 @@ def strip_query(cls, query: str) -> str: class ServiceTextClassificationDataset(ServiceBaseDataset): - task: TaskType = Field(default=TaskType.text_classification, const=True) + task: TaskType = Field(default=TaskType.text_classification) rules: List[ServiceLabelingRule] = Field(default_factory=list) diff --git a/src/argilla/server/services/tasks/token_classification/model.py b/src/argilla/server/services/tasks/token_classification/model.py index ce01b96887..c52777a189 100644 --- a/src/argilla/server/services/tasks/token_classification/model.py +++ b/src/argilla/server/services/tasks/token_classification/model.py @@ -205,8 +205,3 @@ class ServiceTokenClassificationQuery(ServiceBaseRecordsQuery): annotated_as: List[str] = Field(default_factory=list) score: Optional[ServiceScoreRange] = Field(default=None) predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) - - -class ServiceTokenClassificationDataset(ServiceBaseDataset): - task: TaskType = Field(default=TaskType.token_classification, const=True) - pass diff --git a/src/argilla/server/services/tasks/token_classification/service.py b/src/argilla/server/services/tasks/token_classification/service.py index 50ef8dae2d..12e4bd1576 100644 --- a/src/argilla/server/services/tasks/token_classification/service.py +++ b/src/argilla/server/services/tasks/token_classification/service.py @@ -19,12 +19,12 @@ from argilla.server.commons.config import TasksFactory from argilla.server.daos.backend.search.model import SortableField +from argilla.server.services.datasets import ServiceBaseDataset from argilla.server.services.search.model import ServiceSearchResults, ServiceSortConfig from argilla.server.services.search.service import SearchRecordsService from argilla.server.services.storage.service import RecordsStorageService from argilla.server.services.tasks.commons import BulkResponse from argilla.server.services.tasks.token_classification.model import ( - ServiceTokenClassificationDataset, ServiceTokenClassificationQuery, ServiceTokenClassificationRecord, ) @@ -58,7 +58,7 @@ def __init__( async def add_records( self, - dataset: ServiceTokenClassificationDataset, + dataset: ServiceBaseDataset, records: List[ServiceTokenClassificationRecord], ): failed = await self.__storage__.store_records( @@ -70,7 +70,7 @@ async def add_records( def search( self, - dataset: ServiceTokenClassificationDataset, + dataset: ServiceBaseDataset, query: ServiceTokenClassificationQuery, sort_by: List[SortableField], record_from: int = 0, @@ -138,7 +138,7 @@ def search( def read_dataset( self, - dataset: ServiceTokenClassificationDataset, + dataset: ServiceBaseDataset, query: ServiceTokenClassificationQuery, id_from: Optional[str] = None, limit: int = 1000, @@ -150,8 +150,6 @@ def read_dataset( ---------- dataset: The dataset name - owner: - The dataset owner query: If provided, scan will retrieve only records matching the provided query filters. Optional diff --git a/tests/client/sdk/datasets/test_models.py b/tests/client/sdk/datasets/test_models.py index a1281e782b..9ddb208460 100644 --- a/tests/client/sdk/datasets/test_models.py +++ b/tests/client/sdk/datasets/test_models.py @@ -14,14 +14,14 @@ # limitations under the License. import pytest from argilla.client.sdk.datasets.models import Dataset, TaskType -from argilla.server.apis.v0.models.datasets import Dataset as ServerDataset +from argilla.server.schemas.datasets import Dataset as ServerDataset def test_dataset_schema(helpers): client_schema = Dataset.schema() server_schema = helpers.remove_key(ServerDataset.schema(), key="created_by") # don't care about creator here - assert helpers.remove_description(client_schema) == helpers.remove_description(server_schema) + assert helpers.are_compatible_api_schemas(client_schema, server_schema) def test_TaskType_enum(): diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py index 90ea518ae1..6b4004f7ea 100644 --- a/tests/functional_tests/search/test_search_service.py +++ b/tests/functional_tests/search/test_search_service.py @@ -15,7 +15,6 @@ import argilla import pytest from argilla.server.apis.v0.models.commons.model import ScoreRange -from argilla.server.apis.v0.models.datasets import Dataset from argilla.server.apis.v0.models.text_classification import ( TextClassificationQuery, TextClassificationRecord, @@ -24,7 +23,9 @@ from argilla.server.commons.models import TaskType from argilla.server.daos.backend import GenericElasticEngineBackend from argilla.server.daos.backend.search.query_builder import EsQueryBuilder +from argilla.server.daos.models.datasets import BaseDatasetDB from argilla.server.daos.records import DatasetRecordsDAO +from argilla.server.schemas.datasets import Dataset from argilla.server.services.metrics import MetricsService, ServicePythonMetric from argilla.server.services.search.model import ServiceSortConfig from argilla.server.services.search.service import SearchRecordsService @@ -71,9 +72,9 @@ def test_query_builder_with_query_range(backend: GenericElasticEngineBackend): def test_query_builder_with_nested(mocked_client, dao, backend: GenericElasticEngineBackend): - dataset = Dataset( + dataset = BaseDatasetDB( name="test_query_builder_with_nested", - owner=argilla.get_workspace(), + workspace=argilla.get_workspace(), task=TaskType.token_classification, ) argilla.delete(dataset.name) @@ -117,17 +118,14 @@ def test_query_builder_with_nested(mocked_client, dao, backend: GenericElasticEn def test_failing_metrics(service, mocked_client): - dataset = Dataset( + dataset = BaseDatasetDB( name="test_failing_metrics", - owner=argilla.get_workspace(), + workspace=argilla.get_workspace(), task=TaskType.text_classification, ) argilla.delete(dataset.name) - argilla.log( - argilla.TextClassificationRecord(text="This is a text, yeah!"), - name=dataset.name, - ) + argilla.log(argilla.TextClassificationRecord(text="This is a text, yeah!"), name=dataset.name) results = service.search( dataset=dataset, query=TextClassificationQuery(), diff --git a/tests/server/commons/test_records_dao.py b/tests/server/commons/test_records_dao.py index f86d68d084..1ddf52dc37 100644 --- a/tests/server/commons/test_records_dao.py +++ b/tests/server/commons/test_records_dao.py @@ -26,6 +26,7 @@ def test_raise_proper_error(): dao.search_records( dataset=BaseDatasetDB( name="mock-notfound", + workspace="workspace-name", task=TaskType.text_classification, ) ) diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py index 5660a2393f..3ab15acb2f 100644 --- a/tests/server/datasets/test_api.py +++ b/tests/server/datasets/test_api.py @@ -14,11 +14,11 @@ # limitations under the License. from typing import Optional -from argilla.server.apis.v0.models.datasets import Dataset from argilla.server.apis.v0.models.text_classification import ( TextClassificationBulkRequest, ) from argilla.server.commons.models import TaskType +from argilla.server.schemas.datasets import Dataset from tests.helpers import SecuredClient @@ -54,6 +54,7 @@ def test_create_dataset(mocked_client): ) assert response.status_code == 200 dataset = Dataset.parse_obj(response.json()) + assert dataset.id assert dataset.created_by == "argilla" assert dataset.metadata == request["metadata"] assert dataset.tags == request["tags"] @@ -131,7 +132,7 @@ def test_dataset_naming_validation(mocked_client): "type": "value_error.str.regex", } ], - "model": "TextClassificationDataset", + "model": "CreateDatasetRequest", }, } } @@ -153,7 +154,7 @@ def test_dataset_naming_validation(mocked_client): "type": "value_error.str.regex", } ], - "model": "TokenClassificationDataset", + "model": "CreateDatasetRequest", }, } } @@ -171,6 +172,8 @@ def test_list_datasets(mocked_client): datasets = [Dataset.parse_obj(item) for item in response.json()] assert len(datasets) > 0 assert dataset in [ds.name for ds in datasets] + for ds in datasets: + assert ds.id def test_update_dataset(mocked_client): diff --git a/tests/server/datasets/test_dao.py b/tests/server/datasets/test_dao.py index 962df52564..6f2a734009 100644 --- a/tests/server/datasets/test_dao.py +++ b/tests/server/datasets/test_dao.py @@ -29,11 +29,10 @@ def test_retrieve_ownered_dataset_for_no_owner_user(): dataset = "test_retrieve_owned_dataset_for_no_owner_user" created = dao.create_dataset( - BaseDatasetDB(name=dataset, owner="other", task=TaskType.text_classification), + BaseDatasetDB(name=dataset, workspace="other", task=TaskType.text_classification), ) - assert dao.find_by_name(created.name, owner=created.owner) == created - assert dao.find_by_name(created.name, owner=None) == created - assert dao.find_by_name(created.name, owner="me") is None + assert dao.find_by_name(created.name, workspace=created.workspace) == created + assert dao.find_by_name(created.name, workspace="me") is None def test_list_datasets_by_task(): @@ -46,7 +45,7 @@ def test_list_datasets_by_task(): created_text = dao.create_dataset( BaseDatasetDB( name=dataset + "_text", - owner="other", + workspace="other", task=TaskType.text_classification, ), ) @@ -54,7 +53,7 @@ def test_list_datasets_by_task(): created_token = dao.create_dataset( BaseDatasetDB( name=dataset + "_token", - owner="other", + workspace="other", task=TaskType.token_classification, ), ) @@ -76,7 +75,7 @@ def test_close_dataset(): dataset = "test_close_dataset" created = dao.create_dataset( - BaseDatasetDB(name=dataset, owner="other", task=TaskType.text_classification), + BaseDatasetDB(name=dataset, workspace="other", task=TaskType.text_classification), ) dao.close(created) diff --git a/tests/server/datasets/test_model.py b/tests/server/datasets/test_model.py index 165203cd7f..9aae961c76 100644 --- a/tests/server/datasets/test_model.py +++ b/tests/server/datasets/test_model.py @@ -12,10 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime +import uuid import pytest -from argilla.server.apis.v0.models.datasets import CreateDatasetRequest from argilla.server.commons.models import TaskType +from argilla.server.daos.models.datasets import BaseDatasetDB +from argilla.server.schemas.datasets import CreateDatasetRequest, Dataset from pydantic import ValidationError @@ -43,3 +46,38 @@ def test_dataset_naming_ok(name): def test_dataset_naming_ko(name): with pytest.raises(ValidationError, match="string does not match regex"): CreateDatasetRequest(name=name, task=TaskType.token_classification) + + +@pytest.mark.parametrize( + ("dataset", "expected_workspace"), + [ + (BaseDatasetDB(name="ds", workspace="ws", task=TaskType.text_classification), "ws"), + (BaseDatasetDB(name="ds", owner="owner", task=TaskType.text_classification), "owner"), + (BaseDatasetDB(name="ds", workspace="ws", owner="owner", task=TaskType.text_classification), "ws"), + (BaseDatasetDB(name="ds", workspace=None, owner="ws", task=TaskType.text_classification), "ws"), + ], +) +def test_dataset_creation_sync(dataset, expected_workspace): + assert dataset.workspace == expected_workspace + assert dataset.owner == dataset.workspace + assert dataset.id == f"{dataset.workspace}.{dataset.name}" + + +def test_dataset_creation_fails_on_no_workspace_and_owner(): + with pytest.raises(ValueError, match="Missing workspace"): + BaseDatasetDB(task=TaskType.text_classification, name="tedb", workspace=None, owner=None) + + +def test_accept_create_dataset_without_created_by(): + ds = Dataset( + name="a-dataset", + id=uuid.uuid4(), + task=TaskType.text_classification, + owner="dd", + workspace="dd", + created_at=datetime.datetime.utcnow(), + last_updated=datetime.datetime.utcnow(), + ) + + assert ds + assert ds.created_by is None diff --git a/tests/server/text_classification/test_api.py b/tests/server/text_classification/test_api.py index 79210fdf55..8f63c76f82 100644 --- a/tests/server/text_classification/test_api.py +++ b/tests/server/text_classification/test_api.py @@ -17,7 +17,6 @@ import pytest from argilla.server.apis.v0.models.commons.model import BulkResponse -from argilla.server.apis.v0.models.datasets import Dataset from argilla.server.apis.v0.models.text_classification import ( TextClassificationAnnotation, TextClassificationBulkRequest, @@ -27,6 +26,7 @@ TextClassificationSearchResults, ) from argilla.server.commons.models import PredictionStatus +from argilla.server.schemas.datasets import Dataset from tests.client.conftest import SUPPORTED_VECTOR_SEARCH From 40ca933c265c8bbebd54755b91df33edaabf1161 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 3 Mar 2023 17:28:16 +0100 Subject: [PATCH 29/45] refactor: Make workspace required in requests (#2471) # Description Preparing new DB integration, some workspace validation and logic will be removed. This PR prepares this integration by forcing sending workspace on each request, instead using the `default_workspace`. For the current clients and UIs, this won't be a problem since the workspace is sent on each request right now. **Type of change** - [x] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [x] Refactor (change restructuring the codebase without changing functionality) **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --- src/argilla/server/security/model.py | 9 ++---- tests/helpers.py | 5 +++- tests/server/security/test_model.py | 42 +++++++++++++--------------- 3 files changed, 25 insertions(+), 31 deletions(-) diff --git a/src/argilla/server/security/model.py b/src/argilla/server/security/model.py index 5e7f77a491..58d07c2dce 100644 --- a/src/argilla/server/security/model.py +++ b/src/argilla/server/security/model.py @@ -18,7 +18,7 @@ from pydantic import BaseModel, Field, root_validator, validator from argilla._constants import ES_INDEX_REGEX_PATTERN -from argilla.server.errors import EntityNotFoundError +from argilla.server.errors import BadRequestError, EntityNotFoundError WORKSPACE_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_\-]*$") _EMAIL_REGEX_PATTERN = r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}" @@ -79,11 +79,6 @@ def _set_default_workspace(cls, value, values): return list(set(value)) - @property - def default_workspace(self) -> Optional[str]: - """Get the default user workspace""" - return self.username - def check_workspaces(self, workspaces: List[str]) -> List[str]: """ Given a list of workspaces, apply a belongs to validation for each one. Then, return @@ -122,7 +117,7 @@ def check_workspace(self, workspace: str) -> str: """ if not workspace: - return self.default_workspace + raise BadRequestError("Missing workspace. A workspace must by provided") elif workspace not in self.workspaces: raise EntityNotFoundError(name=workspace, type="Workspace") return workspace diff --git a/tests/helpers.py b/tests/helpers.py index 06c1d1825f..49c7641a1b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -26,7 +26,10 @@ class SecuredClient: def __init__(self, client: TestClient): self._client = client - self._header = {API_KEY_HEADER_NAME: settings.default_apikey} + self._header = { + API_KEY_HEADER_NAME: settings.default_apikey, + WORKSPACE_HEADER_NAME: "argilla", # Hard-coded default workspace + } self._current_user = None @property diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py index 07d0dfa5e8..6fb7da4386 100644 --- a/tests/server/security/test_model.py +++ b/tests/server/security/test_model.py @@ -13,7 +13,7 @@ # limitations under the License. import pytest -from argilla.server.errors import EntityNotFoundError +from argilla.server.errors import BadRequestError, EntityNotFoundError from argilla.server.security.model import User from pydantic import ValidationError @@ -67,24 +67,14 @@ def test_check_user_workspaces(): assert user.check_workspaces(["not-found-ws"]) -def test_default_workspace(): - user = User(username="admin") - assert user.default_workspace == "admin" - - test_user = User(username="test", workspaces=["ws"]) - assert test_user.default_workspace == test_user.username - - def test_workspace_for_superuser(): user = User(username="admin") - assert user.default_workspace == "admin" + + assert user.check_workspace("admin") == "admin" with pytest.raises(EntityNotFoundError): assert user.check_workspace("some") == "some" - assert user.check_workspace(None) == "admin" - assert user.check_workspace("") == "admin" - user.workspaces = ["some"] assert user.check_workspaces(["some"]) == ["some"] @@ -111,17 +101,23 @@ def test_is_superuser(): assert user.is_superuser() +@pytest.mark.parametrize("workspaces", [None, [], ["a"]]) +def test_check_workspaces_with_default(workspaces): + user = User(username="user", workspaces=workspaces) + assert set(user.check_workspace(user.username)) == set(user.username) + + @pytest.mark.parametrize( - "workspaces, expected", + "user", [ - (None, {"user"}), - ([], {"user"}), - (["a"], {"user", "a"}), + User(username="admin", workspaces=None, superuser=True), + User(username="mock", workspaces=None, superuser=False), + User(username="user", workspaces=["ab"], superuser=True), ], ) -def test_check_workspaces_with_default(workspaces, expected): - user = User(username="user", workspaces=workspaces) - assert set(user.check_workspaces([])) == expected - assert set(user.check_workspaces(None)) == expected - assert set(user.check_workspaces([None])) == expected - assert set(user.check_workspace(user.username)) == set(user.username) +def test_check_workspace_without_workspace(user): + with pytest.raises(BadRequestError): + user.check_workspace("") + + with pytest.raises(BadRequestError): + user.check_workspace(None) From b3b897ac731d0113ea5e078b5bf5974f351935fd Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Mon, 6 Mar 2023 13:59:30 +0100 Subject: [PATCH 30/45] feat: Allow passing workspace as client param for `rg.log` or `rg.load` (#2425) # Description Allow passing workspace as client parm for rglog or rgload. I also enabled this for `rg.delete` and `rg.delete_records`. Closes #2059 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [X] New feature (non-breaking change which adds functionality) - [X] Refactor (change restructuring the codebase without changing functionality) - [X] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [Test A](https://github.com/argilla-io/argilla/blob/b227b09a423dae8792e520ce4f9da3dcb5fb0d0d/tests/client/test_api.py#L364) **Checklist** - [X] I have merged the original branch into my forked branch - [X] follows the style guidelines of this project - [X] I did a self-review of my code - [] I made corresponding changes to the documentation - [X] My changes generate no new warnings - [X] I have added tests that prove my fix is effective or that my feature works --- src/argilla/client/api.py | 22 +++++++++++++++-- src/argilla/client/client.py | 47 +++++++++++++++++++++++++----------- tests/client/test_api.py | 22 +++++++++++++++++ 3 files changed, 75 insertions(+), 16 deletions(-) diff --git a/src/argilla/client/api.py b/src/argilla/client/api.py index 3bce617e35..cb55e0a7cf 100644 --- a/src/argilla/client/api.py +++ b/src/argilla/client/api.py @@ -109,6 +109,7 @@ def init( def log( records: Union[Record, Iterable[Record], Dataset], name: str, + workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, chunk_size: int = 500, @@ -122,6 +123,8 @@ def log( Args: records: The record, an iterable of records, or a dataset to log. name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. chunk_size: The chunk size for a data bulk. @@ -150,6 +153,7 @@ def log( return ArgillaSingleton.get().log( records=records, name=name, + workspace=workspace, tags=tags, metadata=metadata, chunk_size=chunk_size, @@ -161,6 +165,7 @@ def log( async def log_async( records: Union[Record, Iterable[Record], Dataset], name: str, + workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, chunk_size: int = 500, @@ -171,6 +176,8 @@ async def log_async( Args: records: The record, an iterable of records, or a dataset to log. name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. chunk_size: The chunk size for a data bulk. @@ -192,6 +199,7 @@ async def log_async( return await ArgillaSingleton.get().log_async( records=records, name=name, + workspace=workspace, tags=tags, metadata=metadata, chunk_size=chunk_size, @@ -201,6 +209,7 @@ async def log_async( def load( name: str, + workspace: Optional[str] = None, query: Optional[str] = None, vector: Optional[Tuple[str, List[float]]] = None, ids: Optional[List[Union[str, int]]] = None, @@ -212,6 +221,8 @@ def load( Args: name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. query: An ElasticSearch query with the `query string syntax `_ vector: Vector configuration for a semantic search @@ -246,6 +257,7 @@ def load( """ return ArgillaSingleton.get().load( name=name, + workspace=workspace, query=query, vector=vector, ids=ids, @@ -280,22 +292,25 @@ def copy( ) -def delete(name: str): +def delete(name: str, workspace: Optional[str] = None): """ Deletes a dataset. Args: name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. Examples: >>> import argilla as rg >>> rg.delete(name="example-dataset") """ - ArgillaSingleton.get().delete(name) + ArgillaSingleton.get().delete(name=name, workspace=workspace) def delete_records( name: str, + workspace: Optional[str] = None, query: Optional[str] = None, ids: Optional[List[Union[str, int]]] = None, discard_only: bool = False, @@ -305,6 +320,8 @@ def delete_records( Args: name: The dataset name. + workspace: The workspace to which records will be logged/loaded. If `None` (default) and the + env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. query: An ElasticSearch query with the `query string syntax `_ ids: If provided, deletes dataset records with given ids. @@ -329,6 +346,7 @@ def delete_records( """ return ArgillaSingleton.get().delete_records( name=name, + workspace=workspace, query=query, ids=ids, discard_only=discard_only, diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 1df03aca60..62e0d18c36 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -100,9 +100,7 @@ async def __log_internal__(api: "Argilla", *args, **kwargs): return await api.log_async(*args, **kwargs) except Exception as ex: dataset = kwargs["name"] - _LOGGER.error( - f"\nCannot log data in dataset '{dataset}'\n" f"Error: {type(ex).__name__}\n" f"Details: {ex}" - ) + _LOGGER.error(f"\nCannot log data in dataset '{dataset}'\nError: {type(ex).__name__}\nDetails: {ex}") raise ex def log(self, *args, **kwargs) -> Future: @@ -169,7 +167,7 @@ def __del__(self): def client(self) -> AuthenticatedClient: """The underlying authenticated HTTP client""" warnings.warn( - message=("This prop will be removed in next release. " "Please use the http_client prop instead."), + message="This prop will be removed in next release. Please use the http_client prop instead.", category=UserWarning, ) return self._client @@ -251,18 +249,22 @@ def copy(self, dataset: str, name_of_copy: str, workspace: str = None): ), ) - def delete(self, name: str): + def delete(self, name: str, workspace: Optional[str] = None): """Deletes a dataset. Args: name: The dataset name. """ + if workspace is not None: + self.set_workspace(workspace) + datasets_api.delete_dataset(client=self._client, name=name) def log( self, records: Union[Record, Iterable[Record], Dataset], name: str, + workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, chunk_size: int = 500, @@ -290,6 +292,9 @@ def log( will be returned instead. """ + if workspace is not None: + self.set_workspace(workspace) + future = self._agent.log( records=records, name=name, @@ -310,6 +315,7 @@ async def log_async( self, records: Union[Record, Iterable[Record], Dataset], name: str, + workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, chunk_size: int = 500, @@ -332,6 +338,9 @@ async def log_async( tags = tags or {} metadata = metadata or {} + if workspace is not None: + self.set_workspace(workspace) + if not name: raise InputValueError("Empty dataset name has been passed as argument.") @@ -370,7 +379,7 @@ async def log_async( bulk_class = Text2TextBulkData creation_class = CreationText2TextRecord else: - raise InputValueError(f"Unknown record type {record_type}. Available values are" f" {Record.__args__}") + raise InputValueError(f"Unknown record type {record_type}. Available values are {Record.__args__}") processed, failed = 0, 0 with Progress() as progress_bar: @@ -408,6 +417,7 @@ async def log_async( def delete_records( self, name: str, + workspace: Optional[str] = None, query: Optional[str] = None, ids: Optional[List[Union[str, int]]] = None, discard_only: bool = False, @@ -432,6 +442,9 @@ def delete_records( deletion). """ + if workspace is not None: + self.set_workspace(workspace) + return self.datasets.delete_records( name=name, mark_as_discarded=discard_only, @@ -443,6 +456,7 @@ def delete_records( def load( self, name: str, + workspace: Optional[str] = None, query: Optional[str] = None, vector: Optional[Tuple[str, List[float]]] = None, ids: Optional[List[Union[str, int]]] = None, @@ -469,6 +483,9 @@ def load( A argilla dataset. """ + if workspace is not None: + self.set_workspace(workspace) + if as_pandas is False: warnings.warn( "The argument `as_pandas` is deprecated and will be removed in a future" @@ -479,7 +496,7 @@ def load( raise ValueError( "The argument `as_pandas` is deprecated and will be removed in a future" " version. Please adapt your code accordingly. ", - "If you want a pandas DataFrame do" " `rg.load('my_dataset').to_pandas()`.", + "If you want a pandas DataFrame do `rg.load('my_dataset').to_pandas()`.", ) try: @@ -495,13 +512,15 @@ def load( from argilla import __version__ as version warnings.warn( - message=f"Using python client argilla=={version}," - f" however deployed server version is {err.api_version}." - " This might lead to compatibility issues.\n" - f" Preferably, update your server version to {version}" - " or downgrade your Python API at the loss" - " of functionality and robustness via\n" - f"`pip install argilla=={err.api_version}`", + message=( + f"Using python client argilla=={version}," + f" however deployed server version is {err.api_version}." + " This might lead to compatibility issues.\n" + f" Preferably, update your server version to {version}" + " or downgrade your Python API at the loss" + " of functionality and robustness via\n" + f"`pip install argilla=={err.api_version}`" + ), category=UserWarning, ) diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 7aaff4e86c..63299145df 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -372,6 +372,28 @@ def test_general_log_load(mocked_client, monkeypatch, request, records, dataset_ assert record == expected +@pytest.mark.parametrize( + "records, dataset_class", + [ + ("singlelabel_textclassification_records", rg.DatasetForTextClassification), + ], +) +def test_log_load_with_workspace(mocked_client, monkeypatch, request, records, dataset_class): + dataset_names = [ + f"test_general_log_load_{dataset_class.__name__.lower()}_" + input_type + for input_type in ["single", "list", "dataset"] + ] + for name in dataset_names: + mocked_client.delete(f"/api/datasets/{name}") + + records = request.getfixturevalue(records) + + api.log(records, name=dataset_names[0], workspace="argilla") + ds = api.load(dataset_names[0], workspace="argilla") + api.delete_records(dataset_names[0], ids=[rec.id for rec in ds][:1], workspace="argilla") + api.delete(dataset_names[0], workspace="argilla") + + def test_passing_wrong_iterable_data(mocked_client): dataset_name = "test_log_single_records" mocked_client.delete(f"/api/datasets/{dataset_name}") From 3ebea7644bf21cf9e74f376225252c95ed82707f Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:00:29 +0100 Subject: [PATCH 31/45] feat: Deprecate `chunk_size` in favor of `batch_size` for `rg.log` (#2455) Closes #2453 ## Pull Request overview * Deprecate `chunk_size` in favor of `batch_size` in `rg.log`: * Move `chunk_size` to the end of all related signatures. * Set `chunk_size` default to None and update the typing accordingly. * Introduce `batch_size` in the old position in the signature. * Update docstrings accordingly. * Introduce a `FutureWarning` if `chunk_size` is used. * Introduce test showing that `rg.log(..., chunk_size=100)` indeed throws a warning. * Update a warning to no longer include a newline and a lot of spaces (see first comment of this PR) ## Details Note that this deprecation is non-breaking: Code that uses `chunk_size` will still work, as `batch_size = chunk_size` after a FutureWarning is given, if `chunk_size` is not `None`. ## Discussion * Should I use a FutureWarning? Or a DeprecationWarning? Or a PendingDeprecationWarning? The last two make sense, but they are [ignored by default](https://docs.python.org/3/library/warnings.html#default-warning-filter), I'm afraid. * Is the deprecation message in the format that we like? --- **Type of change** - [x] New feature (non-breaking change which adds functionality) **How Has This Been Tested** I introduced a test, and ran all tests. **Checklist** - [x] I have merged the original branch into my forked branch - [x] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --- - Tom Aarsen --- src/argilla/client/api.py | 18 ++++++++++------ src/argilla/client/client.py | 40 +++++++++++++++++++++++------------- tests/client/test_api.py | 17 +++++++++++++++ 3 files changed, 55 insertions(+), 20 deletions(-) diff --git a/src/argilla/client/api.py b/src/argilla/client/api.py index cb55e0a7cf..c7cf8552ed 100644 --- a/src/argilla/client/api.py +++ b/src/argilla/client/api.py @@ -112,9 +112,10 @@ def log( workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - chunk_size: int = 500, + batch_size: int = 500, verbose: bool = True, background: bool = False, + chunk_size: Optional[int] = None, ) -> Union[BulkResponse, Future]: """Logs Records to argilla. @@ -127,10 +128,11 @@ def log( env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. - chunk_size: The chunk size for a data bulk. + batch_size: The batch size for a data bulk. verbose: If True, shows a progress bar and prints out a quick summary at the end. background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future`` object. You probably want to set ``verbose`` to False in that case. + chunk_size: DEPRECATED! Use `batch_size` instead. Returns: Summary of the response from the REST API. @@ -156,9 +158,10 @@ def log( workspace=workspace, tags=tags, metadata=metadata, - chunk_size=chunk_size, + batch_size=batch_size, verbose=verbose, background=background, + chunk_size=chunk_size, ) @@ -168,8 +171,9 @@ async def log_async( workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - chunk_size: int = 500, + batch_size: int = 500, verbose: bool = True, + chunk_size: Optional[int] = None, ) -> BulkResponse: """Logs Records to argilla with asyncio. @@ -180,8 +184,9 @@ async def log_async( env variable ``ARGILLA_WORKSPACE`` is not set, it will default to the private user workspace. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. - chunk_size: The chunk size for a data bulk. + batch_size: The batch size for a data bulk. verbose: If True, shows a progress bar and prints out a quick summary at the end. + chunk_size: DEPRECATED! Use `batch_size` instead. Returns: Summary of the response from the REST API @@ -202,8 +207,9 @@ async def log_async( workspace=workspace, tags=tags, metadata=metadata, - chunk_size=chunk_size, + batch_size=batch_size, verbose=verbose, + chunk_size=chunk_size, ) diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 62e0d18c36..0eb7820833 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -113,7 +113,7 @@ class Argilla: """ # Larger sizes will trigger a warning - _MAX_CHUNK_SIZE = 5000 + _MAX_BATCH_SIZE = 5000 def __init__( self, @@ -267,9 +267,10 @@ def log( workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - chunk_size: int = 500, + batch_size: int = 500, verbose: bool = True, background: bool = False, + chunk_size: Optional[int] = None, ) -> Union[BulkResponse, Future]: """Logs Records to argilla. @@ -280,11 +281,12 @@ def log( name: The dataset name. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. - chunk_size: The chunk size for a data bulk. + batch_size: The batch size for a data bulk. verbose: If True, shows a progress bar and prints out a quick summary at the end. background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future`` object. You probably want to set ``verbose`` to False in that case. + chunk_size: DEPRECATED! Use `batch_size` instead. Returns: Summary of the response from the REST API. @@ -300,8 +302,9 @@ def log( name=name, tags=tags, metadata=metadata, - chunk_size=chunk_size, + batch_size=batch_size, verbose=verbose, + chunk_size=chunk_size, ) if background: return future @@ -318,8 +321,9 @@ async def log_async( workspace: Optional[str] = None, tags: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, - chunk_size: int = 500, + batch_size: int = 500, verbose: bool = True, + chunk_size: Optional[int] = None, ) -> BulkResponse: """Logs Records to argilla with asyncio. @@ -328,8 +332,9 @@ async def log_async( name: The dataset name. tags: A dictionary of tags related to the dataset. metadata: A dictionary of extra info for the dataset. - chunk_size: The chunk size for a data bulk. + batch_size: The batch size for a data bulk. verbose: If True, shows a progress bar and prints out a quick summary at the end. + chunk_size: DEPRECATED! Use `batch_size` instead. Returns: Summary of the response from the REST API @@ -353,11 +358,18 @@ async def log_async( " https://www.elastic.co/guide/en/elasticsearch/reference/current/indices-create-index.html" ) - if chunk_size > self._MAX_CHUNK_SIZE: + if chunk_size is not None: + warnings.warn( + "The argument `chunk_size` is deprecated and will be removed in a future" + " version. Please use `batch_size` instead.", + FutureWarning, + ) + batch_size = chunk_size + + if batch_size > self._MAX_BATCH_SIZE: _LOGGER.warning( - """The introduced chunk size is noticeably large, timeout errors may occur. - Consider a chunk size smaller than %s""", - self._MAX_CHUNK_SIZE, + "The requested batch size is noticeably large, timeout errors may occur. " + f"Consider a batch size smaller than {self._MAX_BATCH_SIZE}", ) if isinstance(records, Record.__args__): @@ -385,8 +397,8 @@ async def log_async( with Progress() as progress_bar: task = progress_bar.add_task("Logging...", total=len(records), visible=verbose) - for i in range(0, len(records), chunk_size): - chunk = records[i : i + chunk_size] + for i in range(0, len(records), batch_size): + batch = records[i : i + batch_size] response = await async_bulk( client=self._client, @@ -394,14 +406,14 @@ async def log_async( json_body=bulk_class( tags=tags, metadata=metadata, - records=[creation_class.from_client(r) for r in chunk], + records=[creation_class.from_client(r) for r in batch], ), ) processed += response.parsed.processed failed += response.parsed.failed - progress_bar.update(task, advance=len(chunk)) + progress_bar.update(task, advance=len(batch)) # TODO: improve logging policy in library if verbose: diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 63299145df..217c01a03f 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -245,6 +245,23 @@ def test_log_passing_empty_records_list(mocked_client): api.log(records=[], name="ds") +def test_log_deprecated_chunk_size(mocked_client): + dataset_name = "test_log_deprecated_chunk_size" + mocked_client.delete(f"/api/datasets/{dataset_name}") + record = rg.TextClassificationRecord(text="My text") + with pytest.warns(FutureWarning, match="`chunk_size`.*`batch_size`"): + api.log(records=[record], name=dataset_name, chunk_size=100) + + +def test_large_batch_size_warning(mocked_client, caplog: pytest.LogCaptureFixture): + dataset_name = "test_large_batch_size_warning" + mocked_client.delete(f"/api/datasets/{dataset_name}") + record = rg.TextClassificationRecord(text="My text") + api.log(records=[record], name=dataset_name, batch_size=10000) + assert len(caplog.record_tuples) == 1 + assert "batch size is noticeably large" in caplog.record_tuples[0][2] + + def test_log_background(mocked_client): """Verify that logs can be delayed via the background parameter.""" dataset_name = "test_log_background" From e25be3e5e2d0b886e15433d109cc24bebc6cf93d Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Mon, 6 Mar 2023 14:01:59 +0100 Subject: [PATCH 32/45] feat: Expose `batch_size` parameter for `rg.load` (#2460) Closes #2454 Hello! ## Pull Request overview * Extend my previous work to expose a `batch_size` parameter to `rg.load` (and `rg.scan`, `_load_records_new_fashion`). * Update the docstrings for the public methods. * The actual implementation is very simple due to my earlier work in #2434: * `batch_size = self.DEFAULT_SCAN_SIZE` -> `batch_size = batch_size or self.DEFAULT_SCAN_SIZE` * Add an (untested!) warning if `batch_size` is supplied with an old server version. * (We don't have compatibility tests for old server versions) * I've simplified `test_scan_efficient_limiting` with `batch_size` rather than using a mocker to update `DEFAULT_SCAN_SIZE`. * I've also extended this test to test `rg.load` in addition to `active_api().datasets.scan`. ## Note Unlike previously discussed with @davidberenstein1957, I have *not* added a warning like so: https://github.com/argilla-io/argilla/blob/f5834a5408051bf1fa60d42aede6b325dc07fdbd/src/argilla/client/client.py#L338-L343 The server limits the maximum batch size to 500 for loading, so there is no need to have a warning when using a batch size of over 5000. ## Discussion 1. Should I include in the docstring that the default batch size is 250? That would be "hardcoded" into the docs, so if we ever change the default (`self.DEFAULT_SCAN_SIZE`), then we would have to remember to update the docs too. 2. Alternatively, should we deprecate `self.DEFAULT_SCAN_SIZE` and enforce that `batch_size` must be set for `datasets.scan`, with a default size of 250 everywhere? Then people can see the default batch size in the signature? I think we should do option 2. Let me know how you feel. --- **Type of change** - [x] New feature (non-breaking change which adds functionality) **How Has This Been Tested** Modified and updated a test, ran all tests. **Checklist** - [x] I have merged the original branch into my forked branch - [x] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [x] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --- - Tom Aarsen --- src/argilla/client/api.py | 4 ++++ src/argilla/client/apis/datasets.py | 6 +++--- src/argilla/client/client.py | 11 +++++++++++ src/argilla/server/apis/v0/handlers/records_search.py | 2 +- .../client/functional_tests/test_scan_raw_records.py | 10 +++++----- 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/argilla/client/api.py b/src/argilla/client/api.py index c7cf8552ed..01fd3b22ee 100644 --- a/src/argilla/client/api.py +++ b/src/argilla/client/api.py @@ -221,6 +221,7 @@ def load( ids: Optional[List[Union[str, int]]] = None, limit: Optional[int] = None, id_from: Optional[str] = None, + batch_size: int = 250, as_pandas=None, ) -> Dataset: """Loads a argilla dataset. @@ -237,6 +238,8 @@ def load( id_from: If provided, starts gathering the records starting from that Record. As the Records returned with the load method are sorted by ID, ´id_from´ can be used to load using batches. + batch_size: If provided, load `batch_size` samples per request. A lower batch + size may help avoid timeouts. as_pandas: DEPRECATED! To get a pandas DataFrame do ``rg.load('my_dataset').to_pandas()``. @@ -269,6 +272,7 @@ def load( ids=ids, limit=limit, id_from=id_from, + batch_size=batch_size, as_pandas=as_pandas, ) diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py index cf717a4089..5c58d19e83 100644 --- a/src/argilla/client/apis/datasets.py +++ b/src/argilla/client/apis/datasets.py @@ -102,8 +102,6 @@ class Datasets(AbstractApi): __SETTINGS_MIN_API_VERSION__ = "0.15" - DEFAULT_SCAN_SIZE = 250 - class _DatasetApiModel(BaseModel): name: str task: TaskType @@ -155,6 +153,7 @@ def scan( projection: Optional[Set[str]] = None, limit: Optional[int] = None, id_from: Optional[str] = None, + batch_size: int = 250, **query, ) -> Iterable[dict]: """ @@ -169,6 +168,8 @@ def scan( id_from: If provided, starts gathering the records starting from that Record. As the Records returned with the load method are sorted by ID, ´id_from´ can be used to load using batches. + batch_size: If provided, load `batch_size` samples per request. A lower batch + size may help avoid timeouts. Returns: An iterable of raw object containing per-record info @@ -177,7 +178,6 @@ def scan( if limit and limit < 0: raise ValueError("The scan limit must be non-negative.") - batch_size = self.DEFAULT_SCAN_SIZE limit = limit if limit else math.inf url = f"{self._API_PREFIX}/{name}/records/:search?limit={{limit}}" query = self._parse_query(query=query) diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 0eb7820833..277f5bf380 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -474,6 +474,7 @@ def load( ids: Optional[List[Union[str, int]]] = None, limit: Optional[int] = None, id_from: Optional[str] = None, + batch_size: int = 250, as_pandas=None, ) -> Dataset: """Loads a argilla dataset. @@ -519,6 +520,7 @@ def load( ids=ids, limit=limit, id_from=id_from, + batch_size=batch_size, ) except ApiCompatibilityError as err: # Api backward compatibility from argilla import __version__ as version @@ -535,6 +537,13 @@ def load( ), category=UserWarning, ) + if batch_size is not None: + warnings.warn( + message="The `batch_size` parameter is not supported" + f" on server version {err.api_version}. Consider" + f" updating your server version to {version} to" + " take advantage of this functionality." + ) return self._load_records_old_fashion( name=name, @@ -700,6 +709,7 @@ def _load_records_new_fashion( ids: Optional[List[Union[str, int]]] = None, limit: Optional[int] = None, id_from: Optional[str] = None, + batch_size: int = 250, ) -> Dataset: dataset = self.datasets.find_by_name(name=name) task = dataset.task @@ -747,6 +757,7 @@ def _load_records_new_fashion( projection={"*"}, limit=limit, id_from=id_from, + batch_size=batch_size, # Query query_text=query, ids=ids, diff --git a/src/argilla/server/apis/v0/handlers/records_search.py b/src/argilla/server/apis/v0/handlers/records_search.py index 18d2f1c0b7..44a87b3d8d 100644 --- a/src/argilla/server/apis/v0/handlers/records_search.py +++ b/src/argilla/server/apis/v0/handlers/records_search.py @@ -57,7 +57,7 @@ async def search_dataset_records( limit: int = Query( default=100, gte=0, - le=500, + le=1000, description="Number of records to retrieve", ), request_deps: CommonTaskHandlerDependencies = Depends(), diff --git a/tests/client/functional_tests/test_scan_raw_records.py b/tests/client/functional_tests/test_scan_raw_records.py index ea3224b3f3..7c4fbe272b 100644 --- a/tests/client/functional_tests/test_scan_raw_records.py +++ b/tests/client/functional_tests/test_scan_raw_records.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argilla as rg import pytest from argilla.client.api import active_api from argilla.client.sdk.token_classification.models import TokenClassificationRecord @@ -83,16 +84,15 @@ def test_scan_fail_negative_limit( @pytest.mark.parametrize(("limit"), [6, 23, 20]) +@pytest.mark.parametrize(("load_method"), [lambda: active_api().datasets.scan, lambda: rg.load]) def test_scan_efficient_limiting( monkeypatch: pytest.MonkeyPatch, limit, gutenberg_spacy_ner, + load_method, ): - client_datasets = active_api().datasets - # Reduce the default scan size to something small to better test the situation - # where limit > DEFAULT_SCAN_SIZE + method = load_method() batch_size = 10 - monkeypatch.setattr(client_datasets, "DEFAULT_SCAN_SIZE", batch_size) # Monkeypatch the .post() call to track with what URLs the server is called called_paths = [] @@ -105,7 +105,7 @@ def tracked_post(path, *args, **kwargs): monkeypatch.setattr(active_api().http_client, "post", tracked_post) # Try to fetch `limit` samples from the 100 - data = client_datasets.scan(name=gutenberg_spacy_ner, limit=limit) + data = method(name=gutenberg_spacy_ner, limit=limit, batch_size=10) data = list(data) # Ensure that `limit` samples were indeed fetched From 71997809dbe1d5330d907d5c81195e13979d8205 Mon Sep 17 00:00:00 2001 From: Daniel Vila Suero Date: Mon, 6 Mar 2023 21:18:10 +0100 Subject: [PATCH 33/45] docs: Add AutoTrain to readme --- README.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b034e12f15..e6450e6d82 100644 --- a/README.md +++ b/README.md @@ -21,9 +21,18 @@

-

Open-source framework for data-centric NLP

+

Open-source platform for data-centric NLP

Data Labeling for MLOps & Feedback Loops

+ + + +https://user-images.githubusercontent.com/1107111/223220683-fbfa63da-367c-4cfa-bda5-66f47413b6b0.mp4 + +
+ +> 🆕 🔥 Train custom transformers models with no-code: [Argilla + AutoTrain](https://www.argilla.io/blog/argilla-meets-autotrain) + > 🆕 🔥 Deploy [Argilla on Spaces](https://huggingface.co/new-space?template=argilla/argilla-template-space) > 🆕 🔥 Since `1.2.0` Argilla supports vector search for finding the most similar records to a given one. This feature uses vector or semantic search combined with more traditional search (keyword and filter based). Learn more on this [deep-dive guide](https://docs.argilla.io/en/latest/guides/features/semantic-search.html) From 5600301a59899bc6fbc5d40384478e90cb2e21d7 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Mon, 6 Mar 2023 22:40:21 +0100 Subject: [PATCH 34/45] fix: added flexible app redirect to docs page (#2428) # Description I added more flexible app redirect to `api/docs` Closes #2377 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [X] Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** N.A. **Checklist** N.A. --------- Co-authored-by: frascuchon --- src/argilla/server/server.py | 14 +++++++++++++- tests/conftest.py | 6 ++++++ tests/server/test_api.py | 11 +++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py index def025bf53..ca106a0a63 100644 --- a/src/argilla/server/server.py +++ b/src/argilla/server/server.py @@ -29,6 +29,7 @@ from fastapi import FastAPI from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse from pydantic import ConfigError from argilla import __version__ as argilla_version @@ -202,7 +203,18 @@ async def check_telemetry(): version=str(argilla_version), ) -app = FastAPI() + +@argilla_app.get("/docs", include_in_schema=False) +async def redirect_docs(): + return RedirectResponse(url=f"{settings.base_url}api/docs") + + +@argilla_app.get("/api", include_in_schema=False) +async def redirect_api(): + return RedirectResponse(url=f"{settings.base_url}api/docs") + + +app = FastAPI(docs_url=None) app.mount(settings.base_url, argilla_app) configure_app_logging(app) diff --git a/tests/conftest.py b/tests/conftest.py index f8495bff2b..c9b1d52432 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,6 +35,12 @@ def telemetry_track_data(mocker): return spy +@pytest.fixture(scope="session") +def test_client(): + with TestClient(app) as client: + yield client + + @pytest.fixture def mocked_client( monkeypatch, diff --git a/tests/server/test_api.py b/tests/server/test_api.py index 72743a4762..f0b87ec894 100644 --- a/tests/server/test_api.py +++ b/tests/server/test_api.py @@ -20,6 +20,7 @@ TextClassificationRecord, ) from argilla.server.commons.models import TaskStatus, TaskType +from starlette.testclient import TestClient def create_some_data_for_text_classification( @@ -109,3 +110,13 @@ def uri_2_path(uri: str): p = urlparse(uri) return os.path.abspath(os.path.join(p.netloc, p.path)) + + +def test_docs_redirect(test_client: TestClient): + response = test_client.get("/docs", follow_redirects=False) + assert response.status_code == 307 + assert response.next_request.url.path == "/api/docs" + + response = test_client.get("/api", follow_redirects=False) + assert response.status_code == 307 + assert response.next_request.url.path == "/api/docs" From 21efb839e051f1cf48c16a5d25373f25da192d53 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Mon, 6 Mar 2023 22:53:01 +0100 Subject: [PATCH 35/45] feat: Add text2text support for prepare for training spark nlp (#2466) # Description Added text2text support for spark-nlp training small bug-fix for prepare for training with spacy textcat Closes #2465 Closes #2482 **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [X] New feature (non-breaking change which adds functionality) **How Has This Been Tested** N.A. **Checklist** N.A. --------- Co-authored-by: dvsrepo Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/argilla/client/datasets.py | 32 +++++++++-- tests/client/test_dataset.py | 101 +++++++++++++++++++++++++++------ 2 files changed, 113 insertions(+), 20 deletions(-) diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py index 83d15348f9..fd1eb386bd 100644 --- a/src/argilla/client/datasets.py +++ b/src/argilla/client/datasets.py @@ -789,8 +789,12 @@ def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[ doc = nlp.make_doc(text) cats = dict.fromkeys(all_labels, 0) - for anno in record.annotation: - cats[anno] = 1 + + if isinstance(record.annotation, list): + for anno in record.annotation: + cats[anno] = 1 + else: + cats[record.annotation] = 1 doc.cats = cats db.add(doc) @@ -821,8 +825,15 @@ def _prepare_for_training_with_spark_nlp(self, records: List[Record]) -> "pandas def __all_labels__(self): all_labels = set() for record in self._records: - if record.annotation: - all_labels.update(record.annotation) + if record.annotation is None: + continue + elif isinstance(record.annotation, str): + all_labels.add(record.annotation) + elif isinstance(record.annotation, list): + all_labels.update((tuple(record.annotation))) + else: + # this is highly unlikely + raise TypeError("Record.annotation contains an unsupported type: {}".format(type(record.annotation))) return list(all_labels) @@ -1258,6 +1269,19 @@ def _prepare_for_training_with_transformers( return ds + def _prepare_for_training_with_spark_nlp(self, records: List[Record]) -> "pandas.DataFrame": + spark_nlp_data = [] + for record in records: + if record.annotation is None: + continue + if record.id is None: + record.id = str(uuid.uuid4()) + text = record.text + + spark_nlp_data.append([record.id, text, record.annotation]) + + return pd.DataFrame(spark_nlp_data, columns=["id", "text", "target"]) + Dataset = Union[DatasetForTextClassification, DatasetForTokenClassification, DatasetForText2Text] diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py index 79a1246e5c..163460d44d 100644 --- a/tests/client/test_dataset.py +++ b/tests/client/test_dataset.py @@ -339,7 +339,7 @@ def test_prepare_for_training(self, request, records): records = request.getfixturevalue(records) ds = rg.DatasetForTextClassification(records) - train = ds.prepare_for_training() + train = ds.prepare_for_training(seed=42) if not ds[0].multi_label: column_names = ["text", "context", "label"] @@ -357,12 +357,73 @@ def test_prepare_for_training(self, request, records): else: assert train.features["label"] == datasets.ClassLabel(names=["a"]) - train_test = ds.prepare_for_training(train_size=0.5) + train_test = ds.prepare_for_training(train_size=0.5, seed=42) assert len(train_test["train"]) == 1 assert len(train_test["test"]) == 1 for split in ["train", "test"]: assert train_test[split].column_names == column_names + @pytest.mark.parametrize( + "records", + [ + "singlelabel_textclassification_records", + "multilabel_textclassification_records", + ], + ) + def test_prepare_for_training_with_spacy(self, request, records): + records = request.getfixturevalue(records) + + ds = rg.DatasetForTextClassification(records) + with pytest.raises(ValueError): + train = ds.prepare_for_training(framework="spacy", seed=42) + nlp = spacy.blank("en") + doc_bin = ds.prepare_for_training(framework="spacy", lang=nlp, seed=42) + + assert isinstance(doc_bin, spacy.tokens.DocBin) + docs = list(doc_bin.get_docs(nlp.vocab)) + assert len(docs) == 2 + + if records[0].multi_label: + assert set(list(docs[0].cats.keys())) == set(["a", "b"]) + else: + assert isinstance(docs[0].cats, dict) + + train, test = ds.prepare_for_training(train_size=0.5, framework="spacy", lang=nlp, seed=42) + docs_train = list(train.get_docs(nlp.vocab)) + docs_test = list(train.get_docs(nlp.vocab)) + assert len(list(docs_train)) == 1 + assert len(list(docs_test)) == 1 + + @pytest.mark.parametrize( + "records", + [ + "singlelabel_textclassification_records", + "multilabel_textclassification_records", + ], + ) + def test_prepare_for_training_with_spark_nlp(self, request, records): + records = request.getfixturevalue(records) + + ds = rg.DatasetForTextClassification(records) + df = ds.prepare_for_training("spark-nlp", train_size=1, seed=42) + + if ds[0].multi_label: + column_names = ["id", "text", "labels"] + else: + column_names = ["id", "text", "label"] + + assert isinstance(df, pd.DataFrame) + assert list(df.columns) == column_names + assert len(df) == 2 + + df_train, df_test = ds.prepare_for_training("spark-nlp", train_size=0.5, seed=42) + assert len(df_train) == 1 + assert len(df_test) == 1 + assert isinstance(df_train, pd.DataFrame) + assert isinstance(df_test, pd.DataFrame) + assert list(df_train.columns) == column_names + assert list(df_test.columns) == column_names + @pytest.mark.skipif( _HF_HUB_ACCESS_TOKEN is None, reason="You need a HF Hub access token to test the push_to_hub feature", @@ -574,13 +635,15 @@ def test_prepare_for_training_with_spacy(self): r.annotation = [(label, start, end) for label, start, end, _ in r.prediction] with pytest.raises(ValueError): - train = rb_dataset.prepare_for_training(framework="spacy") + train = rb_dataset.prepare_for_training(framework="spacy", seed=42) - train = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en")) + train = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en"), seed=42) assert isinstance(train, spacy.tokens.DocBin) assert len(train) == 100 - train, test = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en"), train_size=0.8) + train, test = rb_dataset.prepare_for_training( + framework="spacy", lang=spacy.blank("en"), train_size=0.8, seed=42 + ) assert isinstance(train, spacy.tokens.DocBin) assert isinstance(test, spacy.tokens.DocBin) assert len(train) == 80 @@ -601,11 +664,11 @@ def test_prepare_for_training_with_spark_nlp(self): for r in rb_dataset: r.annotation = [(label, start, end) for label, start, end, _ in r.prediction] - train = rb_dataset.prepare_for_training(framework="spark-nlp") + train = rb_dataset.prepare_for_training(framework="spark-nlp", seed=42) assert isinstance(train, pd.DataFrame) assert len(train) == 100 - train, test = rb_dataset.prepare_for_training(framework="spark-nlp", train_size=0.8) + train, test = rb_dataset.prepare_for_training(framework="spark-nlp", train_size=0.8, seed=42) assert isinstance(train, pd.DataFrame) assert isinstance(test, pd.DataFrame) assert len(train) == 80 @@ -788,8 +851,10 @@ def test_to_from_pandas(self, text2text_records): assert rec == expected def test_prepare_for_training(self): - ds = rg.DatasetForText2Text([rg.Text2TextRecord(text="mock", annotation="mock")] * 10) - train = ds.prepare_for_training(train_size=1) + ds = rg.DatasetForText2Text( + [rg.Text2TextRecord(text="mock", annotation="mock"), rg.Text2TextRecord(text="mock")] * 10 + ) + train = ds.prepare_for_training(train_size=1, seed=42) assert isinstance(train, datasets.Dataset) assert train.column_names == ["text", "target"] @@ -799,21 +864,25 @@ def test_prepare_for_training(self): assert train.features["text"] == datasets.Value("string") assert train.features["target"] == datasets.Value("string") - train_test = ds.prepare_for_training(train_size=0.5) + train_test = ds.prepare_for_training(train_size=0.5, seed=42) assert len(train_test["train"]) == 5 assert len(train_test["test"]) == 5 for split in ["train", "test"]: assert train_test[split].column_names == ["text", "target"] - def test_prepare_for_training_spacy(self): - ds = rg.DatasetForText2Text([rg.Text2TextRecord(text="mock", annotation="mock")] * 10) + def test_prepare_for_training_with_spacy(self): + ds = rg.DatasetForText2Text( + [rg.Text2TextRecord(text="mock", annotation="mock"), rg.Text2TextRecord(text="mock")] * 10 + ) with pytest.raises(NotImplementedError): ds.prepare_for_training("spacy", lang=spacy.blank("en"), train_size=1) - def test_prepare_for_training_spark_nlp(self): - ds = rg.DatasetForText2Text([rg.Text2TextRecord(text="mock", annotation="mock")] * 10) - with pytest.raises(NotImplementedError): - ds.prepare_for_training("spark-nlp", train_size=1) + def test_prepare_for_training_with_spark_nlp(self): + ds = rg.DatasetForText2Text( + [rg.Text2TextRecord(text="mock", annotation="mock"), rg.Text2TextRecord(text="mock")] * 10 + ) + df = ds.prepare_for_training("spark-nlp", train_size=1) + assert list(df.columns) == ["id", "text", "target"] @pytest.mark.skipif( _HF_HUB_ACCESS_TOKEN is None, From 58aa9f9354775b457843e42b89c99d411aa88370 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Mar 2023 13:01:39 +0100 Subject: [PATCH 36/45] [pre-commit.ci] pre-commit autoupdate (#2490) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/charliermarsh/ruff-pre-commit: v0.0.253 → v0.0.254](https://github.com/charliermarsh/ruff-pre-commit/compare/v0.0.253...v0.0.254) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0395640409..b835cfa39f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: additional_dependencies: ["click==8.0.4"] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.253 + rev: v0.0.254 hooks: # Simulate isort via (the much faster) ruff - id: ruff From 4aecb134c9655c51d097ce04230dd756b4c5a509 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Tue, 7 Mar 2023 13:11:00 +0100 Subject: [PATCH 37/45] Feat/python api sort support (#2487) # Description Added sort functionality to `rg.load` like so. `rg.load("name", limit=100, sort=[("field", "asc|desc")])` Closes https://github.com/argilla-io/argilla/issues/2433 **Type of change** - [X] New feature (non-breaking change which adds functionality) - [X] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [X] [Test A](https://github.com/argilla-io/argilla/blob/65031609fbad263e41ead233f7335bed83f22268/tests/client/test_api.py#L636) **Checklist** - [X] I have merged the original branch into my forked branch - [X] I added relevant documentation - [X] follows the style guidelines of this project - [X] I did a self-review of my code - [X] I made corresponding changes to the documentation - [X] My changes generate no new warnings - [X] I have added tests that prove my fix is effective or that my feature works --------- Co-authored-by: Francisco Aranda --- .../guides/log_load_and_prepare_data.ipynb | 3 +- src/argilla/client/api.py | 3 + src/argilla/client/apis/datasets.py | 37 ++++++--- src/argilla/client/client.py | 21 ++--- .../server/apis/v0/handlers/records_search.py | 83 +++++++++++-------- .../daos/backend/client_adapters/base.py | 9 +- .../backend/client_adapters/opensearch.py | 35 ++++---- .../server/daos/backend/generic_elastic.py | 28 ++++--- .../server/daos/backend/mappings/helpers.py | 4 +- .../daos/backend/search/query_builder.py | 32 +++---- src/argilla/server/daos/records.py | 13 ++- .../server/services/metrics/service.py | 4 +- tests/client/test_api.py | 23 ++++- .../text_classification/test_label_errors.py | 2 +- tests/metrics/test_token_classification.py | 2 +- tests/server/backend/test_query_builder.py | 19 ++++- 16 files changed, 192 insertions(+), 126 deletions(-) diff --git a/docs/_source/guides/log_load_and_prepare_data.ipynb b/docs/_source/guides/log_load_and_prepare_data.ipynb index c826285f5c..64f8ccc563 100644 --- a/docs/_source/guides/log_load_and_prepare_data.ipynb +++ b/docs/_source/guides/log_load_and_prepare_data.ipynb @@ -233,7 +233,8 @@ " query=\"my AND query\",\n", " limit=42\n", " ids=[\"id1\", \"id2\", \"id3\"],\n", - " vectors=[\"vector1\", \"vector2\", \"vector3\"], \n", + " vectors=(\"vector1\", [0, 42, 1957]), \n", + " sort=[(\"event_timestamp\", \"desc\")]\n", ")" ] }, diff --git a/src/argilla/client/api.py b/src/argilla/client/api.py index 01fd3b22ee..6c21fc5529 100644 --- a/src/argilla/client/api.py +++ b/src/argilla/client/api.py @@ -220,6 +220,7 @@ def load( vector: Optional[Tuple[str, List[float]]] = None, ids: Optional[List[Union[str, int]]] = None, limit: Optional[int] = None, + sort: Optional[List[Tuple[str, str]]] = None, id_from: Optional[str] = None, batch_size: int = 250, as_pandas=None, @@ -235,6 +236,7 @@ def load( vector: Vector configuration for a semantic search ids: If provided, load dataset records with given ids. limit: The number of records to retrieve. + sort: The fields on which to sort [(, 'asc|decs')]. id_from: If provided, starts gathering the records starting from that Record. As the Records returned with the load method are sorted by ID, ´id_from´ can be used to load using batches. @@ -271,6 +273,7 @@ def load( vector=vector, ids=ids, limit=limit, + sort=sort, id_from=id_from, batch_size=batch_size, as_pandas=as_pandas, diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py index 5c58d19e83..0b86a0ab57 100644 --- a/src/argilla/client/apis/datasets.py +++ b/src/argilla/client/apis/datasets.py @@ -16,7 +16,7 @@ import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from pydantic import BaseModel, Field @@ -152,6 +152,7 @@ def scan( name: str, projection: Optional[Set[str]] = None, limit: Optional[int] = None, + sort: Optional[List[Tuple[str, str]]] = None, id_from: Optional[str] = None, batch_size: int = 250, **query, @@ -163,8 +164,9 @@ def scan( name: the dataset query: the search query projection: a subset of record fields to retrieve. If not provided, - only id's will be returned - limit: The number of records to retrieve. + only id's will be returned + sort: The fields on which to sort [(, 'asc|decs')]. + limit: The number of records to retrieve id_from: If provided, starts gathering the records starting from that Record. As the Records returned with the load method are sorted by ID, ´id_from´ can be used to load using batches. @@ -187,7 +189,18 @@ def scan( "query": query, } - if id_from: + if sort is not None: + try: + if isinstance(sort, list): + assert all([(isinstance(item, tuple)) and (item[-1] in ["asc", "desc"]) for item in sort]) + else: + raise Exception() + except Exception: + raise ValueError("sort must be a dict formatted as List[Tuple[, 'asc|desc']]") + request["sort_by"] = [{"id": item[0], "order": item[-1]} for item in sort] + + elif id_from: + # TODO: Show message since sort + next_id is not compatible since fixes a sort by id request["next_idx"] = id_from with api_compatibility(self, min_version="1.2.0"): @@ -203,13 +216,15 @@ def scan( if limit <= 0: return - next_idx = response.get("next_idx") - if next_idx: - request_limit = min(limit, batch_size) - response = self.http_client.post( - path=url.format(limit=request_limit), - json={**request, "next_idx": next_idx}, - ) + next_request_params = {k: response[k] for k in ["next_idx", "next_page_cfg"] if response.get(k)} + if not next_request_params: + return + + request_limit = min(limit, batch_size) + response = self.http_client.post( + path=url.format(limit=request_limit), + json={**request, **next_request_params}, + ) def update_record( self, diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py index 277f5bf380..3e892a9930 100644 --- a/src/argilla/client/client.py +++ b/src/argilla/client/client.py @@ -473,6 +473,7 @@ def load( vector: Optional[Tuple[str, List[float]]] = None, ids: Optional[List[Union[str, int]]] = None, limit: Optional[int] = None, + sort: Optional[List[Tuple[str, str]]] = None, id_from: Optional[str] = None, batch_size: int = 250, as_pandas=None, @@ -486,6 +487,7 @@ def load( vector: Vector configuration for a semantic search ids: If provided, load dataset records with given ids. limit: The number of records to retrieve. + sort: The fields on which to sort [(, 'asc|decs')]. id_from: If provided, starts gathering the records starting from that Record. As the Records returned with the load method are sorted by ID, ´id_from´ can be used to load using batches. @@ -519,6 +521,7 @@ def load( vector=vector, ids=ids, limit=limit, + sort=sort, id_from=id_from, batch_size=batch_size, ) @@ -550,6 +553,7 @@ def load( query=query, ids=ids, limit=limit, + sort=sort, id_from=id_from, ) @@ -699,7 +703,7 @@ def _load_records_old_fashion( ) records = [sdk_record.to_client() for sdk_record in response.parsed] - return dataset_class(self.__sort_records_by_id__(records)) + return dataset_class(records) def _load_records_new_fashion( self, @@ -708,6 +712,7 @@ def _load_records_new_fashion( vector: Optional[Tuple[str, List[float]]] = None, ids: Optional[List[Union[str, int]]] = None, limit: Optional[int] = None, + sort: Optional[List[Tuple[str, str]]] = None, id_from: Optional[str] = None, batch_size: int = 250, ) -> Dataset: @@ -738,6 +743,9 @@ def _load_records_new_fashion( ) if vector: + if sort is not None: + _LOGGER.warning("Results are sorted by vector similarity, so 'sort' parameter is ignored.") + vector_search = VectorSearch( name=vector[0], value=vector[1], @@ -756,6 +764,7 @@ def _load_records_new_fashion( name=name, projection={"*"}, limit=limit, + sort=sort, id_from=id_from, batch_size=batch_size, # Query @@ -763,12 +772,4 @@ def _load_records_new_fashion( ids=ids, ) records = [sdk_record_class.parse_obj(r).to_client() for r in records] - return dataset_class(self.__sort_records_by_id__(records)) - - def __sort_records_by_id__(self, records: list) -> list: - try: - records_sorted_by_id = sorted(records, key=lambda x: x.id) - # record ids can be a mix of int/str -> sort all as str type - except TypeError: - records_sorted_by_id = sorted(records, key=lambda x: str(x.id)) - return records_sorted_by_id + return dataset_class(records) diff --git a/src/argilla/server/apis/v0/handlers/records_search.py b/src/argilla/server/apis/v0/handlers/records_search.py index 44a87b3d8d..9043db3f48 100644 --- a/src/argilla/server/apis/v0/handlers/records_search.py +++ b/src/argilla/server/apis/v0/handlers/records_search.py @@ -11,20 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json from typing import List, Optional, Union from fastapi import APIRouter, Depends, Query, Security -from pydantic import BaseModel +from pydantic import BaseModel, Field from argilla.client.sdk.token_classification.models import TokenClassificationQuery +from argilla.server.apis.v0.models.commons.model import SortableField from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies from argilla.server.apis.v0.models.text2text import Text2TextQuery from argilla.server.apis.v0.models.text_classification import TextClassificationQuery +from argilla.server.daos.backend import GenericElasticEngineBackend +from argilla.server.daos.backend.generic_elastic import PaginatedSortInfo from argilla.server.security import auth from argilla.server.security.model import User from argilla.server.services.datasets import DatasetsService -from argilla.server.services.search.service import SearchRecordsService # TODO(@frascuchon): This will be merged with `records.py` # once the similarity search feature is merged into develop @@ -37,23 +39,30 @@ def configure_router(router: APIRouter): Text2TextQuery, ] - class SearchRecordsRequest(BaseModel): - query: Optional[QueryType] = None - next_idx: Optional[str] = None - fields: Optional[List[str]] = None + class ScanDatasetRecordsRequest(BaseModel): + query: Optional[QueryType] + next_idx: Optional[str] = Field(description="Field to fetch data from a record id") + fields: Optional[List[str]] + sort_by: List[SortableField] = Field( + default_factory=list, + description="Set a sort config for records scan. " + "The `next_id` field will be ignored if a sort configuration is found", + ) + next_page_cfg: Optional[str] = Field( + description="Field to paginate over scan results. Use value fetched from previous response" + ) - class SearchRecordsResponse(BaseModel): - next_idx: Optional[str] + class ScanDatasetRecordsResponse(BaseModel): records: List[dict] + next_idx: Optional[str] + next_page_cfg: Optional[str] @router.post( - "/{name}/records/:search", - operation_id="search_dataset_records", - response_model=SearchRecordsResponse, + "/{name}/records/:search", operation_id="scan_dataset_records", response_model=ScanDatasetRecordsResponse ) - async def search_dataset_records( + async def scan_dataset_records( name: str, - request: Optional[SearchRecordsRequest] = None, + request: Optional[ScanDatasetRecordsRequest] = None, limit: int = Query( default=100, gte=0, @@ -62,35 +71,37 @@ async def search_dataset_records( ), request_deps: CommonTaskHandlerDependencies = Depends(), service: DatasetsService = Depends(DatasetsService.get_instance), - search: SearchRecordsService = Depends(SearchRecordsService.get_instance), + engine: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance), current_user: User = Security(auth.get_user, scopes=[]), ): - found = service.find_by_name( - user=current_user, - name=name, - workspace=request_deps.workspace, - ) + found = service.find_by_name(user=current_user, name=name, workspace=request_deps.workspace) - request = request or SearchRecordsRequest() + request = request or ScanDatasetRecordsRequest() + paginated_sort = PaginatedSortInfo(sort_by=request.sort_by or [SortableField(id="id")]) + if request.next_page_cfg: + try: + data = json.loads(request.next_page_cfg) + paginated_sort = PaginatedSortInfo.parse_obj(data) + except Exception: + pass + elif request.next_idx and not request.sort_by: + paginated_sort.next_search_params = [request.next_idx] - docs = search.scan_records( - dataset=found, - query=request.query, - id_from=request.next_idx, - projection=request.fields, - limit=limit, + docs = engine.scan_records( + id=found.id, query=request.query, sort=paginated_sort, include_fields=request.fields, limit=limit ) + docs = list(docs) - last_doc = docs[-1] if docs else {} + for doc in docs: + # Removing sort config for each document and keep the last one, used for next page configuration + paginated_sort.next_search_params = doc.pop("sort", None) - return SearchRecordsResponse( - next_idx=last_doc.get("id"), - records=docs, - ) + next_idx = None + if paginated_sort.next_search_params and not request.sort_by: + next_idx = paginated_sort.next_search_params[0] + + return ScanDatasetRecordsResponse(next_idx=next_idx, next_page_cfg=paginated_sort.json(), records=docs) -router = APIRouter( - tags=["datasets"], - prefix="/datasets", -) +router = APIRouter(tags=["datasets"], prefix="/datasets") configure_router(router) diff --git a/src/argilla/server/daos/backend/client_adapters/base.py b/src/argilla/server/daos/backend/client_adapters/base.py index 43c7210f30..62d15b2c44 100644 --- a/src/argilla/server/daos/backend/client_adapters/base.py +++ b/src/argilla/server/daos/backend/client_adapters/base.py @@ -62,18 +62,17 @@ def delete_docs_by_query( pass @abstractmethod - def list_index_documents( + def scan_docs( self, index: str, - query: Optional[BaseQuery] = None, + query: BaseQuery, + sort: SortConfig, size: Optional[int] = None, fetch_once: bool = False, - id_from: Optional[str] = None, + search_from_params: Optional[Any] = None, enable_highlight: bool = False, - sort: Optional[SortConfig] = None, include_fields: Optional[List[str]] = None, exclude_fields: Optional[List[str]] = None, - shuffle: bool = False, ) -> Iterable[Dict[str, Any]]: pass diff --git a/src/argilla/server/daos/backend/client_adapters/opensearch.py b/src/argilla/server/daos/backend/client_adapters/opensearch.py index 83629300cb..32f188d821 100644 --- a/src/argilla/server/daos/backend/client_adapters/opensearch.py +++ b/src/argilla/server/daos/backend/client_adapters/opensearch.py @@ -309,37 +309,29 @@ def _update_by_query( "updated": response["updated"], } - def list_index_documents( + def scan_docs( self, index: str, - query: Optional[BaseQuery] = None, + query: BaseQuery, + sort: SortConfig, size: Optional[int] = None, fetch_once: bool = False, - id_from: Optional[str] = None, + search_from_params: Optional[Any] = None, enable_highlight: bool = False, - sort: Optional[SortConfig] = None, include_fields: Optional[List[str]] = None, exclude_fields: Optional[List[str]] = None, - shuffle: bool = False, ) -> Iterable[Dict[str, Any]]: batch_size = size or 500 highlight = self.highlight if enable_highlight else None - index_schema = self.get_index_schema(index=index) - if index_schema and "id" in index_schema["mappings"]["properties"]: - sort = SortConfig(sort_by=[SortableField(id="id")]) - else: - fetch_once = True - es_query = self.query_builder.map_2_es_query( query=query, - schema=index_schema, + schema=self.get_index_schema(index=index), sort=sort, - id_from=id_from, + search_after_param=search_from_params, include_fields=include_fields, exclude_fields=exclude_fields, highlight=highlight, - shuffle=shuffle, ) es_query = es_query.copy() or {} response = self.__client__.search( @@ -352,18 +344,15 @@ def list_index_documents( while response["hits"]["hits"]: hit = None for hit in response["hits"]["hits"]: - yield self._normalize_document( - document=hit, - highlight=highlight, - ) + yield self._normalize_document(document=hit, highlight=highlight, add_sort_info=True) records_yield += 1 if fetch_once or (size and size >= records_yield): break - last_id = hit["_id"] - es_query["search_after"] = [last_id] - response = self.__client__.search(index=index, body=es_query, size=size) + next_search_from = hit["sort"] + es_query["search_after"] = next_search_from + response = self.__client__.search(index=index, body=es_query, size=size, track_total_hits=False) def _process_search_results( self, @@ -766,12 +755,16 @@ def _normalize_document( document: Dict[str, Any], highlight: Optional[HighlightParser] = None, is_phrase_query: bool = True, + add_sort_info: bool = False, ): data = { **document["_source"], "id": document["_id"], } + if add_sort_info and "sort" in document: + data["sort"] = document["sort"] + if highlight: keywords = highlight.parse_highligth_results( doc=document, diff --git a/src/argilla/server/daos/backend/generic_elastic.py b/src/argilla/server/daos/backend/generic_elastic.py index 179f1ed2c2..bbbf0172bc 100644 --- a/src/argilla/server/daos/backend/generic_elastic.py +++ b/src/argilla/server/daos/backend/generic_elastic.py @@ -14,6 +14,8 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple +from pydantic import BaseModel, Field + from argilla.logging import LoggingMixin from argilla.server.commons.models import TaskType from argilla.server.daos.backend.base import IndexNotFoundError, InvalidSearchError @@ -54,6 +56,12 @@ def dataset_records_index(dataset_id: str) -> str: return index_mame_template.format(dataset_id) +class PaginatedSortInfo(BaseModel): + shuffle: bool = False + sort_by: List[SortableField] = Field(default_factory=list) + next_search_params: Optional[Any] = None + + class GenericElasticEngineBackend(LoggingMixin): """ Encapsulates logic about the communication, queries and index mapping @@ -204,29 +212,27 @@ def search_records( return total, records - # TODO(@frascuchon): Include sort parameter def scan_records( self, id: str, - query: Optional[BackendRecordsQuery] = None, - id_from: Optional[str] = None, + query: BackendRecordsQuery, + sort: PaginatedSortInfo, limit: Optional[int] = None, - shuffle: bool = False, include_fields: Optional[List[str]] = None, exclude_fields: Optional[List[str]] = None, ) -> Iterable[Dict[str, Any]]: index = dataset_records_index(id) - yield from self.client.list_index_documents( + yield from self.client.scan_docs( index=index, query=query, + sort=SortConfig(shuffle=sort.shuffle, sort_by=sort.sort_by), size=limit, - id_from=id_from, - fetch_once=shuffle, + search_from_params=sort.next_search_params, + fetch_once=sort.shuffle, include_fields=include_fields, exclude_fields=exclude_fields, enable_highlight=True, - shuffle=shuffle, ) def open(self, id: str): @@ -379,6 +385,7 @@ def create_datasets_index(self, force_recreate: bool = False): force_recreate=force_recreate, mappings=datasets_index_mappings(), ) + # TODO: Remove this section of code if settings.enable_migration: try: self._migrate_from_rubrix() @@ -395,7 +402,7 @@ def _migrate_from_rubrix(self): target_index=target_index, reindex=True, ) - for doc in self.client.list_index_documents(index=source_index): + for doc in self.client.scan_docs(index=source_index, query=BaseQuery(), sort=SortConfig()): dataset_id = doc["id"] index = self._old_dataset_index(dataset_id) alias = dataset_records_index(dataset_id=dataset_id) @@ -414,9 +421,10 @@ def _update_dynamic_mapping(self, index: str): ) def list_datasets(self, query: BaseDatasetsQuery): - return self.client.list_index_documents( + return self.client.scan_docs( index=DATASETS_INDEX_NAME, query=query, + sort=SortConfig(), fetch_once=True, size=self.__MAX_NUMBER_OF_LISTED_DATASETS__, ) diff --git a/src/argilla/server/daos/backend/mappings/helpers.py b/src/argilla/server/daos/backend/mappings/helpers.py index 4b32da9244..b8f0441f1b 100644 --- a/src/argilla/server/daos/backend/mappings/helpers.py +++ b/src/argilla/server/daos/backend/mappings/helpers.py @@ -204,8 +204,8 @@ def tasks_common_mappings(): "predictions": mappings.dynamic_field(), "annotations": mappings.dynamic_field(), "status": mappings.keyword_field(), - "event_timestamp": {"type": "date"}, - "last_updated": {"type": "date"}, + "event_timestamp": {"type": "date_nanos"}, + "last_updated": {"type": "date_nanos"}, "annotated_by": mappings.keyword_field(enable_text_search=True), "predicted_by": mappings.keyword_field(enable_text_search=True), "metrics": mappings.dynamic_field(), diff --git a/src/argilla/server/daos/backend/search/query_builder.py b/src/argilla/server/daos/backend/search/query_builder.py index d093c6ca85..354dcb8ec7 100644 --- a/src/argilla/server/daos/backend/search/query_builder.py +++ b/src/argilla/server/daos/backend/search/query_builder.py @@ -166,16 +166,15 @@ def _search_to_es_query( def map_2_es_query( self, - schema: Optional[Dict[str, Any]] = None, - query: Optional[BackendQuery] = None, - sort: Optional[SortConfig] = None, + schema: Dict[str, Any], + query: BackendQuery, + sort: SortConfig = SortConfig(), exclude_fields: Optional[List[str]] = None, include_fields: List[str] = None, doc_from: Optional[int] = None, highlight: Optional[HighlightParser] = None, size: Optional[int] = None, - id_from: Optional[str] = None, - shuffle: bool = False, + search_after_param: Optional[Any] = None, ) -> Dict[str, Any]: if query and query.raw_query: es_query = {"query": query.raw_query} @@ -186,16 +185,15 @@ def map_2_es_query( else {"query": self._search_to_es_query(schema, query)} ) - if id_from: - es_query["search_after"] = [id_from] - sort = SortConfig() # sort by id as default + if search_after_param: + es_query["search_after"] = search_after_param - if shuffle: + if sort.shuffle: self._setup_random_score(es_query) - - es_sort = self.map_2_es_sort_configuration(schema=schema, sort=sort) - if es_sort and not shuffle: - es_query["sort"] = es_sort + else: + es_sort = self.map_2_es_sort_configuration(schema=schema, sort=sort) + if es_sort: + es_query["sort"] = es_sort if doc_from: es_query["from"] = doc_from @@ -227,12 +225,8 @@ def map_2_es_query( return es_query - def map_2_es_sort_configuration( - self, - schema: Optional[Dict[str, Any]] = None, - sort: Optional[SortConfig] = None, - ) -> Optional[List[Dict[str, Any]]]: - if not sort: + def map_2_es_sort_configuration(self, schema: Dict[str, Any], sort: SortConfig) -> Optional[List[Dict[str, Any]]]: + if not sort.sort_by or sort.shuffle: return None # TODO(@frascuchon): compute valid list from the schema diff --git a/src/argilla/server/daos/records.py b/src/argilla/server/daos/records.py index 83ea922db1..73e6da72a2 100644 --- a/src/argilla/server/daos/records.py +++ b/src/argilla/server/daos/records.py @@ -20,7 +20,8 @@ from argilla.server.daos.backend import GenericElasticEngineBackend from argilla.server.daos.backend.base import ClosedIndexError, IndexNotFoundError -from argilla.server.daos.backend.search.model import BaseRecordsQuery +from argilla.server.daos.backend.generic_elastic import PaginatedSortInfo +from argilla.server.daos.backend.search.model import BaseRecordsQuery, SortableField from argilla.server.daos.models.datasets import DatasetDB from argilla.server.daos.models.records import ( DaoRecordsSearch, @@ -177,7 +178,6 @@ def scan_dataset( search: Optional[DaoRecordsSearch] = None, limit: Optional[int] = 1000, id_from: Optional[str] = None, - shuffle: bool = False, include_fields: Optional[Set[str]] = None, exclude_fields: Optional[Set[str]] = None, ) -> Iterable[Dict[str, Any]]: @@ -203,13 +203,18 @@ def scan_dataset( ------- An iterable over found dataset records """ + search = search or DaoRecordsSearch() + next_search_params = [id_from] if id_from else None + paginated_sort = PaginatedSortInfo( + sort_by=[SortableField(id="id")], shuffle=search.sort.shuffle, next_search_params=next_search_params + ) + return self._es.scan_records( id=dataset.id, query=search.query, + sort=paginated_sort, limit=limit, - id_from=id_from, - shuffle=shuffle, include_fields=list(include_fields) if include_fields else None, exclude_fields=list(exclude_fields) if exclude_fields else None, ) diff --git a/src/argilla/server/services/metrics/service.py b/src/argilla/server/services/metrics/service.py index 90b86cf294..8b89bf5ae1 100644 --- a/src/argilla/server/services/metrics/service.py +++ b/src/argilla/server/services/metrics/service.py @@ -16,6 +16,7 @@ from fastapi import Depends +from argilla.server.daos.backend.search.model import SortConfig from argilla.server.daos.models.records import DaoRecordsSearch from argilla.server.daos.records import DatasetRecordsDAO from argilla.server.services.datasets import ServiceDataset @@ -99,8 +100,7 @@ def summarize_metric( query = metric.prepare_query(query) records = self.__dao__.scan_dataset( dataset, - search=DaoRecordsSearch(query=query), - shuffle=metric.shuffle_records, + search=DaoRecordsSearch(query=query, sort=SortConfig(shuffle=metric.shuffle_records)), limit=metric.records_to_fetch, exclude_fields={ "vectors", diff --git a/tests/client/test_api.py b/tests/client/test_api.py index 217c01a03f..d6dd610800 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -14,6 +14,7 @@ # limitations under the License. import concurrent.futures import datetime +import re from time import sleep from typing import Any, Iterable @@ -22,6 +23,7 @@ import httpx import pandas as pd import pytest +from argilla import TextClassificationRecord from argilla._constants import ( _OLD_WORKSPACE_HEADER_NAME, DEFAULT_API_KEY, @@ -631,6 +633,25 @@ def test_load_with_query(mocked_client, supported_vector_search): assert ds.id.iloc[0] == 1 +def test_load_with_sort(mocked_client, supported_vector_search): + dataset = "test_load_with_sort" + mocked_client.delete(f"/api/datasets/{dataset}") + sleep(1) + + expected_data = 4 + api.log([TextClassificationRecord(text=text) for text in ["This is my text"] * expected_data], name=dataset) + with pytest.raises( + ValueError, match=re.escape("sort must be a dict formatted as List[Tuple[, 'asc|desc']]") + ): + api.load(name=dataset, sort=[("event_timestamp", "ascc")]) + + ds = api.load(name=dataset, sort=[("event_timestamp", "asc")]) + assert all([(ds[idx].event_timestamp <= ds[idx + 1].event_timestamp) for idx in range(len(ds) - 1)]) + + ds = api.load(name=dataset, sort=[("event_timestamp", "desc")]) + assert all([(ds[idx].event_timestamp >= ds[idx + 1].event_timestamp) for idx in range(len(ds) - 1)]) + + def test_load_as_pandas(mocked_client, supported_vector_search): dataset = "test_load_as_pandas" mocked_client.delete(f"/api/datasets/{dataset}") @@ -755,7 +776,7 @@ def test_load_sort(mocked_client): assert list(df.id) == [1, 11, "11str", "1str", 2, "2str"] ds = api.load(name=dataset, ids=[1, 2, 11]) df = ds.to_pandas() - assert list(df.id) == [1, 2, 11] + assert list(df.id) == [1, 11, 2] ds = api.load(name=dataset, ids=["1str", "2str", "11str"]) df = ds.to_pandas() assert list(df.id) == ["11str", "1str", "2str"] diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py index 00d9eff1d4..3093b94b21 100644 --- a/tests/labeling/text_classification/test_label_errors.py +++ b/tests/labeling/text_classification/test_label_errors.py @@ -219,4 +219,4 @@ def dataset(mocked_client, records): def test_find_label_errors_integration(dataset): records = rg.load(dataset) recs = find_label_errors(records) - assert [rec.id for rec in recs] == list(range(0, 11, 2)) + list(range(1, 12, 2)) + assert [rec.id for rec in recs] == [0, 10, 2, 4, 6, 8, 1, 11, 3, 5, 7, 9] diff --git a/tests/metrics/test_token_classification.py b/tests/metrics/test_token_classification.py index 466b71ecca..0fc27b1a16 100644 --- a/tests/metrics/test_token_classification.py +++ b/tests/metrics/test_token_classification.py @@ -305,7 +305,7 @@ def test_metrics_without_data(mocked_client, metric, expected_results, monkeypat results.visualize() -def test_metrics_for_text_classification(mocked_client): +def test_metrics_for_token_classification(mocked_client): dataset = "test_metrics_for_token_classification" text = "test the f1 metric of the token classification task" diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py index 4db8dff1c3..85386ad762 100644 --- a/tests/server/backend/test_query_builder.py +++ b/tests/server/backend/test_query_builder.py @@ -69,10 +69,25 @@ def test_build_sort_configuration(index_schema, sort_cfg, expected_sort): def test_build_sort_with_wrong_field_name(): builder = EsQueryBuilder() + index_schema = { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + } + } + } + with pytest.raises(Exception): - builder.map_2_es_sort_configuration(sort=SortConfig(sort_by=[SortableField(id="wat?!")])) + builder.map_2_es_sort_configuration(schema=index_schema, sort=SortConfig(sort_by=[SortableField(id="wat?!")])) def test_build_sort_without_sort_config(): builder = EsQueryBuilder() - assert builder.map_2_es_sort_configuration() is None + index_schema = { + "mappings": { + "properties": { + "id": {"type": "keyword"}, + } + } + } + assert builder.map_2_es_sort_configuration(sort=SortConfig(), schema=index_schema) is None From 3fce9151d82f943e5956137544a5ea4f434f5410 Mon Sep 17 00:00:00 2001 From: leiyre Date: Wed, 8 Mar 2023 09:30:27 +0100 Subject: [PATCH 38/45] feat: Bulk annotation improvement (#2437) # Description This PR improves annotation at record level and at bulk level for the three different tasks - Normalize the 2-step validation flow for all the tasks, except the single label text classification - Enhance the bulk annotation labeling for multi-label text classification - Disable automatic validation for multi-label text classification - Include reset and clear for all task - New styles Closes [#2264](https://github.com/argilla-io/argilla/issues/2264) **Type of change** (Please delete options that are not relevant. Remember to title the PR according to the type of change) - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Refactor (change restructuring the codebase without changing functionality) - [ ] Improvement (change adding some improvement to an existing functionality) - [ ] Documentation update **How Has This Been Tested** (Please describe the tests that you ran to verify your changes. And ideally, reference `tests`) - [ ] Test A - [ ] Test B **Checklist** - [x] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [x] follows the style guidelines of this project - [x] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works --------- Co-authored-by: Keith Cuniah <88380932+keithCuniah@users.noreply.github.com> Co-authored-by: keithCuniah --- frontend/assets/icons/clear.js | 27 + frontend/assets/icons/discard.js | 27 + frontend/assets/icons/index.js | 5 + frontend/assets/icons/pen.js | 27 + frontend/assets/icons/reset.js | 27 + frontend/assets/icons/time.js | 8 +- frontend/assets/icons/validate.js | 27 + .../scss/abstract/variables/_variables.scss | 34 +- frontend/components/base/BaseButton.vue | 9 + frontend/components/base/BaseToast/Toast.vue | 18 +- .../base/pagination/BasePagination.vue | 35 +- .../commons/header/filters/FiltersArea.vue | 2 +- .../commons/header/filters/SearchBar.vue | 2 +- .../header/filters/SelectOptionsSearch.vue | 2 +- .../header/global-actions/GlobalActions.vue | 24 +- .../global-actions/ValidateDiscardAction.vue | 258 ++++++--- .../BulkAnnotation.component.vue | 154 ++++++ .../BulkAnnotationForm.component.vue | 158 ++++++ .../validateDiscardAction.spec.js | 277 ++++++++++ .../results/RecordActionButtons.spec.js | 70 +++ .../commons/results/RecordActionButtons.vue | 108 ++++ .../commons/results/RecordExtraActions.vue | 21 - .../commons/results/ResultsLoading.vue | 2 +- .../commons/results/ResultsRecord.vue | 67 +-- .../components/commons/results/StatusTag.vue | 15 +- .../SimilarityRecordReference.component.vue | 2 +- .../header/TextClassificationHeader.vue | 174 ++++-- .../AnnotationLabelSelector.vue | 133 ----- ...TextClassificationBulkAnnotationSingle.vue | 133 +++++ ...ClassificationBulkAnnotation.component.vue | 86 +++ .../textClassificationBulkAnnotation.spec.js | 160 ++++++ .../labeling-rules/RuleLabelsDefinition.vue | 6 +- .../results/ClassifierExplorationArea.vue | 4 +- .../text-classifier/results/LabelPill.vue | 2 +- .../text-classifier/results/RecordInputs.vue | 7 +- .../results/RecordTextClassification.vue | 170 ++++-- .../ClassifierAnnotationArea.vue | 16 +- .../ClassifierAnnotationButton.vue | 7 +- .../classifierAnnotationArea.spec.js | 23 +- .../text2text/header/Text2TextHeader.vue | 28 +- .../results/RecordStringText2Text.vue | 7 +- .../text2text/results/RecordText2Text.vue | 65 +-- .../results/Text2TextContentEditable.vue | 97 ++-- .../text2text/results/Text2TextList.vue | 511 ++++++++---------- .../results/Text2TextPredictions.spec.js | 43 ++ .../results/Text2TextPredictions.vue | 226 ++++++++ .../header/EntitiesHeader.vue | 11 +- .../header/TokenClassificationHeader.vue | 57 +- .../results/RecordTokenClassification.vue | 112 ++-- .../token-classifier/results/TextSpan.vue | 2 +- .../results/TextSpanStatic.vue | 2 +- .../sidebar/SidebarCollapsableMentions.vue | 2 +- frontend/database/modules/datasets.js | 238 ++++++-- frontend/database/modules/notifications.js | 4 +- frontend/models/DatasetViewSettings.js | 2 + frontend/models/Text2Text.js | 13 +- frontend/models/TextClassification.js | 11 +- frontend/models/TokenClassification.js | 3 +- frontend/models/Workspace.js | 1 - frontend/nuxt.config.js | 1 + .../custom-directives/badge.directive.js | 54 ++ .../RecordExtraAtions.spec.js.snap | 3 +- frontend/static/icons/clear.svg | 3 + frontend/static/icons/discard.svg | 3 + frontend/static/icons/pen.svg | 4 + frontend/static/icons/reset.svg | 3 + frontend/static/icons/time.svg | 6 +- frontend/static/icons/validate.svg | 3 + 68 files changed, 2891 insertions(+), 951 deletions(-) create mode 100644 frontend/assets/icons/clear.js create mode 100644 frontend/assets/icons/discard.js create mode 100644 frontend/assets/icons/pen.js create mode 100644 frontend/assets/icons/reset.js create mode 100644 frontend/assets/icons/validate.js create mode 100644 frontend/components/commons/header/global-actions/bulk-annotation/BulkAnnotation.component.vue create mode 100644 frontend/components/commons/header/global-actions/bulk-annotation/BulkAnnotationForm.component.vue create mode 100644 frontend/components/commons/header/global-actions/validateDiscardAction.spec.js create mode 100644 frontend/components/commons/results/RecordActionButtons.spec.js create mode 100644 frontend/components/commons/results/RecordActionButtons.vue delete mode 100644 frontend/components/text-classifier/header/global-actions/AnnotationLabelSelector.vue create mode 100644 frontend/components/text-classifier/header/global-actions/TextClassificationBulkAnnotationSingle.vue create mode 100644 frontend/components/text-classifier/header/global-actions/text-classification-bulk-annotation/TextClassificationBulkAnnotation.component.vue create mode 100644 frontend/components/text-classifier/header/global-actions/text-classification-bulk-annotation/textClassificationBulkAnnotation.spec.js create mode 100644 frontend/components/text2text/results/Text2TextPredictions.spec.js create mode 100644 frontend/components/text2text/results/Text2TextPredictions.vue create mode 100644 frontend/plugins/custom-directives/badge.directive.js create mode 100644 frontend/static/icons/clear.svg create mode 100644 frontend/static/icons/discard.svg create mode 100644 frontend/static/icons/pen.svg create mode 100644 frontend/static/icons/reset.svg create mode 100644 frontend/static/icons/validate.svg diff --git a/frontend/assets/icons/clear.js b/frontend/assets/icons/clear.js new file mode 100644 index 0000000000..7a85909e89 --- /dev/null +++ b/frontend/assets/icons/clear.js @@ -0,0 +1,27 @@ +/* + * coding=utf-8 + * Copyright 2021-present, the Recognai S.L. team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* eslint-disable */ +var icon = require('vue-svgicon') +icon.register({ + 'clear': { + width: 18, + height: 18, + viewBox: '0 0 18 18', + data: '' + } +}) \ No newline at end of file diff --git a/frontend/assets/icons/discard.js b/frontend/assets/icons/discard.js new file mode 100644 index 0000000000..fbcbce6696 --- /dev/null +++ b/frontend/assets/icons/discard.js @@ -0,0 +1,27 @@ +/* + * coding=utf-8 + * Copyright 2021-present, the Recognai S.L. team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* eslint-disable */ +var icon = require('vue-svgicon') +icon.register({ + 'discard': { + width: 16, + height: 16, + viewBox: '0 0 16 16', + data: '' + } +}) \ No newline at end of file diff --git a/frontend/assets/icons/index.js b/frontend/assets/icons/index.js index a95244b0a8..b9ae1630f4 100644 --- a/frontend/assets/icons/index.js +++ b/frontend/assets/icons/index.js @@ -6,9 +6,11 @@ require('./chevron-down') require('./chevron-left') require('./chevron-right') require('./chevron-up') +require('./clear') require('./close') require('./copy') require('./danger') +require('./discard') require('./exploration') require('./external') require('./filter') @@ -20,8 +22,10 @@ require('./matching') require('./math-plus') require('./meatballs') require('./no-matching') +require('./pen') require('./progress') require('./refresh') +require('./reset') require('./row-last') require('./search') require('./similarity') @@ -32,4 +36,5 @@ require('./support') require('./time') require('./trash-empty') require('./unavailable') +require('./validate') require('./weak-labeling') diff --git a/frontend/assets/icons/pen.js b/frontend/assets/icons/pen.js new file mode 100644 index 0000000000..6b832e330f --- /dev/null +++ b/frontend/assets/icons/pen.js @@ -0,0 +1,27 @@ +/* + * coding=utf-8 + * Copyright 2021-present, the Recognai S.L. team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* eslint-disable */ +var icon = require('vue-svgicon') +icon.register({ + 'pen': { + width: 16, + height: 16, + viewBox: '0 0 16 16', + data: '' + } +}) \ No newline at end of file diff --git a/frontend/assets/icons/reset.js b/frontend/assets/icons/reset.js new file mode 100644 index 0000000000..a08bbe1fc1 --- /dev/null +++ b/frontend/assets/icons/reset.js @@ -0,0 +1,27 @@ +/* + * coding=utf-8 + * Copyright 2021-present, the Recognai S.L. team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* eslint-disable */ +var icon = require('vue-svgicon') +icon.register({ + 'reset': { + width: 16, + height: 18, + viewBox: '0 0 16 18', + data: '' + } +}) \ No newline at end of file diff --git a/frontend/assets/icons/time.js b/frontend/assets/icons/time.js index 4773ec809c..924037798a 100644 --- a/frontend/assets/icons/time.js +++ b/frontend/assets/icons/time.js @@ -19,9 +19,9 @@ var icon = require('vue-svgicon') icon.register({ 'time': { - width: 40, - height: 40, - viewBox: '0 0 40 40', - data: '' + width: 30, + height: 30, + viewBox: '0 0 30 30', + data: '' } }) \ No newline at end of file diff --git a/frontend/assets/icons/validate.js b/frontend/assets/icons/validate.js new file mode 100644 index 0000000000..771557c84c --- /dev/null +++ b/frontend/assets/icons/validate.js @@ -0,0 +1,27 @@ +/* + * coding=utf-8 + * Copyright 2021-present, the Recognai S.L. team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* eslint-disable */ +var icon = require('vue-svgicon') +icon.register({ + 'validate': { + width: 31, + height: 31, + viewBox: '0 0 31 31', + data: '' + } +}) \ No newline at end of file diff --git a/frontend/assets/scss/abstract/variables/_variables.scss b/frontend/assets/scss/abstract/variables/_variables.scss index 0f31b013d9..bde01ed5e3 100644 --- a/frontend/assets/scss/abstract/variables/_variables.scss +++ b/frontend/assets/scss/abstract/variables/_variables.scss @@ -15,7 +15,6 @@ * limitations under the License. */ - // Color map //----------- $palettes: ( @@ -27,18 +26,20 @@ $palettes: ( ), grey: ( 100: #212121, - 200: #4D4D4D, + 200: #4d4d4d, 300: #686a6d, - 600: #E6E6E6, - 700: #F5F5F5, - 800: #FAFAFA, + 400: #838589, + 600: #e6e6e6, + 700: #f5f5f5, + 800: #fafafa, ), blue: ( - 500: #3E5CC9, + 100: #52a3ed, + 500: #3e5cc9, ), purple: ( 100: #3b3269, - 200: #4C4EA3, + 200: #4c4ea3, ), pink: ( base: #f2067a, @@ -49,12 +50,12 @@ $palettes: ( dark: #1d1d1d, ), orange-red-crayola: ( - base: #FF675F + base: #ff675f, ), apricot: ( - light: #FFFBFA, - base: #F8C0A7, - dark: #FD8F5E, + light: #fffbfa, + base: #f8c0a7, + dark: #fd8f5e, ), brown: ( base: #bb720a, @@ -80,23 +81,24 @@ $bg: palette(grey, 800); // Black opacity //----------- $black-87: rgba(0, 0, 0, 0.87); -$black-54: rgba(0, 0, 0, 0.60); +$black-54: rgba(0, 0, 0, 0.6); $black-37: rgba(0, 0, 0, 0.37); -$black-20: rgba(0, 0, 0, 0.20); -$black-10: rgba(0, 0, 0, 0.10); +$black-20: rgba(0, 0, 0, 0.2); +$black-10: rgba(0, 0, 0, 0.1); $black-4: rgba(0, 0, 0, 0.04); // States //----------- $error: palette(pink); -$warning: palette(brown); +$warning: palette(apricot, dark); $success: palette(green); $info: palette(grey, 600); // Fonts //----------- $primary-font-family: "Inter", "Helvetica", "Arial", sans-serif; -$secondary-font-family: "raptor_v2_premiumbold", "Helvetica", "Arial", sans-serif; +$secondary-font-family: "raptor_v2_premiumbold", "Helvetica", "Arial", + sans-serif; $base-font-size: 14px; $base-line-height: 1.4em; diff --git a/frontend/components/base/BaseButton.vue b/frontend/components/base/BaseButton.vue index 978c241a63..2229247647 100644 --- a/frontend/components/base/BaseButton.vue +++ b/frontend/components/base/BaseButton.vue @@ -206,6 +206,15 @@ export default { fill: $black-54; } } + &.text { + background: none; + color: $black-54; + padding-left: 0; + padding-right: 0; + &:hover { + color: $black-87; + } + } } .clear { background: none; diff --git a/frontend/components/base/BaseToast/Toast.vue b/frontend/components/base/BaseToast/Toast.vue index f33a6ebf9e..8815b93dd7 100644 --- a/frontend/components/base/BaseToast/Toast.vue +++ b/frontend/components/base/BaseToast/Toast.vue @@ -26,12 +26,16 @@ class="toast" :class="[`toast-${type}`, `is-${position}`]" @mouseover="toggleTimer(true)" + @mouseleave="toggleTimer(false)" >

- {{ - buttonText - }} + {{ buttonText }} +
@@ -83,7 +87,7 @@ export default { }, onClick: { type: Function, - default: () => {}, + default: async () => {}, }, queue: Boolean, pauseOnHover: { @@ -205,7 +209,7 @@ export default { this.isActive = true; this.timer = new Timer(this.close, this.duration); }, - whenClicked(...arg) { + async whenClicked(...arg) { if (!this.dismissible) return; this.onClick.apply(null, arg); this.close(); @@ -247,7 +251,7 @@ $toast-colors: map-merge( // Colors @each $color, $value in $toast-colors { .toast-#{$color} { - border: 2px solid $value; + border: 1px solid $value; } } @@ -303,9 +307,7 @@ $toast-colors: map-merge( } } &__button { - color: palette(blue, 300); margin: 0 3em; - cursor: pointer; } &__close { margin-right: 1em; diff --git a/frontend/components/base/pagination/BasePagination.vue b/frontend/components/base/pagination/BasePagination.vue index 6b57ac8d5f..0e45fe131f 100644 --- a/frontend/components/base/pagination/BasePagination.vue +++ b/frontend/components/base/pagination/BasePagination.vue @@ -92,11 +92,7 @@
@@ -158,9 +154,28 @@ export default { (p) => p !== this.paginationSize ); }, + currentPaginationPosition() { + return this.isSingleRecordPage + ? this.currentPage + : this.currentPaginationRange; + }, + currentPaginationRange() { + return `${this.currentRangeFrom} - ${this.currentRangeTo}`; + }, + currentRangeFrom() { + return this.paginationSize * this.currentPage - (this.paginationSize - 1); + }, + currentRangeTo() { + return this.paginationSize * this.currentPage > this.totalItems + ? this.totalItems + : this.paginationSize * this.currentPage; + }, currentPage() { return this.paginationSettings.page; }, + isSingleRecordPage() { + return this.paginationSize === 1; + }, pages() { const rangeOfPages = (this.visiblePagesRange - 1) / 2; let start = this.currentPage - rangeOfPages; @@ -232,9 +247,9 @@ export default { keyDown(event) { const arrowRight = event.keyCode === 39; const arrowLeft = event.keyCode === 37; - const focusInInput = event.target.tagName?.toLowerCase() === "input"; const allowPagination = - !this.paginationSettings.disabledShortCutPagination && !focusInInput; + !this.paginationSettings.disabledShortCutPagination && + !this.isEditableAreaFocused(event); if (allowPagination) { if (arrowRight) { this.nextPage(); @@ -243,6 +258,12 @@ export default { } } }, + isEditableAreaFocused(event) { + return ( + event.target.tagName?.toLowerCase() === "input" || + event.target.contentEditable?.toLowerCase() === "true" + ); + }, showNotification() { Notification.dispatch("notify", { message: this.message, diff --git a/frontend/components/commons/header/filters/FiltersArea.vue b/frontend/components/commons/header/filters/FiltersArea.vue index a81b0afd90..f038c2305e 100644 --- a/frontend/components/commons/header/filters/FiltersArea.vue +++ b/frontend/components/commons/header/filters/FiltersArea.vue @@ -162,7 +162,7 @@ export default { align-items: center; } &__content { - padding: $base-space * 4 0; + padding: $base-space * 2 0; position: relative; width: 100%; } diff --git a/frontend/components/commons/header/filters/SearchBar.vue b/frontend/components/commons/header/filters/SearchBar.vue index 80092f8212..ffcd267660 100644 --- a/frontend/components/commons/header/filters/SearchBar.vue +++ b/frontend/components/commons/header/filters/SearchBar.vue @@ -123,7 +123,7 @@ export default { margin: auto 1em auto 1em; } &:hover { - box-shadow: $shadow; + box-shadow: 0 6px 10px 0 rgba(0, 0, 0, 0.1); } } diff --git a/frontend/components/commons/header/filters/SelectOptionsSearch.vue b/frontend/components/commons/header/filters/SelectOptionsSearch.vue index 9e4610c5c5..4eaa79b748 100644 --- a/frontend/components/commons/header/filters/SelectOptionsSearch.vue +++ b/frontend/components/commons/header/filters/SelectOptionsSearch.vue @@ -9,7 +9,7 @@ /> diff --git a/frontend/components/commons/header/global-actions/validateDiscardAction.spec.js b/frontend/components/commons/header/global-actions/validateDiscardAction.spec.js new file mode 100644 index 0000000000..d401605851 --- /dev/null +++ b/frontend/components/commons/header/global-actions/validateDiscardAction.spec.js @@ -0,0 +1,277 @@ +import { shallowMount } from "@vue/test-utils"; +import ValidateDiscardActionComponent from "./ValidateDiscardAction"; + +let wrapper = null; +const options = { + stubs: ["base-checkbox", "annotation-label-selector", "base-button"], + directives: { + badge() { + /* stub */ + }, + }, + propsData: { + datasetId: ["owner", "name"], + datasetTask: "TextClassification", + visibleRecords: [ + { + id: "b5a23810-10e9-4bff-adf3-447a45667299", + metadata: {}, + annotation: { + agent: "recognai", + labels: [ + { + class: "Aplazamiento de pago", + score: 1, + }, + ], + }, + status: "Edited", + selected: true, + vectors: {}, + last_updated: "2023-02-14T13:38:00.319183", + search_keywords: [], + inputs: { + text: "Esto es un registro sin predicciones ni anotaciones", + }, + multi_label: true, + currentAnnotation: { + agent: "recognai", + labels: [ + { + class: "Aplazamiento de pago", + score: 1, + }, + { + class: "Alcantarillado/Pluviales", + score: 1, + }, + ], + }, + originStatus: "Discarded", + }, + { + id: "llamadas_correos_1238", + metadata: { + Fuente: "Correo", + }, + prediction: { + agent: "facsa_categories_v4", + labels: [ + { + class: "Otros", + score: 0.7129160762, + }, + { + class: "Problema calidad agua", + score: 0.0819710568, + }, + { + class: "Calidad del servicio", + score: 0.0295235571, + }, + { + class: "Cortes falta de pago", + score: 0.0239840839, + }, + { + class: "Alcantarillado/Pluviales", + score: 0.0225764699, + }, + { + class: "Contratación", + score: 0.013704461000000001, + }, + { + class: "Solicitan presencia personal FACSA instalación", + score: 0.0132667627, + }, + { + class: "Consulta administrativa oficinas", + score: 0.012263773, + }, + { + class: "Baja presión", + score: 0.0107420115, + }, + { + class: "Funcionamiento del contador", + score: 0.0099040149, + }, + { + class: "Recibos", + score: 0.0091423625, + }, + { + class: "Reposición obra civil", + score: 0.0085537238, + }, + { + class: "Filtración en garaje/bajo", + score: 0.0082879839, + }, + { + class: "Fuga en la vía pública", + score: 0.0069583142, + }, + { + class: "Información/Consultas", + score: 0.0060733184, + }, + { + class: "Vulnerabilidad", + score: 0.0049630683, + }, + { + class: "No tiene agua", + score: 0.0042145588000000005, + }, + { + class: "Facturación errónea", + score: 0.0037822824000000002, + }, + { + class: "Refacturación por fuga", + score: 0.0030815087, + }, + { + class: "Atención recibida", + score: 0.0030129601000000003, + }, + { + class: "Error de lectura", + score: 0.0026326664000000002, + }, + { + class: "Fuga en instalación interior", + score: 0.0024313715, + }, + { + class: "Rotura provocada", + score: 0.0015803818000000001, + }, + { + class: "Presupuestos", + score: 0.0011857918000000001, + }, + { + class: "Protección de datos", + score: 0.0009309166, + }, + { + class: "Aplazamiento de pago", + score: 0.0006013829, + }, + { + class: "descartado", + score: 0.0005980578, + }, + { + class: "Reparto de correspondencia", + score: 0.0005848766, + }, + { + class: "Solicitan cierre agua maniobras instalación abonado", + score: 0.0005321669, + }, + { + class: "Alta", + score: 0, + }, + { + class: "Baja", + score: 0, + }, + { + class: "Cambio de titular", + score: 0, + }, + ], + }, + annotation: { + agent: "recognai", + labels: [ + { + class: "Alcantarillado/Pluviales", + score: 1, + }, + { + class: "Contratación", + score: 1, + }, + { + class: "Problema calidad agua", + score: 1, + }, + { + class: "Funcionamiento del contador", + score: 1, + }, + ], + }, + status: "Validated", + selected: false, + vectors: {}, + last_updated: "2023-02-14T12:35:27.656875", + search_keywords: [], + inputs: { + text: "ANULACION TEMA CONTRAINCENDIOS NAVE SANTA Nº4 ROBERTO SURNAME-JORGE BADENES\n\nEscolta verificam que el que es l’aigua normal si va a\nnom nostre i tenim aigua en esta nau, per favor, jo crec que si, crec que es\nnau nº 4 i si paga jo l’aigua, pero verificameu per favor, el\ncontraincendis eixe, DONEM DE BAIXA!! No sabemn si la alquilarem ni si el\nproper negoci li fara falta, aixi que anulem-lo en cas de alquiler algu que li\nfage falta…pues ja seu pagara ell…el inquilino…com va fer\nJorge en el seu dia, ok?\nORDEN ANULACION, EN CASO DE QUE HAYA CONTADOR MINIMO DE\nFACSA AGUA QUE ES LO UNICO QUE QUEREMOS.\nGRACIAS\nVanessa SURNAME\nROBERTO SURNAME SL\nTelf. PHONE_NUMBER\nwww.rserrano.com\nEste correo electronico y, en su caso,\ncualquier archivo adjunto al mismo, contiene informacion de caracter\nconfidencial exclusivamente dirigida a su destinatario. Queda prohibida su\ndivulgacion, copia o distribucion a terceros sin la previa autorizacion escrita\nde ROBERTO SURNAME SL. En el caso de haber recibido este correo electronico por\nerror, por favor, eliminelo de inmediato y rogamos nos notifique esta\ncircunstancia mediante reenvio a la direccion electronica del remitente. De\nconformidad con lo establecido en las normativas vigentes de Proteccion de\nDatos a nivel nacional y europeo, ROBERTO SURNAME SL garantiza la adopcion de\nlas medidas tecnicas y organizativas necesarias para asegurar el tratamiento\nconfidencial de los datos de caracter personal. Asi mismo le informamos que su\ndireccion de email ha sido recabada del propio interesado o de fuentes de\nacceso publico, y esta incluida en nuestros ficheros con la finalidad de\nmantener contacto con usted para el envio de comunicaciones sobre nuestro\nproductos y servicios mientras exista un interes mutuo para ello o una relacion\nnegocial o contractual. No obstante, usted puede ejercer sus derechos de\nacceso, rectificacion y supresion de sus datos, asi como los derechos de\nlimitacion y oposicion a su tratamiento o solicitar mas informacion sobre\nnuestra politica de privacidad utilizando la siguiente informacion:\nResponsable: ROBERTO SURNAME SL\nDireccion: PARTIDA\nLA TORRETA\n,\nNAVE 9. 12110, Alcora (Castellon), Espana\nTelefono: PHONE_NUMBER\nE-mail:\nEMAIL_ADDRESS\nSi considera que el tratamiento no se ajusta a la normativa vigente de\nProteccion de Datos, podra presentar una reclamacion ante la autoridad de\ncontrol: Agencia Espanola de Proteccion de Datos (\nhttps://www.agpd.es\n)", + }, + multi_label: true, + currentAnnotation: { + agent: "recognai", + labels: [ + { + class: "Alcantarillado/Pluviales", + score: 1, + }, + { + class: "Contratación", + score: 1, + }, + { + class: "Problema calidad agua", + score: 1, + }, + { + class: "Funcionamiento del contador", + score: 1, + }, + ], + }, + originStatus: "Validated", + }, + ], + isMultiLabel: true, + }, +}; + +beforeEach(() => { + wrapper = shallowMount(ValidateDiscardActionComponent, options); +}); + +afterEach(() => { + wrapper.destroy(); +}); +describe("ValidateDiscardAction", () => { + it("render the component", () => { + expect(wrapper.is(ValidateDiscardActionComponent)).toBe(true); + }); + it("render a validate button", () => { + const validateButtonId = "validateButton"; + testIfValidateButtonIsRendered(validateButtonId); + }); + it.skip("render a badge on a validate button if there is pending records", () => { + //FIXME - test that the validate button contains a v-badge attributes if there is any record with pending states + //NOTE - a record in pending status have the attribute "status:'Edited'" + testIfThereIsPendingStatus(true); + }); +}); + +const testIfValidateButtonIsRendered = (validateButtonId) => { + const validateButtonWrapper = wrapper.find(`#${validateButtonId}`); + expect(validateButtonWrapper.exists()).toBe(true); +}; + +const testIfThereIsPendingStatus = (isPendingStatus) => { + expect(wrapper.vm.isAnyPendingStatusRecord).toBe(isPendingStatus); +}; diff --git a/frontend/components/commons/results/RecordActionButtons.spec.js b/frontend/components/commons/results/RecordActionButtons.spec.js new file mode 100644 index 0000000000..3a440c32c9 --- /dev/null +++ b/frontend/components/commons/results/RecordActionButtons.spec.js @@ -0,0 +1,70 @@ +import { shallowMount } from "@vue/test-utils"; +import RecordActionButtons from "./RecordActionButtons"; +import BaseButton from "@/components/base/BaseButton"; + +let wrapper = null; +const options = { + stubs: { + "base-button": BaseButton, + }, + propsData: { + actions: [ + { + id: "validate", + name: "Validate", + allow: true, + active: true, + }, + { + id: "discard", + name: "Discard", + allow: true, + active: true, + }, + { + id: "clear", + name: "Clear", + allow: false, + active: false, + }, + ], + }, +}; +beforeEach(() => { + wrapper = shallowMount(RecordActionButtons, options); +}); + +afterEach(() => { + wrapper.destroy(); +}); + +describe("RecordActionButtonsComponent", () => { + it("render the component", () => { + expect(wrapper.is(RecordActionButtons)).toBe(true); + }); + it("expect to show validate button active", async () => { + testIfButtonIsDisabled("validate", undefined); + }); + it("expect to show discard button active", async () => { + testIfButtonIsDisabled("discard", undefined); + }); + it("expect to emit validate on click validate button", async () => { + testIfEmittedIsCorrect("validate"); + }); + it("expect to emit discard on click discard button", async () => { + testIfEmittedIsCorrect("discard"); + }); + it("expect not to render clear button", async () => { + const clearButton = wrapper.find(`.record__actions-button--clear`); + expect(clearButton.exists()).toBe(false); + }); +}); + +const testIfButtonIsDisabled = async (button, disabled) => { + const actionButton = wrapper.find(`.record__actions-button--${button}`); + expect(actionButton.attributes().disabled).toBe(disabled); +}; +const testIfEmittedIsCorrect = async (button) => { + wrapper.find(`.record__actions-button--${button}`).vm.$emit("click"); + expect(wrapper.emitted()).toHaveProperty(button); +}; diff --git a/frontend/components/commons/results/RecordActionButtons.vue b/frontend/components/commons/results/RecordActionButtons.vue new file mode 100644 index 0000000000..22ddc101f1 --- /dev/null +++ b/frontend/components/commons/results/RecordActionButtons.vue @@ -0,0 +1,108 @@ + + + + + diff --git a/frontend/components/commons/results/RecordExtraActions.vue b/frontend/components/commons/results/RecordExtraActions.vue index 8d7d392b64..111386d146 100644 --- a/frontend/components/commons/results/RecordExtraActions.vue +++ b/frontend/components/commons/results/RecordExtraActions.vue @@ -29,13 +29,6 @@ Copy text
-
- Discard record -
@@ -52,17 +45,10 @@ export default { }), ], props: { - allowChangeStatus: { - type: Boolean, - default: false, - }, recordId: { type: String | Number, required: true, }, - recordStatus: { - type: String, - }, recordClipboardText: { type: Array | String, required: true, @@ -88,13 +74,6 @@ export default { }, }, methods: { - // TODO: call vuex-actions here instead of trigger event - onChangeRecordStatus(status) { - if (this.recordStatus !== status) { - this.$emit("on-change-record-status", status); - } - this.close(); - }, showRecordInfoModal() { this.$emit("show-record-info-modal"); this.close(); diff --git a/frontend/components/commons/results/ResultsLoading.vue b/frontend/components/commons/results/ResultsLoading.vue index e8ba76c3b0..eb6ddb39b2 100644 --- a/frontend/components/commons/results/ResultsLoading.vue +++ b/frontend/components/commons/results/ResultsLoading.vue @@ -43,7 +43,7 @@ export default { width: 100%; background: white; border: 1px solid palette(grey, 600); - margin-bottom: $base-space-between-records; + margin-top: $base-space-between-records; border-radius: $border-radius-m; } diff --git a/frontend/components/commons/results/ResultsRecord.vue b/frontend/components/commons/results/ResultsRecord.vue index dd85d35f50..7eaa2ba960 100644 --- a/frontend/components/commons/results/ResultsRecord.vue +++ b/frontend/components/commons/results/ResultsRecord.vue @@ -17,17 +17,13 @@ @@ -152,11 +145,42 @@ export default { ); return visualTokens; }, + tokenClassifierActionButtons() { + return [ + { + id: "validate", + name: "Validate", + allow: true, + active: this.record.status === "Validated", + }, + { + id: "discard", + name: "Discard", + allow: true, + active: this.record.status === "Discarded", + }, + { + id: "clear", + name: "Clear", + allow: true, + disable: !this.record.annotatedEntities?.length || false, + }, + { + id: "reset", + name: "Reset", + allow: true, + disable: this.record.status !== "Edited", + }, + ]; + }, }, methods: { ...mapActions({ validate: "entities/datasets/validateAnnotations", + discard: "entities/datasets/discardAnnotations", updateRecords: "entities/datasets/updateDatasetRecords", + changeStatusToDefault: "entities/datasets/changeStatusToDefault", + resetRecords: "entities/datasets/resetRecords", }), getEntitiesByOrigin(origin) { if (this.interactionsEnabled) { @@ -172,23 +196,44 @@ export default { : []; } }, - async onValidate(record) { + async toggleValidateRecord() { + if (this.record.status === "Validated") { + await this.onChangeStatusToDefault(); + } else { + await this.onValidate(); + } + }, + async toggleDiscardRecord() { + if (this.record.status === "Discarded") { + await this.onChangeStatusToDefault(); + } else { + this.onDiscard(); + } + }, + async onValidate() { await this.validate({ // TODO: Move this as part of token classification dataset logic dataset: this.getTokenClassificationDataset(), agent: this.$auth.user.username, records: [ { - ...record, + ...this.record, annotatedEntities: undefined, annotation: { - entities: record.annotatedEntities, + entities: this.record.annotatedEntities, origin: "annotation", }, }, ], }); }, + async onChangeStatusToDefault() { + const currentRecordAndDataset = { + dataset: this.getTokenClassificationDataset(), + records: [this.record], + }; + await this.changeStatusToDefault(currentRecordAndDataset); + }, onClearAnnotations() { this.updateRecords({ dataset: this.getTokenClassificationDataset(), @@ -202,6 +247,20 @@ export default { ], }); }, + async onReset() { + await this.resetRecords({ + dataset: this.getTokenClassificationDataset(), + records: [ + { + ...this.record, + annotatedEntities: this.record.annotation?.entities, + }, + ], + }); + }, + onDiscard() { + this.$emit("discard"); + }, getTokenClassificationDataset() { return getTokenClassificationDatasetById(this.datasetId); }, @@ -211,11 +270,14 @@ export default {