From e8a8f8939ceea006d76566f34e69ea8fd37c82dd Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Tue, 7 Feb 2023 17:27:39 +0100
Subject: [PATCH 01/45] [pre-commit.ci] pre-commit autoupdate (#2299)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
updates:
- [github.com/psf/black: 22.12.0 →
23.1.0](https://github.com/psf/black/compare/22.12.0...23.1.0)
- [github.com/pycqa/isort: 5.11.5 →
5.12.0](https://github.com/pycqa/isort/compare/5.11.5...5.12.0)
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.pre-commit-config.yaml | 4 ++--
scripts/migrations/es_migration_25042021.py | 1 -
src/argilla/client/apis/metrics.py | 1 -
src/argilla/client/apis/search.py | 1 -
src/argilla/client/apis/status.py | 1 -
src/argilla/client/datasets.py | 1 +
src/argilla/client/sdk/commons/api.py | 1 -
src/argilla/client/sdk/commons/errors.py | 2 --
src/argilla/client/sdk/datasets/api.py | 1 -
src/argilla/client/sdk/text_classification/api.py | 2 --
src/argilla/logging.py | 2 --
src/argilla/monitoring/_flair.py | 1 -
src/argilla/monitoring/_spacy.py | 1 -
src/argilla/monitoring/_transformers.py | 3 +--
src/argilla/monitoring/asgi.py | 2 +-
src/argilla/server/apis/v0/handlers/datasets.py | 2 --
src/argilla/server/apis/v0/handlers/metrics.py | 3 ---
src/argilla/server/apis/v0/handlers/records_update.py | 2 --
src/argilla/server/apis/v0/handlers/text2text.py | 2 --
src/argilla/server/apis/v0/handlers/text_classification.py | 7 -------
.../v0/handlers/text_classification_dataset_settings.py | 3 ---
.../server/apis/v0/handlers/token_classification.py | 2 --
.../v0/handlers/token_classification_dataset_settings.py | 2 --
src/argilla/server/apis/v0/models/text2text.py | 1 -
src/argilla/server/apis/v0/models/text_classification.py | 4 ----
src/argilla/server/apis/v0/models/token_classification.py | 2 --
.../server/apis/v0/validators/text_classification.py | 1 -
src/argilla/server/commons/config.py | 3 ---
src/argilla/server/commons/models.py | 1 -
src/argilla/server/commons/telemetry.py | 2 --
src/argilla/server/daos/backend/client_adapters/factory.py | 1 -
.../server/daos/backend/client_adapters/opensearch.py | 3 ---
src/argilla/server/daos/backend/generic_elastic.py | 2 --
.../server/daos/backend/mappings/token_classification.py | 1 -
src/argilla/server/daos/backend/metrics/base.py | 7 ++++---
.../server/daos/backend/metrics/text_classification.py | 1 -
src/argilla/server/daos/backend/query_helpers.py | 1 -
src/argilla/server/daos/backend/search/model.py | 2 --
src/argilla/server/daos/backend/search/query_builder.py | 3 ---
src/argilla/server/daos/datasets.py | 3 ---
src/argilla/server/daos/models/records.py | 3 ---
src/argilla/server/daos/records.py | 2 --
src/argilla/server/errors/base_errors.py | 1 -
src/argilla/server/security/auth_provider/base.py | 2 +-
src/argilla/server/server.py | 1 -
src/argilla/server/services/datasets.py | 3 ---
src/argilla/server/services/search/model.py | 1 -
src/argilla/server/services/search/service.py | 1 -
src/argilla/server/services/storage/service.py | 2 --
src/argilla/server/services/tasks/text2text/service.py | 1 -
.../tasks/text_classification/labeling_rules_service.py | 1 -
.../server/services/tasks/text_classification/model.py | 3 ---
.../server/services/tasks/text_classification/service.py | 1 -
.../server/services/tasks/token_classification/metrics.py | 1 -
.../server/services/tasks/token_classification/model.py | 4 ----
.../server/services/tasks/token_classification/service.py | 1 -
tests/client/functional_tests/test_record_update.py | 1 -
tests/client/functional_tests/test_scan_raw_records.py | 1 -
tests/client/sdk/commons/test_client.py | 2 --
tests/client/test_api.py | 4 ----
tests/client/test_client_errors.py | 1 -
tests/conftest.py | 2 --
tests/functional_tests/search/test_search_service.py | 1 -
tests/functional_tests/test_log_for_text_classification.py | 2 --
.../functional_tests/test_log_for_token_classification.py | 1 -
tests/labeling/text_classification/test_rule.py | 5 -----
tests/monitoring/test_flair_monitoring.py | 3 +--
tests/monitoring/test_monitor.py | 1 -
tests/monitoring/test_transformers_monitoring.py | 2 --
tests/server/backend/test_query_builder.py | 1 -
tests/server/daos/models/test_records.py | 1 -
tests/server/info/test_api.py | 2 --
tests/server/security/test_model.py | 2 --
tests/server/test_errors.py | 1 -
tests/server/text2text/test_model.py | 1 -
tests/server/text_classification/test_model.py | 3 ---
tests/server/token_classification/test_model.py | 4 ----
77 files changed, 11 insertions(+), 143 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index af7daea006..f67cae6ff1 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -20,13 +20,13 @@ repos:
# - --remove-header
- repo: https://github.com/psf/black
- rev: 22.12.0
+ rev: 23.1.0
hooks:
- id: black
additional_dependencies: ["click==8.0.4"]
- repo: https://github.com/pycqa/isort
- rev: 5.11.5
+ rev: 5.12.0
hooks:
- id: isort
diff --git a/scripts/migrations/es_migration_25042021.py b/scripts/migrations/es_migration_25042021.py
index 01d4ec2a14..63776fe977 100644
--- a/scripts/migrations/es_migration_25042021.py
+++ b/scripts/migrations/es_migration_25042021.py
@@ -103,7 +103,6 @@ def map_doc_2_action(
if __name__ == "__main__":
-
client = Elasticsearch(hosts=settings.elasticsearch)
for dataset in settings.migration_datasets:
diff --git a/src/argilla/client/apis/metrics.py b/src/argilla/client/apis/metrics.py
index 50bc63dc5c..abb4cf88d2 100644
--- a/src/argilla/client/apis/metrics.py
+++ b/src/argilla/client/apis/metrics.py
@@ -19,7 +19,6 @@
class MetricsAPI(AbstractApi):
-
_API_URL_PATTERN = "/api/datasets/{task}/{name}/metrics/{metric}:summary"
def metric_summary(
diff --git a/src/argilla/client/apis/search.py b/src/argilla/client/apis/search.py
index 9ac808ad25..f8a1b66a98 100644
--- a/src/argilla/client/apis/search.py
+++ b/src/argilla/client/apis/search.py
@@ -37,7 +37,6 @@ class SearchResults:
class Search(AbstractApi):
-
_API_URL_PATTERN = "/api/datasets/{name}/{task}:search"
def search_records(
diff --git a/src/argilla/client/apis/status.py b/src/argilla/client/apis/status.py
index 4d29c03afd..ce5d42372e 100644
--- a/src/argilla/client/apis/status.py
+++ b/src/argilla/client/apis/status.py
@@ -26,7 +26,6 @@
@dataclasses.dataclass(frozen=True)
class ApiInfo:
-
version: Optional[str] = None
diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py
index 0a05463d8a..3791292571 100644
--- a/src/argilla/client/datasets.py
+++ b/src/argilla/client/datasets.py
@@ -1137,6 +1137,7 @@ def __only_annotations__(self, data) -> bool:
def _to_datasets_dict(self) -> Dict:
"""Helper method to put token classification records in a `datasets.Dataset`"""
+
# create a dict first, where we make the necessary transformations
def entities_to_dict(
entities: Optional[
diff --git a/src/argilla/client/sdk/commons/api.py b/src/argilla/client/sdk/commons/api.py
index 1974d443a7..9ccc3eb3bb 100644
--- a/src/argilla/client/sdk/commons/api.py
+++ b/src/argilla/client/sdk/commons/api.py
@@ -103,7 +103,6 @@ async def async_bulk(
def build_bulk_response(
response: httpx.Response, name: str, body: Any
) -> Response[BulkResponse]:
-
if 200 <= response.status_code < 400:
return Response(
status_code=response.status_code,
diff --git a/src/argilla/client/sdk/commons/errors.py b/src/argilla/client/sdk/commons/errors.py
index bb38fa0432..003226f40a 100644
--- a/src/argilla/client/sdk/commons/errors.py
+++ b/src/argilla/client/sdk/commons/errors.py
@@ -64,7 +64,6 @@ def __str__(self):
class ArApiResponseError(BaseClientError):
-
HTTP_STATUS: int
def __init__(self, **ctx):
@@ -105,7 +104,6 @@ class ValidationApiError(ArApiResponseError):
HTTP_STATUS = 422
def __init__(self, client_ctx, params, **ctx):
-
for error in params.get("errors", []):
current_level = client_ctx
for loc in error["loc"]:
diff --git a/src/argilla/client/sdk/datasets/api.py b/src/argilla/client/sdk/datasets/api.py
index 08ab036e8f..b836b949d0 100644
--- a/src/argilla/client/sdk/datasets/api.py
+++ b/src/argilla/client/sdk/datasets/api.py
@@ -89,7 +89,6 @@ def delete_dataset(
def _build_response(
response: httpx.Response, name: str
) -> Response[Union[Dataset, ErrorMessage, HTTPValidationError]]:
-
if response.status_code == 200:
parsed_response = Dataset(**response.json())
return Response(
diff --git a/src/argilla/client/sdk/text_classification/api.py b/src/argilla/client/sdk/text_classification/api.py
index 5699566eba..d6cd7b64a5 100644
--- a/src/argilla/client/sdk/text_classification/api.py
+++ b/src/argilla/client/sdk/text_classification/api.py
@@ -97,7 +97,6 @@ def fetch_dataset_labeling_rules(
client: AuthenticatedClient,
name: str,
) -> Response[Union[List[LabelingRule], HTTPValidationError, ErrorMessage]]:
-
url = "{}/api/datasets/TextClassification/{name}/labeling/rules".format(
client.base_url, name=name
)
@@ -118,7 +117,6 @@ def dataset_rule_metrics(
query: str,
label: str,
) -> Response[Union[LabelingRuleMetricsSummary, HTTPValidationError, ErrorMessage]]:
-
url = "{}/api/datasets/TextClassification/{name}/labeling/rules/{query}/metrics?label={label}".format(
client.base_url, name=name, query=query, label=label
)
diff --git a/src/argilla/logging.py b/src/argilla/logging.py
index 076467aa7e..94b1cc249e 100644
--- a/src/argilla/logging.py
+++ b/src/argilla/logging.py
@@ -84,7 +84,6 @@ def __init__(self, *args, **kwargs):
self.emit = lambda record: None
def emit(self, record: logging.LogRecord):
-
try:
level = logger.level(record.levelname).name
except AttributeError:
@@ -100,7 +99,6 @@ def emit(self, record: logging.LogRecord):
def configure_logging():
-
"""Normalizes logging configuration for argilla and its dependencies"""
intercept_handler = LoguruLoggerHandler()
if not intercept_handler.is_available:
diff --git a/src/argilla/monitoring/_flair.py b/src/argilla/monitoring/_flair.py
index b9deeb6bf4..6330b9fc78 100644
--- a/src/argilla/monitoring/_flair.py
+++ b/src/argilla/monitoring/_flair.py
@@ -87,7 +87,6 @@ def flair_monitor(
sample_rate: float,
log_interval: float,
) -> Optional[SequenceTagger]:
-
return FlairMonitor(
pl,
api=api,
diff --git a/src/argilla/monitoring/_spacy.py b/src/argilla/monitoring/_spacy.py
index a238223e56..a1c80833ac 100644
--- a/src/argilla/monitoring/_spacy.py
+++ b/src/argilla/monitoring/_spacy.py
@@ -64,7 +64,6 @@ def doc2token_classification(
def _prepare_log_data(
self, docs_info: Tuple[Doc, Optional[Dict[str, Any]]]
) -> Dict[str, Any]:
-
return dict(
records=[
self.doc2token_classification(
diff --git a/src/argilla/monitoring/_transformers.py b/src/argilla/monitoring/_transformers.py
index 8c6a1d4300..46fe973f26 100644
--- a/src/argilla/monitoring/_transformers.py
+++ b/src/argilla/monitoring/_transformers.py
@@ -56,7 +56,6 @@ def _prepare_log_data(
data: List[Tuple[str, Dict[str, Any], List[LabelPrediction]]],
multi_label: bool = False,
) -> Dict[str, Any]:
-
agent = self.model_config.name_or_path
records = []
@@ -100,7 +99,7 @@ def __call__(
sequences: Union[str, List[str]],
candidate_labels: List[str],
*args,
- **kwargs
+ **kwargs,
):
metadata = (kwargs.pop("metadata", None) or {}).copy()
hypothesis_template = kwargs.get("hypothesis_template", "@default")
diff --git a/src/argilla/monitoring/asgi.py b/src/argilla/monitoring/asgi.py
index 48d0dcbfcb..4b2119be43 100644
--- a/src/argilla/monitoring/asgi.py
+++ b/src/argilla/monitoring/asgi.py
@@ -112,7 +112,7 @@ def __init__(
agent: Optional[str] = None,
tags: Dict[str, str] = None,
*args,
- **kwargs
+ **kwargs,
):
BaseHTTPMiddleware.__init__(self, *args, **kwargs)
diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py
index bb66e9dbbc..01d7d75478 100644
--- a/src/argilla/server/apis/v0/handlers/datasets.py
+++ b/src/argilla/server/apis/v0/handlers/datasets.py
@@ -69,7 +69,6 @@ async def create_dataset(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["create:datasets"]),
) -> Dataset:
-
owner = user.check_workspace(ws_params.workspace)
dataset_class = TasksFactory.get_task_dataset(request.task)
@@ -114,7 +113,6 @@ def update_dataset(
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> Dataset:
-
found_ds = service.find_by_name(
user=current_user, name=name, workspace=ds_params.workspace
)
diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py
index 51c3659506..96e0364697 100644
--- a/src/argilla/server/apis/v0/handlers/metrics.py
+++ b/src/argilla/server/apis/v0/handlers/metrics.py
@@ -29,7 +29,6 @@
class MetricInfo(BaseModel):
-
id: str = Field(description="The metric id")
name: str = Field(description="The metric name")
description: Optional[str] = Field(
@@ -39,7 +38,6 @@ class MetricInfo(BaseModel):
@dataclass
class MetricSummaryParams:
-
interval: Optional[float] = Query(
default=None,
gt=0.0,
@@ -53,7 +51,6 @@ class MetricSummaryParams:
def configure_router(router: APIRouter, cfg: TaskConfig):
-
base_metrics_endpoint = f"/{cfg.task}/{{name}}/metrics"
new_base_metrics_endpoint = f"/{{name}}/{cfg.task}/metrics"
diff --git a/src/argilla/server/apis/v0/handlers/records_update.py b/src/argilla/server/apis/v0/handlers/records_update.py
index 87d25dde54..f77d391e53 100644
--- a/src/argilla/server/apis/v0/handlers/records_update.py
+++ b/src/argilla/server/apis/v0/handlers/records_update.py
@@ -36,7 +36,6 @@
def configure_router(router: APIRouter):
-
RecordType = Union[
TextClassificationRecord,
TokenClassificationRecord,
@@ -61,7 +60,6 @@ async def partial_update_dataset_record(
storage: RecordsStorageService = Depends(RecordsStorageService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> RecordType:
-
dataset = service.find_by_name(
user=current_user,
name=name,
diff --git a/src/argilla/server/apis/v0/handlers/text2text.py b/src/argilla/server/apis/v0/handlers/text2text.py
index 2f34a6fb3b..d0ca90bb62 100644
--- a/src/argilla/server/apis/v0/handlers/text2text.py
+++ b/src/argilla/server/apis/v0/handlers/text2text.py
@@ -78,7 +78,6 @@ async def bulk_records(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> BulkResponse:
-
task = task_type
owner = current_user.check_workspace(common_params.workspace)
try:
@@ -129,7 +128,6 @@ def search_records(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> Text2TextSearchResults:
-
search = search or Text2TextSearchRequest()
query = search.query or Text2TextQuery()
dataset = datasets.find_by_name(
diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py
index 857318415e..414529dbbf 100644
--- a/src/argilla/server/apis/v0/handlers/text_classification.py
+++ b/src/argilla/server/apis/v0/handlers/text_classification.py
@@ -96,7 +96,6 @@ async def bulk_records(
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> BulkResponse:
-
task = task_type
owner = current_user.check_workspace(common_params.workspace)
try:
@@ -325,7 +324,6 @@ async def list_labeling_rules(
),
current_user: User = Security(auth.get_user, scopes=[]),
) -> List[LabelingRule]:
-
dataset = datasets.find_by_name(
user=current_user,
name=name,
@@ -357,7 +355,6 @@ async def create_rule(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> LabelingRule:
-
dataset = datasets.find_by_name(
user=current_user,
name=name,
@@ -398,7 +395,6 @@ async def compute_rule_metrics(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> LabelingRuleMetricsSummary:
-
dataset = datasets.find_by_name(
user=current_user,
name=name,
@@ -454,7 +450,6 @@ async def delete_labeling_rule(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> None:
-
dataset = datasets.find_by_name(
user=current_user,
name=name,
@@ -484,7 +479,6 @@ async def get_rule(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> LabelingRule:
-
dataset = datasets.find_by_name(
user=current_user,
name=name,
@@ -518,7 +512,6 @@ async def update_rule(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> LabelingRule:
-
dataset = datasets.find_by_name(
user=current_user,
name=name,
diff --git a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py
index fc9fad8ec4..e40c97e079 100644
--- a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py
+++ b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py
@@ -36,7 +36,6 @@
def configure_router(router: APIRouter):
-
task = TaskType.text_classification
base_endpoint = f"/{task}/{{name}}/settings"
new_base_endpoint = f"/{{name}}/{task}/settings"
@@ -57,7 +56,6 @@ async def get_dataset_settings(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["read:dataset.settings"]),
) -> TextClassificationSettings:
-
found_ds = datasets.find_by_name(
user=user,
name=name,
@@ -90,7 +88,6 @@ async def save_settings(
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
user: User = Security(auth.get_user, scopes=["write:dataset.settings"]),
) -> TextClassificationSettings:
-
found_ds = datasets.find_by_name(
user=user,
name=name,
diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py
index afa59342f7..940775a132 100644
--- a/src/argilla/server/apis/v0/handlers/token_classification.py
+++ b/src/argilla/server/apis/v0/handlers/token_classification.py
@@ -91,7 +91,6 @@ async def bulk_records(
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> BulkResponse:
-
task = task_type
owner = current_user.check_workspace(common_params.workspace)
try:
@@ -153,7 +152,6 @@ def search_records(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> TokenClassificationSearchResults:
-
search = search or TokenClassificationSearchRequest()
query = search.query or TokenClassificationQuery()
diff --git a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py
index 24e9048003..4756aeb9da 100644
--- a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py
+++ b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py
@@ -56,7 +56,6 @@ async def get_dataset_settings(
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["read:dataset.settings"]),
) -> TokenClassificationSettings:
-
found_ds = datasets.find_by_name(
user=user,
name=name,
@@ -89,7 +88,6 @@ async def save_settings(
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
user: User = Security(auth.get_user, scopes=["write:dataset.settings"]),
) -> TokenClassificationSettings:
-
found_ds = datasets.find_by_name(
user=user,
name=name,
diff --git a/src/argilla/server/apis/v0/models/text2text.py b/src/argilla/server/apis/v0/models/text2text.py
index ff8845a8eb..e2730f46e3 100644
--- a/src/argilla/server/apis/v0/models/text2text.py
+++ b/src/argilla/server/apis/v0/models/text2text.py
@@ -51,7 +51,6 @@ def sort_sentences_by_score(cls, sentences: List[Text2TextPrediction]):
class Text2TextRecordInputs(BaseRecordInputs[Text2TextAnnotation]):
-
text: str
diff --git a/src/argilla/server/apis/v0/models/text_classification.py b/src/argilla/server/apis/v0/models/text_classification.py
index cabe2f9c7a..84f1868f78 100644
--- a/src/argilla/server/apis/v0/models/text_classification.py
+++ b/src/argilla/server/apis/v0/models/text_classification.py
@@ -76,7 +76,6 @@ def initialize_labels(cls, values):
class CreateLabelingRule(UpdateLabelingRule):
-
query: str = Field(description="The es rule query")
@validator("query")
@@ -109,7 +108,6 @@ class TextClassificationAnnotation(_TextClassificationAnnotation):
class TextClassificationRecordInputs(BaseRecordInputs[TextClassificationAnnotation]):
-
inputs: Dict[str, Union[str, List[str]]]
multi_label: bool = False
explanation: Optional[Dict[str, List[TokenAttributions]]] = None
@@ -122,7 +120,6 @@ class TextClassificationRecord(
class TextClassificationBulkRequest(UpdateDatasetRequest):
-
records: List[TextClassificationRecordInputs]
@validator("records")
@@ -138,7 +135,6 @@ def check_multi_label_integrity(cls, records: List[TextClassificationRecord]):
class TextClassificationQuery(ServiceBaseRecordsQuery):
-
predicted_as: List[str] = Field(default_factory=list)
annotated_as: List[str] = Field(default_factory=list)
score: Optional[ScoreRange] = Field(default=None)
diff --git a/src/argilla/server/apis/v0/models/token_classification.py b/src/argilla/server/apis/v0/models/token_classification.py
index 8e9bbbf47c..63eaaa9632 100644
--- a/src/argilla/server/apis/v0/models/token_classification.py
+++ b/src/argilla/server/apis/v0/models/token_classification.py
@@ -42,7 +42,6 @@ class TokenClassificationAnnotation(_TokenClassificationAnnotation):
class TokenClassificationRecordInputs(BaseRecordInputs[TokenClassificationAnnotation]):
-
text: str = Field()
tokens: List[str] = Field(min_items=1)
# TODO(@frascuchon): Delete this field and all related logic
@@ -73,7 +72,6 @@ class TokenClassificationBulkRequest(UpdateDatasetRequest):
class TokenClassificationQuery(ServiceBaseRecordsQuery):
-
predicted_as: List[str] = Field(default_factory=list)
annotated_as: List[str] = Field(default_factory=list)
score: Optional[ScoreRange] = Field(default=None)
diff --git a/src/argilla/server/apis/v0/validators/text_classification.py b/src/argilla/server/apis/v0/validators/text_classification.py
index 02787fa1fc..e225749bed 100644
--- a/src/argilla/server/apis/v0/validators/text_classification.py
+++ b/src/argilla/server/apis/v0/validators/text_classification.py
@@ -38,7 +38,6 @@
# TODO(@frascuchon): Move validator and its models to the service layer
class DatasetValidator:
-
_INSTANCE = None
def __init__(self, datasets: DatasetsService, metrics: MetricsService):
diff --git a/src/argilla/server/commons/config.py b/src/argilla/server/commons/config.py
index 988c18359b..b12dfc368c 100644
--- a/src/argilla/server/commons/config.py
+++ b/src/argilla/server/commons/config.py
@@ -34,7 +34,6 @@ class TaskConfig(BaseModel):
class TasksFactory:
-
__REGISTERED_TASKS__ = dict()
@classmethod
@@ -46,7 +45,6 @@ def register_task(
record_class: Type[ServiceRecord],
metrics: Optional[Type[ServiceBaseTaskMetrics]] = None,
):
-
cls.__REGISTERED_TASKS__[task_type] = TaskConfig(
task=task_type,
dataset=dataset_class,
@@ -99,7 +97,6 @@ def find_task_metric(
def find_task_metrics(
cls, task: TaskType, metric_ids: Set[str]
) -> List[ServiceBaseMetric]:
-
if not metric_ids:
return []
diff --git a/src/argilla/server/commons/models.py b/src/argilla/server/commons/models.py
index ff51220f44..c0d82925ad 100644
--- a/src/argilla/server/commons/models.py
+++ b/src/argilla/server/commons/models.py
@@ -23,7 +23,6 @@ class TaskStatus(str, Enum):
class TaskType(str, Enum):
-
text_classification = "TextClassification"
token_classification = "TokenClassification"
text2text = "Text2Text"
diff --git a/src/argilla/server/commons/telemetry.py b/src/argilla/server/commons/telemetry.py
index 24e0a86a0e..3f39929033 100644
--- a/src/argilla/server/commons/telemetry.py
+++ b/src/argilla/server/commons/telemetry.py
@@ -51,7 +51,6 @@ def _configure_analytics(disable_send: bool = False) -> Client:
@dataclasses.dataclass
class _TelemetryClient:
-
client: Client
__INSTANCE__: "_TelemetryClient" = None
@@ -76,7 +75,6 @@ def get(cls):
return cls.__INSTANCE__
def __post_init__(self):
-
from argilla import __version__
self.__server_id__ = uuid.UUID(int=uuid.getnode())
diff --git a/src/argilla/server/daos/backend/client_adapters/factory.py b/src/argilla/server/daos/backend/client_adapters/factory.py
index ad145f0eba..2ab8ac2d33 100644
--- a/src/argilla/server/daos/backend/client_adapters/factory.py
+++ b/src/argilla/server/daos/backend/client_adapters/factory.py
@@ -38,7 +38,6 @@ def get(
retry_on_timeout: bool = True,
max_retries: int = 5,
) -> IClientAdapter:
-
(
client_class,
support_vector_search,
diff --git a/src/argilla/server/daos/backend/client_adapters/opensearch.py b/src/argilla/server/daos/backend/client_adapters/opensearch.py
index b71323a401..088b48da84 100644
--- a/src/argilla/server/daos/backend/client_adapters/opensearch.py
+++ b/src/argilla/server/daos/backend/client_adapters/opensearch.py
@@ -41,7 +41,6 @@
@dataclasses.dataclass
class OpenSearchClient(IClientAdapter):
-
index_shards: int
config_backend: Dict[str, Any]
@@ -110,7 +109,6 @@ def search_docs(
enable_highlight: bool = True,
routing: str = None,
) -> Dict[str, Any]:
-
with self.error_handling(index=index):
highlight = self.highlight if enable_highlight else None
es_query = self.query_builder.map_2_es_query(
@@ -782,7 +780,6 @@ def _normalize_document(
highlight: Optional[HighlightParser] = None,
is_phrase_query: bool = True,
):
-
data = {
**document["_source"],
"id": document["_id"],
diff --git a/src/argilla/server/daos/backend/generic_elastic.py b/src/argilla/server/daos/backend/generic_elastic.py
index 888b555552..9a1bdd103b 100644
--- a/src/argilla/server/daos/backend/generic_elastic.py
+++ b/src/argilla/server/daos/backend/generic_elastic.py
@@ -184,7 +184,6 @@ def search_records(
exclude_fields: List[str] = None,
enable_highlight: bool = True,
) -> Tuple[int, List[Dict[str, Any]]]:
-
index = dataset_records_index(id)
if not sort.sort_by and sort.shuffle is False:
@@ -248,7 +247,6 @@ def create_dataset(
vectors_cfg: Optional[Dict[str, Any]] = None,
force_recreate: bool = False,
) -> None:
-
_mappings = self._common_records_mappings
task_mappings = self.get_task_mapping(task).copy()
for k in task_mappings:
diff --git a/src/argilla/server/daos/backend/mappings/token_classification.py b/src/argilla/server/daos/backend/mappings/token_classification.py
index c51f2dee49..6685560968 100644
--- a/src/argilla/server/daos/backend/mappings/token_classification.py
+++ b/src/argilla/server/daos/backend/mappings/token_classification.py
@@ -38,7 +38,6 @@ class TokenTagMetrics(BaseModel):
class TokenMetrics(BaseModel):
-
idx: int
value: str
char_start: int
diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py
index 5f9ed46df8..80122add68 100644
--- a/src/argilla/server/daos/backend/metrics/base.py
+++ b/src/argilla/server/daos/backend/metrics/base.py
@@ -125,7 +125,6 @@ def _build_aggregation(self, interval: Optional[float] = None) -> Dict[str, Any]
@dataclasses.dataclass
class TermsAggregation(ElasticsearchMetric):
-
field: str = None
script: Union[str, Dict[str, Any]] = None
fixed_size: Optional[int] = None
@@ -206,7 +205,10 @@ def _build_aggregation(
) -> Dict[str, Any]:
field = text_field or self.default_field
terms_id = f"{self.id}_{field}" if text_field else self.id
- return TermsAggregation(id=terms_id, field=field,).aggregation_request(
+ return TermsAggregation(
+ id=terms_id,
+ field=field,
+ ).aggregation_request(
size=size or self.DEFAULT_WORDCOUNT_SIZE
)[terms_id]
@@ -223,7 +225,6 @@ def aggregation_request(
index: str,
size: int = None,
) -> List[Dict[str, Any]]:
-
schema = client.get_property_type(
index=index,
property_name="metadata",
diff --git a/src/argilla/server/daos/backend/metrics/text_classification.py b/src/argilla/server/daos/backend/metrics/text_classification.py
index 70fffc4ed7..fa59630558 100644
--- a/src/argilla/server/daos/backend/metrics/text_classification.py
+++ b/src/argilla/server/daos/backend/metrics/text_classification.py
@@ -48,7 +48,6 @@ class LabelingRulesMetric(ElasticsearchMetric):
def _build_aggregation(
self, rule_query: str, labels: Optional[List[str]]
) -> Dict[str, Any]:
-
annotated_records_filter = filters.exists_field("annotated_as")
rule_query_filter = filters.text_query(rule_query)
aggr_filters = {
diff --git a/src/argilla/server/daos/backend/query_helpers.py b/src/argilla/server/daos/backend/query_helpers.py
index 731edd6b8f..1c57c5ccee 100644
--- a/src/argilla/server/daos/backend/query_helpers.py
+++ b/src/argilla/server/daos/backend/query_helpers.py
@@ -297,7 +297,6 @@ def histogram_aggregation(
script: Union[str, Dict[str, Any]] = None,
interval: float = 0.1,
):
-
assert field_name or script, "Either field name or script must be provided"
if script:
diff --git a/src/argilla/server/daos/backend/search/model.py b/src/argilla/server/daos/backend/search/model.py
index 395bc702e6..fcaad0e5fa 100644
--- a/src/argilla/server/daos/backend/search/model.py
+++ b/src/argilla/server/daos/backend/search/model.py
@@ -26,7 +26,6 @@ class SortOrder(str, Enum):
class QueryRange(BaseModel):
-
range_from: float = Field(default=0.0, alias="from")
range_to: float = Field(default=None, alias="to")
@@ -77,7 +76,6 @@ class VectorSearch(BaseModel):
class BaseRecordsQuery(BaseQuery):
-
query_text: Optional[str] = None
advanced_query_dsl: bool = False
diff --git a/src/argilla/server/daos/backend/search/query_builder.py b/src/argilla/server/daos/backend/search/query_builder.py
index ac47a55c08..a05ed872b9 100644
--- a/src/argilla/server/daos/backend/search/query_builder.py
+++ b/src/argilla/server/daos/backend/search/query_builder.py
@@ -32,7 +32,6 @@
class HighlightParser:
-
_SEARCH_KEYWORDS_FIELD = "search_keywords"
__HIGHLIGHT_PRE_TAG__ = "<@@-ar-key>"
@@ -196,7 +195,6 @@ def map_2_es_query(
id_from: Optional[str] = None,
shuffle: bool = False,
) -> Dict[str, Any]:
-
if query and query.raw_query:
es_query = {"query": query.raw_query}
else:
@@ -252,7 +250,6 @@ def map_2_es_sort_configuration(
schema: Optional[Dict[str, Any]] = None,
sort: Optional[SortConfig] = None,
) -> Optional[List[Dict[str, Any]]]:
-
if not sort:
return None
diff --git a/src/argilla/server/daos/datasets.py b/src/argilla/server/daos/datasets.py
index 512e201835..c45f8ddd20 100644
--- a/src/argilla/server/daos/datasets.py
+++ b/src/argilla/server/daos/datasets.py
@@ -117,7 +117,6 @@ def update_dataset(
self,
dataset: DatasetDB,
) -> DatasetDB:
-
self._es.update_dataset_document(
id=dataset.id, document=self._dataset_to_es_doc(dataset)
)
@@ -133,7 +132,6 @@ def find_by_name(
as_dataset_class: Type[DatasetDB] = BaseDatasetDB,
task: Optional[str] = None,
) -> Optional[DatasetDB]:
-
dataset_id = BaseDatasetDB.build_dataset_id(
name=name,
owner=owner,
@@ -220,7 +218,6 @@ def save_settings(
dataset: DatasetDB,
settings: DatasetSettingsDB,
) -> BaseDatasetSettingsDB:
-
self._configure_vectors(dataset, settings)
self._es.update_dataset_document(
id=dataset.id,
diff --git a/src/argilla/server/daos/models/records.py b/src/argilla/server/daos/models/records.py
index 4bdd0b18ea..43776057e6 100644
--- a/src/argilla/server/daos/models/records.py
+++ b/src/argilla/server/daos/models/records.py
@@ -29,7 +29,6 @@
class DaoRecordsSearch(BaseModel):
-
query: Optional[BaseRecordsQuery] = None
sort: SortConfig = Field(default_factory=SortConfig)
@@ -119,7 +118,6 @@ def update_annotation(values, annotation_field: str):
@root_validator()
def prepare_record_for_db(cls, values):
-
values = cls.update_annotation(values, "prediction")
values = cls.update_annotation(values, "annotation")
@@ -239,7 +237,6 @@ def dict(self, *args, **kwargs) -> "DictStrAny":
class BaseRecordDB(BaseRecordInDB, Generic[AnnotationDB]):
-
# Read only ones
metrics: Dict[str, Any] = Field(default_factory=dict)
search_keywords: Optional[List[str]] = None
diff --git a/src/argilla/server/daos/records.py b/src/argilla/server/daos/records.py
index 4de51d78f6..1ab091d885 100644
--- a/src/argilla/server/daos/records.py
+++ b/src/argilla/server/daos/records.py
@@ -152,7 +152,6 @@ def search_records(
exclude_fields: List[str] = None,
highligth_results: bool = True,
) -> DaoRecordsSearchResults:
-
try:
search = search or DaoRecordsSearch()
@@ -261,7 +260,6 @@ async def get_record_by_id(
dataset: DatasetDB,
id: str,
) -> Optional[Dict[str, Any]]:
-
return self._es.find_record_by_id(
dataset_id=dataset.id,
record_id=id,
diff --git a/src/argilla/server/errors/base_errors.py b/src/argilla/server/errors/base_errors.py
index db63f5826f..ab69d7ff7b 100644
--- a/src/argilla/server/errors/base_errors.py
+++ b/src/argilla/server/errors/base_errors.py
@@ -19,7 +19,6 @@
class ServerError(Exception):
-
HTTP_STATUS: int = status.HTTP_500_INTERNAL_SERVER_ERROR
@classmethod
diff --git a/src/argilla/server/security/auth_provider/base.py b/src/argilla/server/security/auth_provider/base.py
index 2c5aa7ecb9..268cd9beb7 100644
--- a/src/argilla/server/security/auth_provider/base.py
+++ b/src/argilla/server/security/auth_provider/base.py
@@ -32,6 +32,6 @@ async def get_user(
self,
security_scopes: SecurityScopes,
api_key: Optional[str] = Depends(api_key_header),
- **kwargs
+ **kwargs,
) -> User:
raise NotImplementedError()
diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py
index 1dd7667ccd..7e1ece9e26 100644
--- a/src/argilla/server/server.py
+++ b/src/argilla/server/server.py
@@ -162,7 +162,6 @@ async def setup_elasticsearch():
def configure_app_security(app: FastAPI):
-
if hasattr(auth, "router"):
app.include_router(auth.router)
diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py
index c5f834fb68..2a3bce6f9a 100644
--- a/src/argilla/server/services/datasets.py
+++ b/src/argilla/server/services/datasets.py
@@ -44,7 +44,6 @@ class ServiceBaseDatasetSettings(BaseDatasetSettingsDB):
class DatasetsService:
-
_INSTANCE: "DatasetsService" = None
@classmethod
@@ -197,7 +196,6 @@ def copy_dataset(
copy_tags: Dict[str, Any] = None,
copy_metadata: Dict[str, Any] = None,
) -> ServiceDataset:
-
dataset_workspace = copy_workspace or dataset.owner
dataset_workspace = user.check_workspace(dataset_workspace)
@@ -253,7 +251,6 @@ async def get_settings(
async def save_settings(
self, user: User, dataset: ServiceDataset, settings: ServiceDatasetSettings
) -> ServiceDatasetSettings:
-
self.__dao__.save_settings(dataset=dataset, settings=settings)
return settings
diff --git a/src/argilla/server/services/search/model.py b/src/argilla/server/services/search/model.py
index de02d46356..076d751f99 100644
--- a/src/argilla/server/services/search/model.py
+++ b/src/argilla/server/services/search/model.py
@@ -48,7 +48,6 @@ class ServiceScoreRange(ServiceQueryRange):
class ServiceBaseSearchResultsAggregations(BaseModel):
-
predicted_as: Dict[str, int] = Field(default_factory=dict)
annotated_as: Dict[str, int] = Field(default_factory=dict)
annotated_by: Dict[str, int] = Field(default_factory=dict)
diff --git a/src/argilla/server/services/search/service.py b/src/argilla/server/services/search/service.py
index 8097be98a4..b543a50cf7 100644
--- a/src/argilla/server/services/search/service.py
+++ b/src/argilla/server/services/search/service.py
@@ -67,7 +67,6 @@ def search(
exclude_metrics: bool = True,
metrics: Optional[List[ServiceMetric]] = None,
) -> ServiceSearchResults:
-
if record_from > 0:
metrics = None
diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py
index 2866fb37ba..27897efe11 100644
--- a/src/argilla/server/services/storage/service.py
+++ b/src/argilla/server/services/storage/service.py
@@ -37,7 +37,6 @@ class DeleteRecordsOut:
class RecordsStorageService:
-
_INSTANCE: "RecordsStorageService" = None
@classmethod
@@ -115,7 +114,6 @@ async def update_record(
record: ServiceRecord,
**data,
) -> ServiceRecord:
-
if data.get("metadata"):
record.metadata = {
**(record.metadata or {}),
diff --git a/src/argilla/server/services/tasks/text2text/service.py b/src/argilla/server/services/tasks/text2text/service.py
index 870eca8788..48f5af61b3 100644
--- a/src/argilla/server/services/tasks/text2text/service.py
+++ b/src/argilla/server/services/tasks/text2text/service.py
@@ -81,7 +81,6 @@ def search(
size: int = 100,
exclude_metrics: bool = True,
) -> ServiceSearchResults:
-
metrics = TasksFactory.find_task_metrics(
dataset.task,
metric_ids={
diff --git a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py
index 63bf5b77d8..9b0b4871c0 100644
--- a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py
+++ b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py
@@ -42,7 +42,6 @@ class LabelingRuleSummary(BaseModel):
class LabelingService:
-
_INSTANCE = None
@classmethod
diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py
index 9b8a272846..2450d9731d 100644
--- a/src/argilla/server/services/tasks/text_classification/model.py
+++ b/src/argilla/server/services/tasks/text_classification/model.py
@@ -71,7 +71,6 @@ def strip_query(cls, query: str) -> str:
class ServiceTextClassificationDataset(ServiceBaseDataset):
-
task: TaskType = Field(default=TaskType.text_classification, const=True)
rules: List[ServiceLabelingRule] = Field(default_factory=list)
@@ -300,7 +299,6 @@ def flatten_text(cls, text: Dict[str, Any]):
def _labels_from_annotation(
cls, annotation: TextClassificationAnnotation, multi_label: bool
) -> Union[List[str], List[int]]:
-
if not annotation:
return []
@@ -334,7 +332,6 @@ def extended_fields(self) -> Dict[str, Any]:
class ServiceTextClassificationQuery(ServiceBaseRecordsQuery):
-
predicted_as: List[str] = Field(default_factory=list)
annotated_as: List[str] = Field(default_factory=list)
score: Optional[ServiceScoreRange] = Field(default=None)
diff --git a/src/argilla/server/services/tasks/text_classification/service.py b/src/argilla/server/services/tasks/text_classification/service.py
index 0176bd6b25..e6351d13e9 100644
--- a/src/argilla/server/services/tasks/text_classification/service.py
+++ b/src/argilla/server/services/tasks/text_classification/service.py
@@ -215,7 +215,6 @@ def _is_dataset_multi_label(
def get_labeling_rules(
self, dataset: ServiceTextClassificationDataset
) -> Iterable[ServiceLabelingRule]:
-
return self.__labeling__.list_rules(dataset)
def add_labeling_rule(
diff --git a/src/argilla/server/services/tasks/token_classification/metrics.py b/src/argilla/server/services/tasks/token_classification/metrics.py
index a6b70288ae..e210b0960f 100644
--- a/src/argilla/server/services/tasks/token_classification/metrics.py
+++ b/src/argilla/server/services/tasks/token_classification/metrics.py
@@ -270,7 +270,6 @@ def build_tokens_metrics(
record: ServiceTokenClassificationRecord,
tags: Optional[List[str]] = None,
) -> List[TokenMetrics]:
-
return [
TokenMetrics(
idx=token_idx,
diff --git a/src/argilla/server/services/tasks/token_classification/model.py b/src/argilla/server/services/tasks/token_classification/model.py
index fb2491d0de..926783ecbd 100644
--- a/src/argilla/server/services/tasks/token_classification/model.py
+++ b/src/argilla/server/services/tasks/token_classification/model.py
@@ -77,7 +77,6 @@ class ServiceTokenClassificationAnnotation(ServiceBaseAnnotation):
class ServiceTokenClassificationRecord(
ServiceBaseRecord[ServiceTokenClassificationAnnotation]
):
-
tokens: List[str] = Field(min_items=1)
text: str = Field()
_raw_text: Optional[str] = Field(alias="raw_text")
@@ -87,7 +86,6 @@ class ServiceTokenClassificationRecord(
_predicted: Optional[PredictionStatus] = Field(alias="predicted")
def extended_fields(self) -> Dict[str, Any]:
-
return {
**super().extended_fields(),
# See ../service/service.py
@@ -144,7 +142,6 @@ def span_utils(self) -> SpanUtils:
@property
def predicted(self) -> Optional[PredictionStatus]:
if self.annotation and self.prediction:
-
annotated_entities = self.annotation.entities
predicted_entities = self.prediction.entities
if len(annotated_entities) != len(predicted_entities):
@@ -223,7 +220,6 @@ class Config:
class ServiceTokenClassificationQuery(ServiceBaseRecordsQuery):
-
predicted_as: List[str] = Field(default_factory=list)
annotated_as: List[str] = Field(default_factory=list)
score: Optional[ServiceScoreRange] = Field(default=None)
diff --git a/src/argilla/server/services/tasks/token_classification/service.py b/src/argilla/server/services/tasks/token_classification/service.py
index c422ec2ce0..534ff6920b 100644
--- a/src/argilla/server/services/tasks/token_classification/service.py
+++ b/src/argilla/server/services/tasks/token_classification/service.py
@@ -77,7 +77,6 @@ def search(
size: int = 100,
exclude_metrics: bool = True,
) -> ServiceSearchResults:
-
"""
Run a search in a dataset
diff --git a/tests/client/functional_tests/test_record_update.py b/tests/client/functional_tests/test_record_update.py
index d2046c1f8a..345149ca69 100644
--- a/tests/client/functional_tests/test_record_update.py
+++ b/tests/client/functional_tests/test_record_update.py
@@ -34,7 +34,6 @@ def test_partial_record_update(
mocked_client,
gutenberg_spacy_ner,
):
-
expected_id = "00c27206-da48-4fc3-aab7-4b730628f8ac"
record = record_data_by_id(
diff --git a/tests/client/functional_tests/test_scan_raw_records.py b/tests/client/functional_tests/test_scan_raw_records.py
index 967b8851c5..f58b18006a 100644
--- a/tests/client/functional_tests/test_scan_raw_records.py
+++ b/tests/client/functional_tests/test_scan_raw_records.py
@@ -28,7 +28,6 @@ def test_scan_records(
gutenberg_spacy_ner,
fields,
):
-
import pandas as pd
import argilla as rg
diff --git a/tests/client/sdk/commons/test_client.py b/tests/client/sdk/commons/test_client.py
index 1fe621df33..5b1dd8bedd 100644
--- a/tests/client/sdk/commons/test_client.py
+++ b/tests/client/sdk/commons/test_client.py
@@ -32,7 +32,6 @@ def test_wrong_hostname_values(
url: str,
raises_error: bool,
):
-
if raises_error:
with pytest.raises(Exception):
Client(base_url=url)
@@ -42,7 +41,6 @@ def test_wrong_hostname_values(
def test_http_calls(mocked_client):
-
rb_api = active_api()
data = rb_api.http_client.get("/api/_info")
assert data.get("version"), data
diff --git a/tests/client/test_api.py b/tests/client/test_api.py
index 23d02253f6..4f80a66deb 100644
--- a/tests/client/test_api.py
+++ b/tests/client/test_api.py
@@ -208,7 +208,6 @@ def test_log_records_with_too_long_text(mocked_client):
def test_not_found_response(mocked_client):
-
with pytest.raises(NotFoundApiError):
api.load(name="not-found")
@@ -227,7 +226,6 @@ def test_log_without_name(mocked_client):
def test_log_passing_empty_records_list(mocked_client):
-
with pytest.raises(
InputValueError,
match="Empty record list has been passed as argument.",
@@ -285,7 +283,6 @@ def raise_http_error(*args, **kwargs):
)
with pytest.raises(BaseClientError):
-
try:
future.result()
finally:
@@ -639,7 +636,6 @@ def test_token_classification_spans(span, valid):
def test_load_text2text(mocked_client, supported_vector_search):
-
vectors = {"bert_uncased": [1.2, 3.4, 6.4, 6.4]}
records = []
diff --git a/tests/client/test_client_errors.py b/tests/client/test_client_errors.py
index d0bf562067..a19ff3dd12 100644
--- a/tests/client/test_client_errors.py
+++ b/tests/client/test_client_errors.py
@@ -18,7 +18,6 @@
def test_unauthorized_response_error(mocked_client):
-
with pytest.raises(UnauthorizedApiError, match="Could not validate credentials"):
import argilla as ar
diff --git a/tests/conftest.py b/tests/conftest.py
index 72f6d86241..a9b28296e0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -33,7 +33,6 @@
@pytest.fixture
def telemetry_track_data(mocker):
-
client = telemetry._TelemetryClient.get()
if client:
# Disable sending data for tests
@@ -48,7 +47,6 @@ def mocked_client(
monkeypatch,
telemetry_track_data,
) -> SecuredClient:
-
with TestClient(app, raise_server_exceptions=False) as _client:
client_ = SecuredClient(_client)
diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py
index 1b09a87d7c..92eb8062bb 100644
--- a/tests/functional_tests/search/test_search_service.py
+++ b/tests/functional_tests/search/test_search_service.py
@@ -132,7 +132,6 @@ def test_query_builder_with_nested(
def test_failing_metrics(service, mocked_client):
-
dataset = Dataset(
name="test_failing_metrics",
owner=argilla.get_workspace(),
diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py
index 3ece320078..0a85ec53e8 100644
--- a/tests/functional_tests/test_log_for_text_classification.py
+++ b/tests/functional_tests/test_log_for_text_classification.py
@@ -185,7 +185,6 @@ def test_log_data_with_vectors_and_update_ko(mocked_client: SecuredClient):
def test_log_data_in_several_workspaces(mocked_client: SecuredClient):
-
workspace = "test-ws"
dataset = "test_log_data_in_several_workspaces"
text = "This is a text"
@@ -314,7 +313,6 @@ def test_dynamics_metadata(mocked_client):
def test_log_with_bulk_error(mocked_client):
-
dataset = "test_log_with_bulk_error"
ar.delete(dataset)
try:
diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py
index 08dddb73fe..f4bea51e2c 100644
--- a/tests/functional_tests/test_log_for_token_classification.py
+++ b/tests/functional_tests/test_log_for_token_classification.py
@@ -54,7 +54,6 @@ def test_log_with_empty_tokens_list(mocked_client):
def test_call_metrics_with_no_api_client_initialized(mocked_client):
-
for metric in ALL_METRICS:
if metric == entity_consistency:
continue
diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py
index 00d2c37f85..3c1d29a326 100644
--- a/tests/labeling/text_classification/test_rule.py
+++ b/tests/labeling/text_classification/test_rule.py
@@ -178,7 +178,6 @@ def test_create_rules_with_update(
def test_load_rules(mocked_client, log_dataset):
-
mocked_client.post(
f"/api/datasets/TextClassification/{log_dataset}/labeling/rules",
json={"query": "a query", "label": "LALA"},
@@ -191,7 +190,6 @@ def test_load_rules(mocked_client, log_dataset):
def test_add_rules(mocked_client, log_dataset):
-
expected_rules = [
Rule(query="a query", label="La La"),
Rule(query="another query", label="La La"),
@@ -209,7 +207,6 @@ def test_add_rules(mocked_client, log_dataset):
def test_delete_rules(mocked_client, log_dataset):
-
rules = [
Rule(query="a query", label="La La"),
Rule(query="another query", label="La La"),
@@ -235,7 +232,6 @@ def test_delete_rules(mocked_client, log_dataset):
def test_update_rules(mocked_client, log_dataset):
-
rules = [
Rule(query="a query", label="La La"),
Rule(query="another query", label="La La"),
@@ -316,7 +312,6 @@ def test_copy_dataset_with_rules(mocked_client, log_dataset):
],
)
def test_rule_metrics(mocked_client, log_dataset, rule, expected_metrics):
-
delete_rule_silently(mocked_client, log_dataset, rule)
mocked_client.post(
diff --git a/tests/monitoring/test_flair_monitoring.py b/tests/monitoring/test_flair_monitoring.py
index e944c08aed..e8cb51b6ad 100644
--- a/tests/monitoring/test_flair_monitoring.py
+++ b/tests/monitoring/test_flair_monitoring.py
@@ -15,7 +15,6 @@
def test_flair_monitoring(mocked_client, monkeypatch):
-
from flair.data import Sentence
from flair.models import SequenceTagger
@@ -52,7 +51,7 @@ def test_flair_monitoring(mocked_client, monkeypatch):
assert record.tokens == [token.text for token in sentence.tokens]
assert len(record.prediction) == len(detected_labels)
- for ((label, start, end, score), span) in zip(record.prediction, detected_labels):
+ for (label, start, end, score), span in zip(record.prediction, detected_labels):
assert label == span.value
assert start == span.span.start_pos
assert end == span.span.end_pos
diff --git a/tests/monitoring/test_monitor.py b/tests/monitoring/test_monitor.py
index 95c28ea9e1..18240551af 100644
--- a/tests/monitoring/test_monitor.py
+++ b/tests/monitoring/test_monitor.py
@@ -38,7 +38,6 @@ def test_monitor_with_non_supported_model():
def test_monitor_non_supported_huggingface_model():
with warnings.catch_warnings(record=True) as warning_list:
-
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
diff --git a/tests/monitoring/test_transformers_monitoring.py b/tests/monitoring/test_transformers_monitoring.py
index 3ea8eb28cf..ee0487c98f 100644
--- a/tests/monitoring/test_transformers_monitoring.py
+++ b/tests/monitoring/test_transformers_monitoring.py
@@ -242,7 +242,6 @@ def test_monitor_zero_short_passing_labels_keyword_arg(
mocked_monitor,
dataset,
):
-
argilla.delete(dataset)
predictions = mocked_monitor(
text,
@@ -307,7 +306,6 @@ def test_monitor_zero_shot_with_text_array(
mocked_monitor,
dataset,
):
-
argilla.delete(dataset)
predictions = mocked_monitor(
[text], candidate_labels=labels, hypothesis_template=hypothesis
diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py
index edadba6a9e..2ceef02733 100644
--- a/tests/server/backend/test_query_builder.py
+++ b/tests/server/backend/test_query_builder.py
@@ -61,7 +61,6 @@
],
)
def test_build_sort_configuration(index_schema, sort_cfg, expected_sort):
-
builder = EsQueryBuilder()
es_sort = builder.map_2_es_sort_configuration(
diff --git a/tests/server/daos/models/test_records.py b/tests/server/daos/models/test_records.py
index 21d8980d2a..e2542c804f 100644
--- a/tests/server/daos/models/test_records.py
+++ b/tests/server/daos/models/test_records.py
@@ -21,7 +21,6 @@
def test_metadata_limit():
-
long_value = "a" * (settings.metadata_field_length + 1)
short_value = "a" * (settings.metadata_field_length - 1)
diff --git a/tests/server/info/test_api.py b/tests/server/info/test_api.py
index 6c22ead5e5..d25b143555 100644
--- a/tests/server/info/test_api.py
+++ b/tests/server/info/test_api.py
@@ -18,7 +18,6 @@
def test_api_info(mocked_client):
-
response = mocked_client.get("/api/_info")
assert response.status_code == 200
@@ -28,7 +27,6 @@ def test_api_info(mocked_client):
def test_api_status(mocked_client):
-
response = mocked_client.get("/api/_status")
assert response.status_code == 200
diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py
index 88ff875bdc..4c4428ed06 100644
--- a/tests/server/security/test_model.py
+++ b/tests/server/security/test_model.py
@@ -64,7 +64,6 @@ def test_check_non_provided_workspaces():
def test_check_user_workspaces():
-
a_ws = "A-workspace"
expected_workspaces = [a_ws, "B-ws"]
user = User(username="test-user", workspaces=[a_ws, "B-ws", "C-ws"])
@@ -76,7 +75,6 @@ def test_check_user_workspaces():
def test_default_workspace():
-
user = User(username="admin")
assert user.default_workspace == "admin"
diff --git a/tests/server/test_errors.py b/tests/server/test_errors.py
index 838cca3956..0090ab2e31 100644
--- a/tests/server/test_errors.py
+++ b/tests/server/test_errors.py
@@ -16,7 +16,6 @@
def test_generic_error():
-
err = GenericServerError(error=ValueError("this is an error"))
assert (
str(err)
diff --git a/tests/server/text2text/test_model.py b/tests/server/text2text/test_model.py
index 5bead4fd23..c8fa2fd30d 100644
--- a/tests/server/text2text/test_model.py
+++ b/tests/server/text2text/test_model.py
@@ -22,7 +22,6 @@
def test_sentences_sorted_by_score():
-
record = Text2TextRecord(
text="The inpu2 text",
prediction=Text2TextAnnotation(
diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py
index ebbc61af47..18a967a637 100644
--- a/tests/server/text_classification/test_model.py
+++ b/tests/server/text_classification/test_model.py
@@ -134,7 +134,6 @@ def test_model_with_annotations():
def test_single_label_with_multiple_annotation():
-
with pytest.raises(
ValidationError,
match="Single label record must include only one annotation label",
@@ -213,7 +212,6 @@ def test_score_integrity():
def test_prediction_ok_cases():
-
data = {
"multi_label": True,
"inputs": {"data": "My cool data"},
@@ -330,7 +328,6 @@ def test_validate_without_labels_for_single_label(annotation):
def test_query_with_uncovered_by_rules():
-
query = TextClassificationQuery(uncovered_by_rules=["query", "other*"])
assert EsQueryBuilder._to_es_query(query) == {
diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py
index 2f710c7d84..11701da917 100644
--- a/tests/server/token_classification/test_model.py
+++ b/tests/server/token_classification/test_model.py
@@ -31,7 +31,6 @@
def test_char_position():
-
with pytest.raises(
ValidationError,
match="End character cannot be placed before the starting character,"
@@ -68,7 +67,6 @@ def test_fix_substrings():
def test_entities_with_spaces():
-
text = "This is a great space"
ServiceTokenClassificationRecord(
text=text,
@@ -284,7 +282,6 @@ def test_annotated_without_entities():
def test_adjust_spans():
-
text = "A text with some empty spaces that could bring not cleany annotated spans"
record = ServiceTokenClassificationRecord(
text=text,
@@ -338,7 +335,6 @@ def test_whitespace_in_tokens():
def test_predicted_ok_ko_computation():
-
text = "A text with some empty spaces that could bring not cleanly annotated spans"
record = ServiceTokenClassificationRecord(
text=text,
From a09c74d7cdb05df4f823d419898df03b3a392f8c Mon Sep 17 00:00:00 2001
From: Daniel Vila Suero
Date: Tue, 7 Feb 2023 18:16:34 +0100
Subject: [PATCH 02/45] Add deploy to readme (#2307)
# Description
Please include a summary of the changes and the related issue. Please
also include relevant motivation and context. List any dependencies that
are required for this change.
Closes #
**Type of change**
(Please delete options that are not relevant. Remember to title the PR
according to the type of change)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update
**How Has This Been Tested**
(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)
- [ ] Test A
- [ ] Test B
**Checklist**
- [ ] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [ ] follows the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I added comments to my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
---
README.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index c7f402073e..8e9f7af9a2 100644
--- a/README.md
+++ b/README.md
@@ -16,20 +16,20 @@
+
+
+
Open-source framework for data-centric NLP
Data Labeling, curation, and Inference Store
Designed for MLOps & Feedback Loops
-
> 🆕 🔥 Play with Argilla UI with this [live-demo](https://argilla-live-demo.hf.space) powered by Hugging Face Spaces (login:`argilla`, password:`1234`)
> 🆕 🔥 Since `1.2.0` Argilla supports vector search for finding the most similar records to a given one. This feature uses vector or semantic search combined with more traditional search (keyword and filter based). Learn more on this [deep-dive guide](https://docs.argilla.io/en/latest/guides/features/semantic-search.html)
-![imagen](https://user-images.githubusercontent.com/1107111/204772677-facee627-9b3b-43ca-8533-bbc9b4e2d0aa.png)
-
From 2bc97c62380bb46aad6439323588051c005202f6 Mon Sep 17 00:00:00 2001
From: Daniel Vila Suero
Date: Tue, 7 Feb 2023 18:19:35 +0100
Subject: [PATCH 03/45] Update README.md (#2308)
# Description
Please include a summary of the changes and the related issue. Please
also include relevant motivation and context. List any dependencies that
are required for this change.
Closes #
**Type of change**
(Please delete options that are not relevant. Remember to title the PR
according to the type of change)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update
**How Has This Been Tested**
(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)
- [ ] Test A
- [ ] Test B
**Checklist**
- [ ] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [ ] follows the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I added comments to my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
---
README.md | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index 8e9f7af9a2..ff816f08ae 100644
--- a/README.md
+++ b/README.md
@@ -16,16 +16,15 @@
-
+
Open-source framework for data-centric NLP
-Data Labeling, curation, and Inference Store
-Designed for MLOps & Feedback Loops
+Data Labeling for MLOps & Feedback Loops
-> 🆕 🔥 Play with Argilla UI with this [live-demo](https://argilla-live-demo.hf.space) powered by Hugging Face Spaces (login:`argilla`, password:`1234`)
+> 🆕 🔥 Deploy [Argilla on Spaces](https://huggingface.co/new-space?template=argilla/argilla-template-space)
> 🆕 🔥 Since `1.2.0` Argilla supports vector search for finding the most similar records to a given one. This feature uses vector or semantic search combined with more traditional search (keyword and filter based). Learn more on this [deep-dive guide](https://docs.argilla.io/en/latest/guides/features/semantic-search.html)
From d3542748b1fbaa9b9b1d56f6e5b70027fded3a90 Mon Sep 17 00:00:00 2001
From: Francisco Aranda
Date: Thu, 9 Feb 2023 17:11:52 +0100
Subject: [PATCH 04/45] chore: Upgrade package version
---
frontend/package.json | 2 +-
src/argilla/_version.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/frontend/package.json b/frontend/package.json
index 7a87342bbf..9fd9999c0e 100644
--- a/frontend/package.json
+++ b/frontend/package.json
@@ -1,6 +1,6 @@
{
"name": "argilla",
- "version": "1.3.0-dev0",
+ "version": "1.4.0-dev0",
"private": true,
"eslintIgnore": [
"node_modules/**/*",
diff --git a/src/argilla/_version.py b/src/argilla/_version.py
index 90d7c9110a..7208766f3a 100644
--- a/src/argilla/_version.py
+++ b/src/argilla/_version.py
@@ -13,4 +13,4 @@
# limitations under the License.
# coding: utf-8
-version = "1.3.0-dev0"
+version = "1.4.0-dev0"
From 5f0627cf523fdfcb2efeb7ae5663c3e9a845ecb9 Mon Sep 17 00:00:00 2001
From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Date: Tue, 14 Feb 2023 09:35:45 +0100
Subject: [PATCH 05/45] ci: Replace `isort` by `ruff` in `pre-commit` (#2325)
Hello!
# Pull Request overview
* Replaced `isort` by `ruff` in `pre-commit`
* This has fixed the import sorting throughout the repo.
* Manually went through `ruff` errors and fixed a bunch:
* Unused imports
* Added `if TYPE_CHECKING:` statements
* Placed `# noqa: ` where the warned-about statement is the
desired behaviour
* Add basic `ruff` configuration to `pyproject.toml`
## Details
This PR focuses on replacing `isort` by `ruff` in `pre-commit`. The
motivation for this is that:
* `isort` frequently breaks. I have experienced 2 separate occasions in
the last few months alone where the latest `isort` release has broken my
CI runs in NLTK and SetFit.
* `isort` is no longer supported for Python 3.7, whereas Argilla still
supports 3.7 for now.
* `ruff` is absurdly fast, I actually can't believe how quick it is.
This PR consists of 3 commits at this time, and I would advise looking
at them commit-by-commit rather than at the PR as a whole. I'll also
explain each commit individually.
## [Add ruff basic
configuration](https://github.com/argilla-io/argilla/commit/497420e7097b7039d479df6bf431ea73c370f90b)
I've added basic configuration for
[`ruff`](https://github.com/charliermarsh/ruff), a very efficient
linter. I recommend the following commands:
```
# Get all [F]ailures and [E]rrors
ruff .
# See all import sort errors
ruff . --select I
# Fix all import sort errors
ruff . --select I --fix
```
## [Remove unused imports, apply some noqa's, add
TYPE_CHECKING](https://github.com/argilla-io/argilla/commit/f219acbaba8f4b7fffaea70e8e1e9d1fe3a28475)
The unused imports speaks for itself.
As for the `noqa`'s, `ruff` (like most linters) respect the `# noqa` (no
quality assurance) keyword. I've used the keyword in various locations
where linters would warn, but the behaviour is actually correct. As a
result, the output of `ruff .` now only points to questionable code.
Lastly, I added `TYPE_CHECKING` in some locations. If type hints hint at
objects that do not need to be imported during run-time, then it's
common to type hint like `arr: "numpy.ndarray"`. However, IDE's won't
understand what `arr` is. Python has implemented `TYPE_CHECKING` which
can be used to conditionally import code *only* when type checking. As a
result, the code block is not actually executed in practice, but the
inclusion of it allows for IDEs to better support development.
See an example here:
```python
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy
def func(arr: "numpy.ndarray") -> None:
...
```
## [Replace isort with ruff in
CI](https://github.com/argilla-io/argilla/commit/e05f30efb65c9e9b3921a551e1f4b581fcc14db4)
I've replaced `isort` (which was both [broken for
5.11.*](https://github.com/PyCQA/isort/issues/2077) and [does not work
for Python 3.8 in
5.12.*](https://github.com/PyCQA/isort/releases/tag/5.12.0)) with `ruff`
in the CI, using both `--select I` to only select `isort` warnings and
`--fix` to immediately fix the warnings.
Then I ran `pre-commit run --all` to fix the ~67 outstanding issues in
the repository.
---
**Type of change**
- [x] Refactor (change restructuring the codebase without changing
functionality)
**How Has This Been Tested**
I verified that the behaviour did not change using `pytest tests`.
**Checklist**
- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I added comments to my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- Tom Aarsen
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
.pre-commit-config.yaml | 10 ++++--
pyproject.toml | 31 +++++++++++++++++++
scripts/load_data.py | 5 ++-
src/argilla/__init__.py | 4 +--
src/argilla/client/apis/datasets.py | 2 +-
src/argilla/client/apis/search.py | 2 +-
src/argilla/client/datasets.py | 11 +++++--
src/argilla/client/sdk/client.py | 4 +--
src/argilla/client/sdk/commons/api.py | 2 +-
src/argilla/client/sdk/text2text/api.py | 1 -
.../client/sdk/text_classification/api.py | 2 +-
.../client/sdk/token_classification/api.py | 1 -
src/argilla/client/sdk/users/api.py | 3 --
.../text_classification/label_models.py | 11 +++----
src/argilla/listeners/listener.py | 6 ++--
src/argilla/listeners/models.py | 5 ++-
src/argilla/metrics/__init__.py | 4 +--
src/argilla/monitoring/asgi.py | 2 +-
src/argilla/monitoring/base.py | 1 -
.../server/apis/v0/handlers/datasets.py | 2 +-
.../server/apis/v0/handlers/metrics.py | 4 +--
.../server/apis/v0/handlers/records_update.py | 11 ++-----
.../apis/v0/handlers/token_classification.py | 1 -
.../server/apis/v0/models/text2text.py | 3 +-
.../apis/v0/models/text_classification.py | 7 ++---
.../server/daos/backend/metrics/base.py | 5 ++-
.../daos/backend/search/query_builder.py | 2 +-
src/argilla/server/daos/models/datasets.py | 2 +-
src/argilla/server/daos/models/records.py | 2 +-
src/argilla/server/errors/api_errors.py | 2 +-
src/argilla/server/errors/base_errors.py | 2 +-
src/argilla/server/server.py | 1 -
src/argilla/server/services/datasets.py | 2 +-
.../server/services/storage/service.py | 4 +--
.../server/services/tasks/text2text/models.py | 4 ---
.../services/tasks/text2text/service.py | 3 +-
.../tasks/text_classification/model.py | 2 +-
src/argilla/utils/span_utils.py | 2 +-
tests/client/apis/test_base.py | 1 -
tests/client/conftest.py | 3 +-
.../functional_tests/test_record_update.py | 1 -
.../functional_tests/test_scan_raw_records.py | 5 +--
tests/client/sdk/commons/api.py | 3 +-
tests/client/sdk/commons/test_client.py | 1 -
tests/client/sdk/conftest.py | 3 +-
tests/client/sdk/datasets/test_api.py | 9 +-----
tests/client/sdk/datasets/test_models.py | 1 -
tests/client/sdk/text2text/test_models.py | 1 -
.../sdk/text_classification/test_models.py | 1 -
.../sdk/token_classification/test_models.py | 1 -
tests/client/sdk/users/test_api.py | 1 -
tests/client/test_api.py | 6 ++--
tests/client/test_asgi.py | 9 +++---
tests/client/test_client_errors.py | 1 -
tests/client/test_dataset.py | 3 +-
tests/client/test_models.py | 4 +--
tests/conftest.py | 4 +--
tests/datasets/test_datasets.py | 3 +-
.../test_delete_records_from_datasets.py | 1 -
.../datasets/test_update_record.py | 1 -
.../search/test_search_service.py | 3 +-
.../test_log_for_text_classification.py | 6 ++--
.../test_log_for_token_classification.py | 4 +--
tests/helpers.py | 5 ++-
.../text_classification/test_label_errors.py | 5 ++-
.../text_classification/test_label_models.py | 2 --
.../labeling/text_classification/test_rule.py | 1 -
.../text_classification/test_weak_labels.py | 3 +-
tests/listeners/test_listener.py | 3 +-
tests/metrics/test_common_metrics.py | 3 +-
tests/metrics/test_token_classification.py | 3 +-
tests/monitoring/test_base_monitor.py | 2 +-
tests/monitoring/test_flair_monitoring.py | 3 +-
.../test_transformers_monitoring.py | 3 +-
tests/server/backend/test_query_builder.py | 1 -
tests/server/commons/test_records_dao.py | 1 -
tests/server/commons/test_settings.py | 3 +-
tests/server/commons/test_telemetry.py | 3 +-
tests/server/daos/models/test_records.py | 1 -
tests/server/datasets/test_api.py | 3 +-
tests/server/datasets/test_dao.py | 1 -
tests/server/datasets/test_model.py | 3 +-
.../datasets/test_get_record.py | 1 -
tests/server/info/test_api.py | 2 +-
tests/server/security/test_dao.py | 1 -
tests/server/security/test_model.py | 3 +-
tests/server/security/test_provider.py | 3 +-
tests/server/security/test_service.py | 1 -
tests/server/test_app.py | 1 -
tests/server/text2text/test_api.py | 3 +-
tests/server/text_classification/test_api.py | 2 +-
.../text_classification/test_api_rules.py | 1 -
.../server/text_classification/test_model.py | 3 +-
tests/server/token_classification/test_api.py | 2 +-
.../server/token_classification/test_model.py | 3 +-
tests/utils/test_span_utils.py | 1 -
tests/utils/test_utils.py | 1 -
97 files changed, 139 insertions(+), 182 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f67cae6ff1..5ba8658505 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -25,10 +25,14 @@ repos:
- id: black
additional_dependencies: ["click==8.0.4"]
- - repo: https://github.com/pycqa/isort
- rev: 5.12.0
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
+ rev: v0.0.244
hooks:
- - id: isort
+ # Simulate isort via (the much faster) ruff
+ - id: ruff
+ args:
+ - --select=I
+ - --fix
- repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
rev: v9.4.0
diff --git a/pyproject.toml b/pyproject.toml
index 5052f1d547..f28608d42e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -108,3 +108,34 @@ exclude_lines = [
[tool.isort]
profile = "black"
+
+[tool.ruff]
+# Ignore line length violations
+ignore = ["E501"]
+
+# Exclude a variety of commonly ignored directories.
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".hg",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "venv",
+]
+
+[tool.ruff.per-file-ignores]
+# Ignore imported but unused;
+"__init__.py" = ["F401"]
diff --git a/scripts/load_data.py b/scripts/load_data.py
index 0295872157..06405d7650 100644
--- a/scripts/load_data.py
+++ b/scripts/load_data.py
@@ -15,12 +15,11 @@
import sys
import time
+import argilla as rg
import pandas as pd
import requests
-from datasets import load_dataset
-
-import argilla as rg
from argilla.labeling.text_classification import Rule, add_rules
+from datasets import load_dataset
class LoadDatasets:
diff --git a/src/argilla/__init__.py b/src/argilla/__init__.py
index c9fb4f6787..1dbd654c9f 100644
--- a/src/argilla/__init__.py
+++ b/src/argilla/__init__.py
@@ -47,12 +47,10 @@
read_datasets,
read_pandas,
)
- from argilla.client.models import (
- TextGenerationRecord, # TODO Remove TextGenerationRecord
- )
from argilla.client.models import (
Text2TextRecord,
TextClassificationRecord,
+ TextGenerationRecord, # TODO Remove TextGenerationRecord
TokenAttributions,
TokenClassificationRecord,
)
diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py
index 55079c7a95..5fff021362 100644
--- a/src/argilla/client/apis/datasets.py
+++ b/src/argilla/client/apis/datasets.py
@@ -15,7 +15,7 @@
import warnings
from dataclasses import dataclass
from datetime import datetime
-from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
+from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from pydantic import BaseModel, Field
diff --git a/src/argilla/client/apis/search.py b/src/argilla/client/apis/search.py
index f8a1b66a98..07de6bb71e 100644
--- a/src/argilla/client/apis/search.py
+++ b/src/argilla/client/apis/search.py
@@ -13,7 +13,7 @@
# limitations under the License.
import dataclasses
-from typing import List, Optional, Union
+from typing import List, Optional
from argilla.client.apis import AbstractApi
from argilla.client.models import Record
diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py
index 944e7d93f2..952a3a96ef 100644
--- a/src/argilla/client/datasets.py
+++ b/src/argilla/client/datasets.py
@@ -16,7 +16,7 @@
import logging
import random
import uuid
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import pandas as pd
from pkg_resources import parse_version
@@ -32,6 +32,11 @@
from argilla.client.sdk.datasets.models import TaskType
from argilla.utils.span_utils import SpanUtils
+if TYPE_CHECKING:
+ import datasets
+ import pandas
+ import spacy
+
_LOGGER = logging.getLogger(__name__)
@@ -60,7 +65,7 @@ def _requires_spacy(func):
@functools.wraps(func)
def check_if_spacy_installed(*args, **kwargs):
try:
- import spacy
+ import spacy # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"'spacy' must be installed to use `{func.__name__}`"
@@ -1007,7 +1012,7 @@ def from_datasets(
for row in dataset:
# TODO: fails with a KeyError if no tokens column is present and no mapping is indicated
if not row["tokens"]:
- _LOGGER.warning(f"Ignoring row with no tokens.")
+ _LOGGER.warning("Ignoring row with no tokens.")
continue
if row.get("tags"):
diff --git a/src/argilla/client/sdk/client.py b/src/argilla/client/sdk/client.py
index ee244c7697..be4d2983b9 100644
--- a/src/argilla/client/sdk/client.py
+++ b/src/argilla/client/sdk/client.py
@@ -119,7 +119,7 @@ async def inner_async(self, *args, **kwargs):
try:
result = await func(self, *args, **kwargs)
return result
- except httpx.ConnectError as err:
+ except httpx.ConnectError as err: # noqa: F841
return wrap_error(self.base_url)
@functools.wraps(func)
@@ -127,7 +127,7 @@ def inner(self, *args, **kwargs):
try:
result = func(self, *args, **kwargs)
return result
- except httpx.ConnectError as err:
+ except httpx.ConnectError as err: # noqa: F841
return wrap_error(self.base_url)
is_coroutine = inspect.iscoroutinefunction(func)
diff --git a/src/argilla/client/sdk/commons/api.py b/src/argilla/client/sdk/commons/api.py
index 9ccc3eb3bb..3810aebdf1 100644
--- a/src/argilla/client/sdk/commons/api.py
+++ b/src/argilla/client/sdk/commons/api.py
@@ -126,7 +126,7 @@ def build_data_response(
parsed_record = json.loads(r)
try:
parsed_response = data_type(**parsed_record)
- except Exception as err:
+ except Exception as err: # noqa: F841
raise GenericApiError(**parsed_record) from None
parsed_responses.append(parsed_response)
return Response(
diff --git a/src/argilla/client/sdk/text2text/api.py b/src/argilla/client/sdk/text2text/api.py
index 2baab48725..5b5aca9568 100644
--- a/src/argilla/client/sdk/text2text/api.py
+++ b/src/argilla/client/sdk/text2text/api.py
@@ -33,7 +33,6 @@ def data(
limit: Optional[int] = None,
id_from: Optional[str] = None,
) -> Response[Union[List[Text2TextRecord], HTTPValidationError, ErrorMessage]]:
-
path = f"/api/datasets/{name}/Text2Text/data"
params = build_param_dict(id_from, limit)
diff --git a/src/argilla/client/sdk/text_classification/api.py b/src/argilla/client/sdk/text_classification/api.py
index 8a5161fe4a..39c603411e 100644
--- a/src/argilla/client/sdk/text_classification/api.py
+++ b/src/argilla/client/sdk/text_classification/api.py
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Optional, Union
+from typing import List, Optional, Union
import httpx
diff --git a/src/argilla/client/sdk/token_classification/api.py b/src/argilla/client/sdk/token_classification/api.py
index f5a95ba36d..10a91d0f59 100644
--- a/src/argilla/client/sdk/token_classification/api.py
+++ b/src/argilla/client/sdk/token_classification/api.py
@@ -39,7 +39,6 @@ def data(
) -> Response[
Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage]
]:
-
path = f"/api/datasets/{name}/TokenClassification/data"
params = build_param_dict(id_from, limit)
diff --git a/src/argilla/client/sdk/users/api.py b/src/argilla/client/sdk/users/api.py
index 964b6f4480..c22d704b13 100644
--- a/src/argilla/client/sdk/users/api.py
+++ b/src/argilla/client/sdk/users/api.py
@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import httpx
-
from argilla.client.sdk.client import AuthenticatedClient
-from argilla.client.sdk.commons.errors_handler import handle_response_error
from argilla.client.sdk.users.models import User
diff --git a/src/argilla/labeling/text_classification/label_models.py b/src/argilla/labeling/text_classification/label_models.py
index 27f2e742ea..7186c46b9e 100644
--- a/src/argilla/labeling/text_classification/label_models.py
+++ b/src/argilla/labeling/text_classification/label_models.py
@@ -20,7 +20,6 @@
import numpy as np
from argilla import DatasetForTextClassification, TextClassificationRecord
-from argilla.client.datasets import Dataset
from argilla.labeling.text_classification.weak_labels import WeakLabels, WeakMultiLabels
_LOGGER = logging.getLogger(__name__)
@@ -368,7 +367,7 @@ def score(
MissingAnnotationError: If the ``weak_labels`` do not contain annotated records.
"""
try:
- import sklearn
+ import sklearn # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'sklearn' must be installed to compute the metrics! "
@@ -501,7 +500,7 @@ def __init__(
self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu"
):
try:
- import snorkel
+ import snorkel # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'snorkel' must be installed to use the `Snorkel` label model! "
@@ -764,8 +763,8 @@ class FlyingSquid(LabelModel):
def __init__(self, weak_labels: WeakLabels, **kwargs):
try:
- import flyingsquid
- import pgmpy
+ import flyingsquid # noqa: F401
+ import pgmpy # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'flyingsquid' must be installed to use the `FlyingSquid` label model!"
@@ -1024,7 +1023,7 @@ def score(
MissingAnnotationError: If the ``weak_labels`` do not contain annotated records.
"""
try:
- import sklearn
+ import sklearn # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'sklearn' must be installed to compute the metrics! "
diff --git a/src/argilla/listeners/listener.py b/src/argilla/listeners/listener.py
index 77e48e7896..d836bbf5e6 100644
--- a/src/argilla/listeners/listener.py
+++ b/src/argilla/listeners/listener.py
@@ -98,7 +98,7 @@ def catch_exceptions_decorator(job_func):
def wrapper(*args, **kwargs):
try:
return job_func(*args, **kwargs)
- except:
+ except: # noqa: E722
import traceback
print(traceback.format_exc())
@@ -208,7 +208,7 @@ def __listener_iteration_job__(self, *args, **kwargs):
self._LOGGER.debug(f"Evaluate condition with arguments: {condition_args}")
if self.condition(*condition_args):
- self._LOGGER.debug(f"Condition passed! Running action...")
+ self._LOGGER.debug("Condition passed! Running action...")
return self.__run_action__(ctx, *args, **kwargs)
def __compute_metrics__(self, current_api, dataset, query: str) -> Metrics:
@@ -235,7 +235,7 @@ def __run_action__(self, ctx: Optional[RGListenerContext] = None, *args, **kwarg
)
self._LOGGER.debug(f"Running action with arguments: {action_args}")
return self.action(*args, *action_args, **kwargs)
- except:
+ except: # noqa: E722
import traceback
print(traceback.format_exc())
diff --git a/src/argilla/listeners/models.py b/src/argilla/listeners/models.py
index c61f603177..5c276b9a3f 100644
--- a/src/argilla/listeners/models.py
+++ b/src/argilla/listeners/models.py
@@ -13,12 +13,15 @@
# limitations under the License.
import dataclasses
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from prodict import Prodict
from argilla.client.models import Record
+if TYPE_CHECKING:
+ from argilla.listeners import RGDatasetListener
+
@dataclasses.dataclass
class Search:
diff --git a/src/argilla/metrics/__init__.py b/src/argilla/metrics/__init__.py
index 53ef179b65..df7a911c61 100644
--- a/src/argilla/metrics/__init__.py
+++ b/src/argilla/metrics/__init__.py
@@ -19,15 +19,13 @@
entity_consistency,
entity_density,
entity_labels,
-)
-from .token_classification import f1 as ner_f1
-from .token_classification import (
mention_length,
token_capitalness,
token_frequency,
token_length,
tokens_length,
)
+from .token_classification import f1 as ner_f1
__all__ = [
text_length,
diff --git a/src/argilla/monitoring/asgi.py b/src/argilla/monitoring/asgi.py
index 4b2119be43..f861eee958 100644
--- a/src/argilla/monitoring/asgi.py
+++ b/src/argilla/monitoring/asgi.py
@@ -21,7 +21,7 @@
from argilla.monitoring.base import BaseMonitor
try:
- import starlette
+ import starlette # noqa: F401
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'starlette' must be installed to use the middleware feature! "
diff --git a/src/argilla/monitoring/base.py b/src/argilla/monitoring/base.py
index 2e8002de46..59fc8976d4 100644
--- a/src/argilla/monitoring/base.py
+++ b/src/argilla/monitoring/base.py
@@ -13,7 +13,6 @@
# limitations under the License.
import atexit
-import dataclasses
import logging
import random
import threading
diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py
index 01d7d75478..0d1470f0f9 100644
--- a/src/argilla/server/apis/v0/handlers/datasets.py
+++ b/src/argilla/server/apis/v0/handlers/datasets.py
@@ -64,7 +64,7 @@ async def list_datasets(
description="Create a new dataset",
)
async def create_dataset(
- request: CreateDatasetRequest = Body(..., description=f"The request dataset info"),
+ request: CreateDatasetRequest = Body(..., description="The request dataset info"),
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
user: User = Security(auth.get_user, scopes=["create:datasets"]),
diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py
index 96e0364697..68b92bbb37 100644
--- a/src/argilla/server/apis/v0/handlers/metrics.py
+++ b/src/argilla/server/apis/v0/handlers/metrics.py
@@ -58,7 +58,7 @@ def configure_router(router: APIRouter, cfg: TaskConfig):
path=base_metrics_endpoint,
new_path=new_base_metrics_endpoint,
router_method=router.get,
- operation_id=f"get_dataset_metrics",
+ operation_id="get_dataset_metrics",
name="get_dataset_metrics",
)
def get_dataset_metrics(
@@ -84,7 +84,7 @@ def get_dataset_metrics(
path=base_metrics_endpoint + "/{metric}:summary",
new_path=new_base_metrics_endpoint + "/{metric}:summary",
router_method=router.post,
- operation_id=f"metric_summary",
+ operation_id="metric_summary",
name="metric_summary",
)
def metric_summary(
diff --git a/src/argilla/server/apis/v0/handlers/records_update.py b/src/argilla/server/apis/v0/handlers/records_update.py
index f77d391e53..b2034265ad 100644
--- a/src/argilla/server/apis/v0/handlers/records_update.py
+++ b/src/argilla/server/apis/v0/handlers/records_update.py
@@ -14,17 +14,12 @@
from typing import Any, Dict, Optional, Union
-from fastapi import APIRouter, Depends, Query, Security
+from fastapi import APIRouter, Depends, Security
from pydantic import BaseModel
-from argilla.client.sdk.token_classification.models import TokenClassificationQuery
-from argilla.server.apis.v0.helpers import deprecate_endpoint
from argilla.server.apis.v0.models.commons.params import CommonTaskHandlerDependencies
-from argilla.server.apis.v0.models.text2text import Text2TextQuery, Text2TextRecord
-from argilla.server.apis.v0.models.text_classification import (
- TextClassificationQuery,
- TextClassificationRecord,
-)
+from argilla.server.apis.v0.models.text2text import Text2TextRecord
+from argilla.server.apis.v0.models.text_classification import TextClassificationRecord
from argilla.server.apis.v0.models.token_classification import TokenClassificationRecord
from argilla.server.commons.config import TasksFactory
from argilla.server.commons.models import TaskStatus
diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py
index 940775a132..7a21310e1d 100644
--- a/src/argilla/server/apis/v0/handlers/token_classification.py
+++ b/src/argilla/server/apis/v0/handlers/token_classification.py
@@ -23,7 +23,6 @@
metrics,
token_classification_dataset_settings,
)
-from argilla.server.apis.v0.helpers import deprecate_endpoint
from argilla.server.apis.v0.models.commons.model import BulkResponse
from argilla.server.apis.v0.models.commons.params import (
CommonTaskHandlerDependencies,
diff --git a/src/argilla/server/apis/v0/models/text2text.py b/src/argilla/server/apis/v0/models/text2text.py
index e2730f46e3..5345f95746 100644
--- a/src/argilla/server/apis/v0/models/text2text.py
+++ b/src/argilla/server/apis/v0/models/text2text.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from datetime import datetime
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, validator
@@ -27,7 +26,7 @@
SortableField,
)
from argilla.server.apis.v0.models.datasets import UpdateDatasetRequest
-from argilla.server.commons.models import PredictionStatus, TaskType
+from argilla.server.commons.models import PredictionStatus
from argilla.server.services.metrics.models import CommonTasksMetrics
from argilla.server.services.search.model import (
ServiceBaseRecordsQuery,
diff --git a/src/argilla/server/apis/v0/models/text_classification.py b/src/argilla/server/apis/v0/models/text_classification.py
index 84f1868f78..b2886fe56b 100644
--- a/src/argilla/server/apis/v0/models/text_classification.py
+++ b/src/argilla/server/apis/v0/models/text_classification.py
@@ -14,7 +14,7 @@
# limitations under the License.
from datetime import datetime
-from typing import Any, Dict, List, Optional, Union
+from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, root_validator, validator
@@ -39,14 +39,11 @@
)
from argilla.server.services.tasks.text_classification.model import (
ServiceTextClassificationDataset,
-)
-from argilla.server.services.tasks.text_classification.model import (
- ServiceTextClassificationQuery as _TextClassificationQuery,
+ TokenAttributions,
)
from argilla.server.services.tasks.text_classification.model import (
TextClassificationAnnotation as _TextClassificationAnnotation,
)
-from argilla.server.services.tasks.text_classification.model import TokenAttributions
class UpdateLabelingRule(BaseModel):
diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py
index 80122add68..256ffcdbf1 100644
--- a/src/argilla/server/daos/backend/metrics/base.py
+++ b/src/argilla/server/daos/backend/metrics/base.py
@@ -13,11 +13,14 @@
# limitations under the License.
import dataclasses
-from typing import Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from argilla.server.daos.backend.query_helpers import aggregations
from argilla.server.helpers import unflatten_dict
+if TYPE_CHECKING:
+ from argilla.server.daos.backend.client_adapters.base import IClientAdapter
+
@dataclasses.dataclass
class ElasticsearchMetric:
diff --git a/src/argilla/server/daos/backend/search/query_builder.py b/src/argilla/server/daos/backend/search/query_builder.py
index a05ed872b9..76d9e989ff 100644
--- a/src/argilla/server/daos/backend/search/query_builder.py
+++ b/src/argilla/server/daos/backend/search/query_builder.py
@@ -277,7 +277,7 @@ def map_2_es_sort_configuration(
es_sort = []
for sortable_field in sort.sort_by or [SortableField(id="id")]:
if valid_fields:
- if not sortable_field.id.split(".")[0] in valid_fields:
+ if sortable_field.id.split(".")[0] not in valid_fields:
raise AssertionError(
f"Wrong sort id {sortable_field.id}. Valid values are: "
f"{[str(v) for v in valid_fields]}"
diff --git a/src/argilla/server/daos/models/datasets.py b/src/argilla/server/daos/models/datasets.py
index 22a57e2019..027e4a7011 100644
--- a/src/argilla/server/daos/models/datasets.py
+++ b/src/argilla/server/daos/models/datasets.py
@@ -46,7 +46,7 @@ def id(self) -> str:
"""The dataset id. Compounded by owner and name"""
return self.build_dataset_id(self.name, self.owner)
- def dict(self, *args, **kwargs) -> "DictStrAny":
+ def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
Extends base component dict extending object properties
and user defined extended fields
diff --git a/src/argilla/server/daos/models/records.py b/src/argilla/server/daos/models/records.py
index 43776057e6..117d39581e 100644
--- a/src/argilla/server/daos/models/records.py
+++ b/src/argilla/server/daos/models/records.py
@@ -225,7 +225,7 @@ def extended_fields(self) -> Dict[str, Any]:
"score": self.scores,
}
- def dict(self, *args, **kwargs) -> "DictStrAny":
+ def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
Extends base component dict extending object properties
and user defined extended fields
diff --git a/src/argilla/server/errors/api_errors.py b/src/argilla/server/errors/api_errors.py
index 01fb75c81f..1cc0f6eb50 100644
--- a/src/argilla/server/errors/api_errors.py
+++ b/src/argilla/server/errors/api_errors.py
@@ -15,7 +15,7 @@
import logging
from typing import Any, Dict
-from fastapi import HTTPException, Request, status
+from fastapi import HTTPException, Request
from fastapi.exception_handlers import http_exception_handler
from pydantic import BaseModel
diff --git a/src/argilla/server/errors/base_errors.py b/src/argilla/server/errors/base_errors.py
index ab69d7ff7b..36ed5bd700 100644
--- a/src/argilla/server/errors/base_errors.py
+++ b/src/argilla/server/errors/base_errors.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Type, Union
+from typing import Any, Optional, Type, Union
import pydantic
from starlette import status
diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py
index 7e1ece9e26..67e146d670 100644
--- a/src/argilla/server/server.py
+++ b/src/argilla/server/server.py
@@ -16,7 +16,6 @@
"""
This module configures the global fastapi application
"""
-import fileinput
import glob
import inspect
import logging
diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py
index 2a3bce6f9a..d67914d354 100644
--- a/src/argilla/server/services/datasets.py
+++ b/src/argilla/server/services/datasets.py
@@ -146,7 +146,7 @@ def delete(self, user: User, dataset: ServiceDataset):
self.__dao__.delete_dataset(dataset)
else:
raise ForbiddenOperationError(
- f"You don't have the necessary permissions to delete this dataset. "
+ "You don't have the necessary permissions to delete this dataset. "
"Only dataset creators or administrators can delete datasets"
)
diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py
index 27897efe11..703adaab55 100644
--- a/src/argilla/server/services/storage/service.py
+++ b/src/argilla/server/services/storage/service.py
@@ -13,7 +13,7 @@
# limitations under the License.
import dataclasses
-from typing import Any, Dict, List, Optional, Type
+from typing import List, Optional, Type
from fastapi import Depends
@@ -94,7 +94,7 @@ async def delete_records(
else:
if not user.is_superuser() and user.username != dataset.created_by:
raise ForbiddenOperationError(
- f"You don't have the necessary permissions to delete records on this dataset. "
+ "You don't have the necessary permissions to delete records on this dataset. "
"Only dataset creators or administrators can delete datasets"
)
diff --git a/src/argilla/server/services/tasks/text2text/models.py b/src/argilla/server/services/tasks/text2text/models.py
index bccc2b912d..f8d1e8438c 100644
--- a/src/argilla/server/services/tasks/text2text/models.py
+++ b/src/argilla/server/services/tasks/text2text/models.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
@@ -21,9 +20,7 @@
from argilla.server.services.datasets import ServiceBaseDataset
from argilla.server.services.search.model import (
ServiceBaseRecordsQuery,
- ServiceBaseSearchResultsAggregations,
ServiceScoreRange,
- ServiceSearchResults,
)
from argilla.server.services.tasks.commons import (
ServiceBaseAnnotation,
@@ -91,4 +88,3 @@ class ServiceText2TextQuery(ServiceBaseRecordsQuery):
class ServiceText2TextDataset(ServiceBaseDataset):
task: TaskType = Field(default=TaskType.text2text, const=True)
- pass
diff --git a/src/argilla/server/services/tasks/text2text/service.py b/src/argilla/server/services/tasks/text2text/service.py
index 48f5af61b3..a83e219153 100644
--- a/src/argilla/server/services/tasks/text2text/service.py
+++ b/src/argilla/server/services/tasks/text2text/service.py
@@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, List, Optional, Type
+from typing import Iterable, List, Optional
from fastapi import Depends
from argilla.server.commons.config import TasksFactory
-from argilla.server.services.metrics.models import ServiceBaseTaskMetrics
from argilla.server.services.search.model import (
ServiceSearchResults,
ServiceSortableField,
diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py
index 2450d9731d..a94a2632a4 100644
--- a/src/argilla/server/services/tasks/text_classification/model.py
+++ b/src/argilla/server/services/tasks/text_classification/model.py
@@ -99,7 +99,7 @@ def check_label_length(cls, class_label):
assert 1 <= len(class_label) <= DEFAULT_MAX_KEYWORD_LENGTH, (
f"Class name '{class_label}' exceeds max length of {DEFAULT_MAX_KEYWORD_LENGTH}"
if len(class_label) > DEFAULT_MAX_KEYWORD_LENGTH
- else f"Class name must not be empty"
+ else "Class name must not be empty"
)
return class_label
diff --git a/src/argilla/utils/span_utils.py b/src/argilla/utils/span_utils.py
index 201039cb4d..4a236c6083 100644
--- a/src/argilla/utils/span_utils.py
+++ b/src/argilla/utils/span_utils.py
@@ -111,7 +111,7 @@ def validate(self, spans: List[Tuple[str, int, int]]):
if misaligned_spans_errors:
spans = "\n".join(misaligned_spans_errors)
- message += f"Following entity spans are not aligned with provided tokenization\n"
+ message += "Following entity spans are not aligned with provided tokenization\n"
message += f"Spans:\n{spans}\n"
message += f"Tokens:\n{self.tokens}"
diff --git a/tests/client/apis/test_base.py b/tests/client/apis/test_base.py
index 189969dbb6..1101d9849a 100644
--- a/tests/client/apis/test_base.py
+++ b/tests/client/apis/test_base.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla.client import api
from argilla.client.apis import AbstractApi, api_compatibility
from argilla.client.sdk._helpers import handle_response_error
diff --git a/tests/client/conftest.py b/tests/client/conftest.py
index ae01679272..ab95abdd6f 100644
--- a/tests/client/conftest.py
+++ b/tests/client/conftest.py
@@ -15,10 +15,9 @@
import datetime
from typing import List
-import pytest
-
import argilla
import argilla as ar
+import pytest
from argilla.client.sdk.datasets.models import TaskType
from argilla.client.sdk.text2text.models import (
CreationText2TextRecord,
diff --git a/tests/client/functional_tests/test_record_update.py b/tests/client/functional_tests/test_record_update.py
index 345149ca69..8c52cb4058 100644
--- a/tests/client/functional_tests/test_record_update.py
+++ b/tests/client/functional_tests/test_record_update.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla.client.api import active_api
from argilla.client.sdk.commons.errors import NotFoundApiError
diff --git a/tests/client/functional_tests/test_scan_raw_records.py b/tests/client/functional_tests/test_scan_raw_records.py
index f58b18006a..7af6af9b6f 100644
--- a/tests/client/functional_tests/test_scan_raw_records.py
+++ b/tests/client/functional_tests/test_scan_raw_records.py
@@ -13,8 +13,6 @@
# limitations under the License.
import pytest
-
-import argilla
from argilla.client.api import active_api
from argilla.client.sdk.token_classification.models import TokenClassificationRecord
@@ -28,9 +26,8 @@ def test_scan_records(
gutenberg_spacy_ner,
fields,
):
- import pandas as pd
-
import argilla as rg
+ import pandas as pd
data = active_api().datasets.scan(
name=gutenberg_spacy_ner,
diff --git a/tests/client/sdk/commons/api.py b/tests/client/sdk/commons/api.py
index 92eeaa9ef6..b7ca635486 100644
--- a/tests/client/sdk/commons/api.py
+++ b/tests/client/sdk/commons/api.py
@@ -14,8 +14,6 @@
# limitations under the License.
import httpx
import pytest
-from httpx import Response as HttpxResponse
-
from argilla.client.sdk.commons.api import (
build_bulk_response,
build_data_response,
@@ -29,6 +27,7 @@
ValidationError,
)
from argilla.client.sdk.text_classification.models import TextClassificationRecord
+from httpx import Response as HttpxResponse
def test_text2text_bulk(sdk_client, mocked_client, bulk_text2text_data, monkeypatch):
diff --git a/tests/client/sdk/commons/test_client.py b/tests/client/sdk/commons/test_client.py
index 5b1dd8bedd..cb801f12ef 100644
--- a/tests/client/sdk/commons/test_client.py
+++ b/tests/client/sdk/commons/test_client.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla.client.api import active_api
from argilla.client.sdk.client import Client
diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py
index 7186d02816..75ef26042b 100644
--- a/tests/client/sdk/conftest.py
+++ b/tests/client/sdk/conftest.py
@@ -16,9 +16,8 @@
from datetime import datetime
from typing import Any, Dict, List
-import pytest
-
import argilla as ar
+import pytest
from argilla._constants import DEFAULT_API_KEY
from argilla.client.sdk.client import AuthenticatedClient
from argilla.client.sdk.text2text.models import (
diff --git a/tests/client/sdk/datasets/test_api.py b/tests/client/sdk/datasets/test_api.py
index ec46434227..7a5fba08d0 100644
--- a/tests/client/sdk/datasets/test_api.py
+++ b/tests/client/sdk/datasets/test_api.py
@@ -14,7 +14,6 @@
# limitations under the License.
import httpx
import pytest
-
from argilla._constants import DEFAULT_API_KEY
from argilla.client.sdk.client import AuthenticatedClient
from argilla.client.sdk.commons.errors import (
@@ -22,14 +21,8 @@
NotFoundApiError,
ValidationApiError,
)
-from argilla.client.sdk.commons.models import (
- ErrorMessage,
- HTTPValidationError,
- Response,
- ValidationError,
-)
from argilla.client.sdk.datasets.api import _build_response, get_dataset
-from argilla.client.sdk.datasets.models import Dataset, TaskType
+from argilla.client.sdk.datasets.models import Dataset
from argilla.client.sdk.text_classification.models import TextClassificationBulkData
diff --git a/tests/client/sdk/datasets/test_models.py b/tests/client/sdk/datasets/test_models.py
index a10ea4acd6..c490db0e9b 100644
--- a/tests/client/sdk/datasets/test_models.py
+++ b/tests/client/sdk/datasets/test_models.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from argilla.client.sdk.datasets.models import Dataset, TaskType
from argilla.server.apis.v0.models.datasets import Dataset as ServerDataset
diff --git a/tests/client/sdk/text2text/test_models.py b/tests/client/sdk/text2text/test_models.py
index 811d56f61a..57f2e70178 100644
--- a/tests/client/sdk/text2text/test_models.py
+++ b/tests/client/sdk/text2text/test_models.py
@@ -16,7 +16,6 @@
from datetime import datetime
import pytest
-
from argilla.client.models import Text2TextRecord
from argilla.client.sdk.text2text.models import (
CreationText2TextRecord,
diff --git a/tests/client/sdk/text_classification/test_models.py b/tests/client/sdk/text_classification/test_models.py
index 24bee1eed1..9bebb747c5 100644
--- a/tests/client/sdk/text_classification/test_models.py
+++ b/tests/client/sdk/text_classification/test_models.py
@@ -16,7 +16,6 @@
from datetime import datetime
import pytest
-
from argilla.client.models import TextClassificationRecord, TokenAttributions
from argilla.client.sdk.text_classification.models import (
ClassPrediction,
diff --git a/tests/client/sdk/token_classification/test_models.py b/tests/client/sdk/token_classification/test_models.py
index bf0af8e502..4e419b07d2 100644
--- a/tests/client/sdk/token_classification/test_models.py
+++ b/tests/client/sdk/token_classification/test_models.py
@@ -16,7 +16,6 @@
from datetime import datetime
import pytest
-
from argilla.client.models import TokenClassificationRecord
from argilla.client.sdk.token_classification.models import (
CreationTokenClassificationRecord,
diff --git a/tests/client/sdk/users/test_api.py b/tests/client/sdk/users/test_api.py
index 34c053ff8a..8eb8f2b3c7 100644
--- a/tests/client/sdk/users/test_api.py
+++ b/tests/client/sdk/users/test_api.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla.client.sdk.client import AuthenticatedClient
from argilla.client.sdk.commons.errors import BaseClientError, UnauthorizedApiError
from argilla.client.sdk.users.api import whoami
diff --git a/tests/client/test_api.py b/tests/client/test_api.py
index 4f80a66deb..0c52482c27 100644
--- a/tests/client/test_api.py
+++ b/tests/client/test_api.py
@@ -15,14 +15,13 @@
import concurrent.futures
import datetime
from time import sleep
-from typing import Any, Iterable, List
+from typing import Any, Iterable
+import argilla as ar
import datasets
import httpx
import pandas as pd
import pytest
-
-import argilla as ar
from argilla._constants import (
_OLD_WORKSPACE_HEADER_NAME,
DEFAULT_API_KEY,
@@ -46,6 +45,7 @@
from argilla.server.apis.v0.models.text_classification import (
TextClassificationSearchResults,
)
+
from tests.helpers import SecuredClient
from tests.server.test_api import create_some_data_for_text_classification
diff --git a/tests/client/test_asgi.py b/tests/client/test_asgi.py
index 41565b7aa8..465e20136d 100644
--- a/tests/client/test_asgi.py
+++ b/tests/client/test_asgi.py
@@ -16,17 +16,16 @@
import time
from typing import Any, Dict
-from fastapi import FastAPI
-from starlette.applications import Starlette
-from starlette.responses import JSONResponse, PlainTextResponse
-from starlette.testclient import TestClient
-
import argilla
from argilla.monitoring.asgi import (
ArgillaLogHTTPMiddleware,
text_classification_mapper,
token_classification_mapper,
)
+from fastapi import FastAPI
+from starlette.applications import Starlette
+from starlette.responses import JSONResponse, PlainTextResponse
+from starlette.testclient import TestClient
def test_argilla_middleware_for_text_classification(
diff --git a/tests/client/test_client_errors.py b/tests/client/test_client_errors.py
index a19ff3dd12..f9f0cd4703 100644
--- a/tests/client/test_client_errors.py
+++ b/tests/client/test_client_errors.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla.client.sdk.commons.errors import UnauthorizedApiError
diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py
index e335e419b3..83ae429af0 100644
--- a/tests/client/test_dataset.py
+++ b/tests/client/test_dataset.py
@@ -17,12 +17,11 @@
import sys
from time import sleep
+import argilla as ar
import datasets
import pandas as pd
import pytest
import spacy
-
-import argilla as ar
from argilla.client.datasets import (
DatasetBase,
DatasetForTokenClassification,
diff --git a/tests/client/test_models.py b/tests/client/test_models.py
index 6d5d4b052b..75d37bde11 100644
--- a/tests/client/test_models.py
+++ b/tests/client/test_models.py
@@ -19,15 +19,13 @@
import numpy
import pandas as pd
import pytest
-from pydantic import ValidationError
-
-from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.client.models import (
Text2TextRecord,
TextClassificationRecord,
TokenClassificationRecord,
_Validators,
)
+from pydantic import ValidationError
@pytest.mark.parametrize(
diff --git a/tests/conftest.py b/tests/conftest.py
index a9b28296e0..33738a6db6 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -15,7 +15,6 @@
import httpx
import pytest
from _pytest.logging import LogCaptureFixture
-
from argilla.client.sdk.users import api as users_api
from argilla.server.commons import telemetry
@@ -23,10 +22,9 @@
from loguru import logger
except ModuleNotFoundError:
logger = None
-from starlette.testclient import TestClient
-
from argilla import app
from argilla.client.api import active_api
+from starlette.testclient import TestClient
from .helpers import SecuredClient
diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py
index 3edfdccf73..4499d25964 100644
--- a/tests/datasets/test_datasets.py
+++ b/tests/datasets/test_datasets.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
-
import argilla as ar
+import pytest
from argilla import TextClassificationSettings, TokenClassificationSettings
from argilla.client import api
from argilla.client.sdk.commons.errors import ForbiddenApiError
diff --git a/tests/functional_tests/datasets/test_delete_records_from_datasets.py b/tests/functional_tests/datasets/test_delete_records_from_datasets.py
index 5279ebb567..238c358338 100644
--- a/tests/functional_tests/datasets/test_delete_records_from_datasets.py
+++ b/tests/functional_tests/datasets/test_delete_records_from_datasets.py
@@ -15,7 +15,6 @@
import time
import pytest
-
from argilla.client.sdk.commons.errors import ForbiddenApiError
diff --git a/tests/functional_tests/datasets/test_update_record.py b/tests/functional_tests/datasets/test_update_record.py
index 8ffd74fabc..7eb050a4c5 100644
--- a/tests/functional_tests/datasets/test_update_record.py
+++ b/tests/functional_tests/datasets/test_update_record.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla.server.apis.v0.models.text2text import Text2TextRecord
from argilla.server.apis.v0.models.text_classification import TextClassificationRecord
from argilla.server.apis.v0.models.token_classification import TokenClassificationRecord
diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py
index 92eb8062bb..9fc0b9b439 100644
--- a/tests/functional_tests/search/test_search_service.py
+++ b/tests/functional_tests/search/test_search_service.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
-
import argilla
+import pytest
from argilla.server.apis.v0.models.commons.model import ScoreRange
from argilla.server.apis.v0.models.datasets import Dataset
from argilla.server.apis.v0.models.text_classification import (
diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py
index 0a85ec53e8..5e2a16f8bb 100644
--- a/tests/functional_tests/test_log_for_text_classification.py
+++ b/tests/functional_tests/test_log_for_text_classification.py
@@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
-
import argilla as ar
+import pytest
from argilla.client.sdk.commons.errors import (
BadRequestApiError,
GenericApiError,
ValidationApiError,
)
from argilla.server.settings import settings
-from tests.client.conftest import SUPPORTED_VECTOR_SEARCH, supported_vector_search
+
+from tests.client.conftest import SUPPORTED_VECTOR_SEARCH
from tests.helpers import SecuredClient
diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py
index f4bea51e2c..2df8c37c5c 100644
--- a/tests/functional_tests/test_log_for_token_classification.py
+++ b/tests/functional_tests/test_log_for_token_classification.py
@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
-
import argilla
+import pytest
from argilla import TokenClassificationRecord
from argilla.client import api
from argilla.client.sdk.commons.errors import NotFoundApiError
from argilla.metrics import __all__ as ALL_METRICS
from argilla.metrics import entity_consistency
+
from tests.client.conftest import SUPPORTED_VECTOR_SEARCH
from tests.helpers import SecuredClient
diff --git a/tests/helpers.py b/tests/helpers.py
index 3999376e6f..4246a8fd45 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -14,14 +14,13 @@
from typing import List
-from fastapi import FastAPI
-from starlette.testclient import TestClient
-
from argilla._constants import API_KEY_HEADER_NAME, WORKSPACE_HEADER_NAME
from argilla.client.api import active_api
from argilla.server.security import auth
from argilla.server.security.auth_provider.local.settings import settings
from argilla.server.security.auth_provider.local.users.model import UserInDB
+from fastapi import FastAPI
+from starlette.testclient import TestClient
class SecuredClient:
diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py
index 71933a9828..f33f9d8ab6 100644
--- a/tests/labeling/text_classification/test_label_errors.py
+++ b/tests/labeling/text_classification/test_label_errors.py
@@ -14,11 +14,9 @@
# limitations under the License.
import sys
+import argilla as ar
import cleanlab
import pytest
-from pkg_resources import parse_version
-
-import argilla as ar
from argilla.labeling.text_classification import find_label_errors
from argilla.labeling.text_classification.label_errors import (
MissingPredictionError,
@@ -26,6 +24,7 @@
SortBy,
_construct_s_and_psx,
)
+from pkg_resources import parse_version
@pytest.fixture(
diff --git a/tests/labeling/text_classification/test_label_models.py b/tests/labeling/text_classification/test_label_models.py
index c21bac70a2..0cde14796f 100644
--- a/tests/labeling/text_classification/test_label_models.py
+++ b/tests/labeling/text_classification/test_label_models.py
@@ -13,11 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
-from types import SimpleNamespace
import numpy as np
import pytest
-
from argilla import TextClassificationRecord
from argilla.labeling.text_classification import (
FlyingSquid,
diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py
index 3c1d29a326..c46c7cdf3e 100644
--- a/tests/labeling/text_classification/test_rule.py
+++ b/tests/labeling/text_classification/test_rule.py
@@ -14,7 +14,6 @@
# limitations under the License.
import httpx
import pytest
-
from argilla import load
from argilla.client.models import TextClassificationRecord
from argilla.client.sdk.text_classification.models import (
diff --git a/tests/labeling/text_classification/test_weak_labels.py b/tests/labeling/text_classification/test_weak_labels.py
index 0ad4cf6e97..8a42263881 100644
--- a/tests/labeling/text_classification/test_weak_labels.py
+++ b/tests/labeling/text_classification/test_weak_labels.py
@@ -18,8 +18,6 @@
import numpy as np
import pandas as pd
import pytest
-from pandas.testing import assert_frame_equal
-
from argilla import TextClassificationRecord
from argilla.client.sdk.text_classification.models import (
CreationTextClassificationRecord,
@@ -35,6 +33,7 @@
NoRulesFoundError,
WeakLabelsBase,
)
+from pandas.testing import assert_frame_equal
@pytest.fixture
diff --git a/tests/listeners/test_listener.py b/tests/listeners/test_listener.py
index 1aa6bca55a..8549036869 100644
--- a/tests/listeners/test_listener.py
+++ b/tests/listeners/test_listener.py
@@ -15,9 +15,8 @@
import time
from typing import List
-import pytest
-
import argilla as ar
+import pytest
from argilla import RGListenerContext, listener
from argilla.client.models import Record
diff --git a/tests/metrics/test_common_metrics.py b/tests/metrics/test_common_metrics.py
index b1ac68084c..4cbb6f80f6 100644
--- a/tests/metrics/test_common_metrics.py
+++ b/tests/metrics/test_common_metrics.py
@@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
-
import argilla
import argilla as ar
+import pytest
from argilla.metrics.commons import keywords, records_status, text_length
diff --git a/tests/metrics/test_token_classification.py b/tests/metrics/test_token_classification.py
index 443b7a17c8..89ee11c961 100644
--- a/tests/metrics/test_token_classification.py
+++ b/tests/metrics/test_token_classification.py
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
-
import argilla
import argilla as ar
+import pytest
from argilla.metrics import entity_consistency
from argilla.metrics.token_classification import (
Annotations,
diff --git a/tests/monitoring/test_base_monitor.py b/tests/monitoring/test_base_monitor.py
index 96dc9c349d..f3baa32730 100644
--- a/tests/monitoring/test_base_monitor.py
+++ b/tests/monitoring/test_base_monitor.py
@@ -13,7 +13,7 @@
# limitations under the License.
from time import sleep
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
from argilla import TextClassificationRecord
from argilla.client.api import Api, active_api
diff --git a/tests/monitoring/test_flair_monitoring.py b/tests/monitoring/test_flair_monitoring.py
index e8cb51b6ad..18aa38201e 100644
--- a/tests/monitoring/test_flair_monitoring.py
+++ b/tests/monitoring/test_flair_monitoring.py
@@ -15,11 +15,10 @@
def test_flair_monitoring(mocked_client, monkeypatch):
+ import argilla as ar
from flair.data import Sentence
from flair.models import SequenceTagger
- import argilla as ar
-
dataset = "test_flair_monitoring"
model = "flair/ner-english"
diff --git a/tests/monitoring/test_transformers_monitoring.py b/tests/monitoring/test_transformers_monitoring.py
index ee0487c98f..225f606518 100644
--- a/tests/monitoring/test_transformers_monitoring.py
+++ b/tests/monitoring/test_transformers_monitoring.py
@@ -14,9 +14,8 @@
from time import sleep
from typing import List, Union
-import pytest
-
import argilla
+import pytest
from argilla import TextClassificationRecord
diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py
index 2ceef02733..0276da15eb 100644
--- a/tests/server/backend/test_query_builder.py
+++ b/tests/server/backend/test_query_builder.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla.server.daos.backend.search.model import (
SortableField,
SortConfig,
diff --git a/tests/server/commons/test_records_dao.py b/tests/server/commons/test_records_dao.py
index 2dd227267a..f86d68d084 100644
--- a/tests/server/commons/test_records_dao.py
+++ b/tests/server/commons/test_records_dao.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla.server.commons.models import TaskType
from argilla.server.daos.backend import GenericElasticEngineBackend
from argilla.server.daos.models.datasets import BaseDatasetDB
diff --git a/tests/server/commons/test_settings.py b/tests/server/commons/test_settings.py
index 563ebdd3ce..9380f00eba 100644
--- a/tests/server/commons/test_settings.py
+++ b/tests/server/commons/test_settings.py
@@ -15,9 +15,8 @@
import os
import pytest
-from pydantic import ValidationError
-
from argilla.server.settings import ApiSettings
+from pydantic import ValidationError
@pytest.mark.parametrize("bad_namespace", ["Badns", "bad-ns", "12-bad-ns", "@bad"])
diff --git a/tests/server/commons/test_telemetry.py b/tests/server/commons/test_telemetry.py
index e892042649..f3ce903918 100644
--- a/tests/server/commons/test_telemetry.py
+++ b/tests/server/commons/test_telemetry.py
@@ -15,11 +15,10 @@
import uuid
import pytest
-from fastapi import Request
-
from argilla.server.commons import telemetry
from argilla.server.commons.models import TaskType
from argilla.server.errors import ServerError
+from fastapi import Request
mock_request = Request(scope={"type": "http", "headers": {}})
diff --git a/tests/server/daos/models/test_records.py b/tests/server/daos/models/test_records.py
index e2542c804f..5ebac261fc 100644
--- a/tests/server/daos/models/test_records.py
+++ b/tests/server/daos/models/test_records.py
@@ -15,7 +15,6 @@
import warnings
import pytest
-
from argilla.server.daos.models.records import BaseRecordInDB
from argilla.server.settings import settings
diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py
index 416cf16275..2d0ebebf0b 100644
--- a/tests/server/datasets/test_api.py
+++ b/tests/server/datasets/test_api.py
@@ -19,6 +19,7 @@
TextClassificationBulkRequest,
)
from argilla.server.commons.models import TaskType
+
from tests.helpers import SecuredClient
@@ -97,7 +98,7 @@ def test_fetch_dataset_using_workspaces(mocked_client: SecuredClient):
assert response.status_code == 409, response.json()
response = mocked_client.post(
- f"/api/datasets",
+ "/api/datasets",
json=request,
)
diff --git a/tests/server/datasets/test_dao.py b/tests/server/datasets/test_dao.py
index 2b208c73fd..962df52564 100644
--- a/tests/server/datasets/test_dao.py
+++ b/tests/server/datasets/test_dao.py
@@ -14,7 +14,6 @@
# limitations under the License.
import pytest
-
from argilla.server.commons.models import TaskType
from argilla.server.daos.backend import GenericElasticEngineBackend
from argilla.server.daos.datasets import DatasetsDAO
diff --git a/tests/server/datasets/test_model.py b/tests/server/datasets/test_model.py
index b664522d78..165203cd7f 100644
--- a/tests/server/datasets/test_model.py
+++ b/tests/server/datasets/test_model.py
@@ -14,10 +14,9 @@
# limitations under the License.
import pytest
-from pydantic import ValidationError
-
from argilla.server.apis.v0.models.datasets import CreateDatasetRequest
from argilla.server.commons.models import TaskType
+from pydantic import ValidationError
@pytest.mark.parametrize(
diff --git a/tests/server/functional_tests/datasets/test_get_record.py b/tests/server/functional_tests/datasets/test_get_record.py
index 5bc7deba30..ee95849bef 100644
--- a/tests/server/functional_tests/datasets/test_get_record.py
+++ b/tests/server/functional_tests/datasets/test_get_record.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from argilla.server.apis.v0.models.text2text import (
Text2TextBulkRequest,
Text2TextRecord,
diff --git a/tests/server/info/test_api.py b/tests/server/info/test_api.py
index d25b143555..41b6c23518 100644
--- a/tests/server/info/test_api.py
+++ b/tests/server/info/test_api.py
@@ -36,7 +36,7 @@ def test_api_status(mocked_client):
assert info.version == argilla_version
# Checking to not get the error dictionary service.py includes whenever something goes wrong
- assert not "error" in info.elasticsearch
+ assert "error" not in info.elasticsearch
# Checking that the first key into mem_info dictionary has a nont-none value
assert "rss" in info.mem_info is not None
diff --git a/tests/server/security/test_dao.py b/tests/server/security/test_dao.py
index cda8ea5d1b..41e03f5bc3 100644
--- a/tests/server/security/test_dao.py
+++ b/tests/server/security/test_dao.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla._constants import DEFAULT_API_KEY
from argilla.server.security.auth_provider.local.users.service import create_users_dao
diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py
index 4c4428ed06..757349d6a0 100644
--- a/tests/server/security/test_model.py
+++ b/tests/server/security/test_model.py
@@ -13,10 +13,9 @@
# limitations under the License.
import pytest
-from pydantic import ValidationError
-
from argilla.server.errors import EntityNotFoundError
from argilla.server.security.model import User
+from pydantic import ValidationError
@pytest.mark.parametrize("email", ["my@email.com", "infra@recogn.ai"])
diff --git a/tests/server/security/test_provider.py b/tests/server/security/test_provider.py
index f06b64dcf7..6d169e74dc 100644
--- a/tests/server/security/test_provider.py
+++ b/tests/server/security/test_provider.py
@@ -13,12 +13,11 @@
# limitations under the License.
import pytest
-from fastapi.security import SecurityScopes
-
from argilla._constants import DEFAULT_API_KEY
from argilla.server.security.auth_provider.local.provider import (
create_local_auth_provider,
)
+from fastapi.security import SecurityScopes
localAuth = create_local_auth_provider()
security_Scopes = SecurityScopes
diff --git a/tests/server/security/test_service.py b/tests/server/security/test_service.py
index 9bc5f315a3..8985b556a2 100644
--- a/tests/server/security/test_service.py
+++ b/tests/server/security/test_service.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla._constants import DEFAULT_API_KEY
from argilla.server.security.auth_provider.local.users.dao import create_users_dao
from argilla.server.security.auth_provider.local.users.service import UsersService
diff --git a/tests/server/test_app.py b/tests/server/test_app.py
index 21d336d081..06e6ca9f05 100644
--- a/tests/server/test_app.py
+++ b/tests/server/test_app.py
@@ -16,7 +16,6 @@
from importlib import reload
import pytest
-
from argilla.server import app
diff --git a/tests/server/text2text/test_api.py b/tests/server/text2text/test_api.py
index a5f8c36e73..c5db714f9d 100644
--- a/tests/server/text2text/test_api.py
+++ b/tests/server/text2text/test_api.py
@@ -14,14 +14,13 @@
from typing import List, Optional
import pytest
-
from argilla.server.apis.v0.models.commons.model import BulkResponse
from argilla.server.apis.v0.models.text2text import (
Text2TextBulkRequest,
- Text2TextRecord,
Text2TextRecordInputs,
Text2TextSearchResults,
)
+
from tests.client.conftest import SUPPORTED_VECTOR_SEARCH
diff --git a/tests/server/text_classification/test_api.py b/tests/server/text_classification/test_api.py
index 5ffb8c5300..6791614471 100644
--- a/tests/server/text_classification/test_api.py
+++ b/tests/server/text_classification/test_api.py
@@ -16,7 +16,6 @@
from datetime import datetime
import pytest
-
from argilla.server.apis.v0.models.commons.model import BulkResponse
from argilla.server.apis.v0.models.datasets import Dataset
from argilla.server.apis.v0.models.text_classification import (
@@ -28,6 +27,7 @@
TextClassificationSearchResults,
)
from argilla.server.commons.models import PredictionStatus
+
from tests.client.conftest import SUPPORTED_VECTOR_SEARCH
diff --git a/tests/server/text_classification/test_api_rules.py b/tests/server/text_classification/test_api_rules.py
index 6c95ed80c0..464ab427af 100644
--- a/tests/server/text_classification/test_api_rules.py
+++ b/tests/server/text_classification/test_api_rules.py
@@ -13,7 +13,6 @@
# limitations under the License.
import pytest
-
from argilla.server.apis.v0.models.text_classification import (
CreateLabelingRule,
LabelingRule,
diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py
index 18a967a637..87e344cb02 100644
--- a/tests/server/text_classification/test_model.py
+++ b/tests/server/text_classification/test_model.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-from pydantic import ValidationError
-
from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.server.apis.v0.models.text_classification import (
TextClassificationAnnotation,
@@ -27,6 +25,7 @@
ClassPrediction,
ServiceTextClassificationRecord,
)
+from pydantic import ValidationError
def test_flatten_metadata():
diff --git a/tests/server/token_classification/test_api.py b/tests/server/token_classification/test_api.py
index 29ee17ec24..8260300012 100644
--- a/tests/server/token_classification/test_api.py
+++ b/tests/server/token_classification/test_api.py
@@ -15,7 +15,6 @@
from typing import Callable
import pytest
-
from argilla.server.apis.v0.models.commons.model import BulkResponse, SortableField
from argilla.server.apis.v0.models.token_classification import (
TokenClassificationBulkRequest,
@@ -24,6 +23,7 @@
TokenClassificationSearchRequest,
TokenClassificationSearchResults,
)
+
from tests.client.conftest import SUPPORTED_VECTOR_SEARCH
diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py
index 11701da917..8d0c12bbff 100644
--- a/tests/server/token_classification/test_model.py
+++ b/tests/server/token_classification/test_model.py
@@ -14,8 +14,6 @@
# limitations under the License.
import pytest
-from pydantic import ValidationError
-
from argilla._constants import DEFAULT_MAX_KEYWORD_LENGTH
from argilla.server.apis.v0.models.token_classification import (
TokenClassificationAnnotation,
@@ -28,6 +26,7 @@
EntitySpan,
ServiceTokenClassificationRecord,
)
+from pydantic import ValidationError
def test_char_position():
diff --git a/tests/utils/test_span_utils.py b/tests/utils/test_span_utils.py
index eb53a0b038..e3ddbdc6cf 100644
--- a/tests/utils/test_span_utils.py
+++ b/tests/utils/test_span_utils.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from argilla.utils.span_utils import SpanUtils
diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py
index 47804bd72e..85fe24e9f5 100644
--- a/tests/utils/test_utils.py
+++ b/tests/utils/test_utils.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-
from argilla.utils import LazyargillaModule
From 91a77ad85244cd885fafd40c38251ecdf2e37147 Mon Sep 17 00:00:00 2001
From: Daniel Vila Suero
Date: Tue, 14 Feb 2023 18:29:43 +0100
Subject: [PATCH 06/45] Docs: Update readme with quickstart section and new
links to guides (#2333)
# Description
- [x] Documentation update
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
README.md | 81 ++++++++-----------------------------------------------
1 file changed, 11 insertions(+), 70 deletions(-)
diff --git a/README.md b/README.md
index 87474b4e3c..b034e12f15 100644
--- a/README.md
+++ b/README.md
@@ -61,7 +61,7 @@
### Advanced NLP labeling
-- Programmatic labeling using [weak supervision](https://docs.argilla.io/en/latest/guides/techniques/weak_supervision.html). Built-in label models (Snorkel, Flyingsquid)
+- Programmatic labeling using [rules and weak supervision](https://docs.argilla.io/en/latest/guides/programmatic_labeling_with_rules.html). Built-in label models (Snorkel, Flyingsquid)
- [Bulk-labeling](https://docs.argilla.io/en/latest/reference/webapp/features.html#bulk-annotate) and [search-driven annotation](https://docs.argilla.io/en/latest/guides/features/queries.html)
- Iterate on training data with any [pre-trained model](https://docs.argilla.io/en/latest/tutorials/libraries/huggingface.html) or [library](https://docs.argilla.io/en/latest/tutorials/libraries/libraries.html)
- Efficiently review and refine annotations in the UI and with Python
@@ -71,93 +71,34 @@
### Monitoring
- Close the gap between production data and data collection activities
-- [Auto-monitoring](https://docs.argilla.io/en/latest/guides/steps/3_deploying.html) for [major NLP libraries and pipelines](https://docs.argilla.io/en/latest/tutorials/libraries/libraries.html) (spaCy, Hugging Face, FlairNLP)
+- [Auto-monitoring](https://docs.argilla.io/en/latest/guides/log_load_and_prepare_data.html) for [major NLP libraries and pipelines](https://docs.argilla.io/en/latest/tutorials/libraries/libraries.html) (spaCy, Hugging Face, FlairNLP)
- [ASGI middleware](https://docs.argilla.io/en/latest/tutorials/notebooks/deploying-texttokenclassification-fastapi.html) for HTTP endpoints
-- Argilla Metrics to understand data and model issues, [like entity consistency for NER models](https://docs.argilla.io/en/latest/guides/steps/4_monitoring.html)
+- Argilla Metrics to understand data and model issues, [like entity consistency for NER models](https://docs.argilla.io/en/latest/guides/measure_datasets_with_metrics.html)
- Integrated with Kibana for custom dashboards
### Team workspaces
- Bring different users and roles into the NLP data and model lifecycles
-- Organize data collection, review and monitoring into different [workspaces](https://docs.argilla.io/en/latest/getting_started/installation/user_management.html#workspace)
+- Organize data collection, review and monitoring into different [workspaces](https://docs.argilla.io/en/latest/getting_started/installation/configurations/user_management.html)
- Manage workspace access for different users
## Quickstart
-Argilla is composed of a `Python Server` with Elasticsearch as the database layer, and a `Python Client` to create and manage datasets.
+👋 Welcome! If you have just discovered Argilla this is the best place to get started. Argilla is composed of:
-To get started you need to **install the client and the server** with `pip`:
-```bash
-
-pip install "argilla[server]"
-
-```
-
-Then you need to **run [Elasticsearch (ES)](https://www.elastic.co/elasticsearch)**.
-
-The simplest way is to use`Docker` by running:
-
-```bash
+* Argilla Client: a powerful Python library for reading and writing data into Argilla, using all the libraries you love (transformers, spaCy, datasets, and any other).
-docker run -d --name elasticsearch-for-argilla --network argilla-net -p 9200:9200 -p 9300:9300 -e "ES_JAVA_OPTS=-Xms512m -Xmx512m" -e "discovery.type=single-node" docker.elastic.co/elasticsearch/elasticsearch:8.5.3
-
-```
-> :information_source: **Check [the docs](https://docs.argilla.io/en/latest/getting_started/quickstart.html) for further options and configurations for Elasticsearch.**
-
-Finally you can **launch the server**:
-
-```bash
-
-python -m argilla
-
-```
-> :information_source: The most common error message after this step is related to the Elasticsearch instance not running. Make sure your Elasticsearch instance is running on http://localhost:9200/. If you already have an Elasticsearch instance or cluster, you point the server to its URL by using [ENV variables](#)
+* Argilla Server and UI: the API and UI for data annotation and curation.
+To get started you need to:
-🎉 You can now access Argilla UI pointing your browser at http://localhost:6900/.
+1. Launch the Argilla Server and UI.
-**The default username and password are** `argilla` **and** `1234`.
-
-Your workspace will contain no datasets. So let's use the `datasets` library to create our first datasets!
-
-First, you need to install `datasets`:
-```bash
-
-pip install datasets
-
-```
-
-Then go to your Python IDE of choice and run:
-```python
-
-import pandas as pd
-import argilla as rg
-from datasets import load_dataset
-
-# load dataset from the hub
-dataset = load_dataset("argilla/gutenberg_spacy-ner", split="train")
-
-# read in dataset, assuming its a dataset for text classification
-dataset_rg = rg.read_datasets(dataset, task="TokenClassification")
-
-# log the dataset to the Argilla web app
-rg.log(dataset_rg, "gutenberg_spacy-ner")
-
-# load dataset from json
-my_dataframe = pd.read_json(
- "https://raw.githubusercontent.com/recognai/datasets/main/sst-sentimentclassification.json")
-
-# convert pandas dataframe to DatasetForTextClassification
-dataset_rg = rg.DatasetForTextClassification.from_pandas(my_dataframe)
-
-# log the dataset to the Argilla web app
-rg.log(dataset_rg, name="sst-sentimentclassification")
-```
+2. Pick a tutorial and start rocking with Argilla using Jupyter Notebooks, or Google Colab.
-This will create two datasets that you can use to do a quick tour of the core features of Argilla.
+To get started follow the steps [on the Quickstart docs page](https://docs.argilla.io/en/latest/getting_started/quickstart.html).
> 🚒 **If you find issues, get direct support from the team and other community members on the [Slack Community](https://join.slack.com/t/rubrixworkspace/shared_invite/zt-whigkyjn-a3IUJLD7gDbTZ0rKlvcJ5g)**
-For getting started with your own use cases, [go to the docs](https://docs.argilla.io).
## Principles
- **Open**: Argilla is free, open-source, and 100% compatible with major NLP libraries (Hugging Face transformers, spaCy, Stanford Stanza, Flair, etc.). In fact, you can **use and combine your preferred libraries** without implementing any specific interface.
From 6d2885f8760747dfe6c0e4918a37203dc7c1d248 Mon Sep 17 00:00:00 2001
From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Date: Wed, 15 Feb 2023 14:33:33 +0100
Subject: [PATCH 07/45] Enhancement: Also validate records on assignment of
variables (#2337)
Closes #2257
Hello!
## Pull Request overview
* Add `validate_assignment = True` to the `pydantic` Config to also
apply validation on assignment e.g. via `record.prediction = ...`.
* Add tests to ensure the correct behaviour.
## What was wrong?
See #2257 for details on the issue. The core issue is that the following
script would fail with a server-side error on `rg.log`:
```python
import argilla as rg
record = rg.TextClassificationRecord(
text="Hello world, this is me!",
prediction=[("LABEL1", 0.8), ("LABEL2", 0.2)],
annotation="LABEL1",
multi_label=False,
)
rg.log(record, "test_item_assignment")
record.prediction = "rubbish"
rg.log(record, "test_item_assignment")
```
Click to see failure logs using the develop
branch
```
023-02-14 11:19:40.794 | ERROR | argilla.client.client:__log_internal__:103 -
Cannot log data in dataset 'test_item_assignment'
Error: ValueError
Details: not enough values to unpack (expected 2, got 1)
Traceback (most recent call last):
File "c:/code/argilla/demo_2257.py", line 13, in
rg.log(record, "test_item_assignment")
File "C:\code\argilla\src\argilla\client\api.py", line 157, in log
background=background,
File "C:\code\argilla\src\argilla\client\client.py", line 305, in log
return future.result()
File "C:\Users\tom\.conda\envs\argilla\lib\concurrent\futures\_base.py", line 435, in result
return self.__get_result()
File "C:\Users\tom\.conda\envs\argilla\lib\concurrent\futures\_base.py", line 384, in __get_result
raise self._exception
File "C:\code\argilla\src\argilla\client\client.py", line 107, in __log_internal__
raise ex
File "C:\code\argilla\src\argilla\client\client.py", line 99, in __log_internal__
return await api.log_async(*args, **kwargs)
File "C:\code\argilla\src\argilla\client\client.py", line 389, in log_async
records=[creation_class.from_client(r) for r in chunk],
File "C:\code\argilla\src\argilla\client\client.py", line 389, in
records=[creation_class.from_client(r) for r in chunk],
File "C:\code\argilla\src\argilla\client\sdk\text_classification\models.py", line 65, in from_client
for label, score in record.prediction
File "C:\code\argilla\src\argilla\client\sdk\text_classification\models.py", line 64, in
ClassPrediction(**{"class": label, "score": score})
ValueError: not enough values to unpack (expected 2, got 1)
0%| | 0/1 [00:00, ?it/s]
```
This is unnecessary, as we can create a client-side error using our
validators, too.
## The fix
The fix is as simple as instructing `pydantic` to also trigger the
validation on assignment, e.g. on `record.prediction = ...`.
## Relevant documentation
See `validate_assignment` in
https://docs.pydantic.dev/usage/model_config/#options.
## After the fix
Using the same script as before, the error now becomes:
```
Traceback (most recent call last):
File "c:/code/argilla/demo_2257.py", line 11, in
record.prediction = "rubbish"
File "C:\code\argilla\src\argilla\client\models.py", line 280, in __setattr__
super().__setattr__(name, value)
File "pydantic\main.py", line 445, in pydantic.main.BaseModel.__setattr__
pydantic.error_wrappers.ValidationError: 1 validation error for TextClassificationRecord
prediction
value is not a valid list (type=type_error.list)
```
Which triggers directly on the assignment rather than only when the
record is logged.
---
**Type of change**
- [x] Improvement (change adding some improvement to an existing
functionality)
**How Has This Been Tested**
All existing tests passed with my changes. Beyond that, I added some new
tests of my own. These tests fail on the `develop` branch, but pass on
this branch. This is because previously it was allowed to perform
`record.prediction = "rubbish"` as no validation was executed on it.
**Checklist**
- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [x] I added comments to my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
---
- Tom Aarsen
---
src/argilla/client/models.py | 2 ++
tests/client/test_models.py | 21 +++++++++++++++++++++
2 files changed, 23 insertions(+)
diff --git a/src/argilla/client/models.py b/src/argilla/client/models.py
index 609a92762b..845e7f8eb0 100644
--- a/src/argilla/client/models.py
+++ b/src/argilla/client/models.py
@@ -121,7 +121,9 @@ def _check_and_update_status(cls, values):
return values
class Config:
+ # https://docs.pydantic.dev/usage/model_config/#options
extra = "forbid"
+ validate_assignment = True
class BulkResponse(BaseModel):
diff --git a/tests/client/test_models.py b/tests/client/test_models.py
index 75d37bde11..bf3d72c366 100644
--- a/tests/client/test_models.py
+++ b/tests/client/test_models.py
@@ -306,3 +306,24 @@ class MockRecord(_Validators):
def test_text2text_prediction_validator(prediction, expected):
record = Text2TextRecord(text="mock", prediction=prediction)
assert record.prediction == expected
+
+
+@pytest.mark.parametrize(
+ "record",
+ [
+ TextClassificationRecord(text="This is a test"),
+ TokenClassificationRecord(
+ text="This is a test", tokens="This is a test".split()
+ ),
+ Text2TextRecord(text="This is a test"),
+ ],
+)
+def test_record_validation_on_assignment(record):
+ with pytest.raises(ValidationError):
+ record.prediction = "rubbish"
+
+ with pytest.raises(ValidationError):
+ record.annotation = [("rubbish",)]
+
+ with pytest.raises(ValidationError):
+ record.vectors = "rubbish"
From 179ffb92c73fba9b785503ee8c61f2bea290787a Mon Sep 17 00:00:00 2001
From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Date: Wed, 15 Feb 2023 16:13:42 +0100
Subject: [PATCH 08/45] Enhancement: Distinguish between error message and
context in validation of spans (#2329)
# Description
I expanded slightly on the error message provided when providing spans
that do not match the tokenization.
Consider the following example script:
```python
import argilla as rg
record = rg.TokenClassificationRecord(
text = "This is my text",
tokens=["This", "is", "my", "text"],
prediction=[("ORG", 0, 6), ("PER", 8, 14)],
)
```
The (truncated) output on the `develop` branch:
```
ValueError: Following entity spans are not aligned with provided tokenization
Spans:
- [This i] defined in This is my ...
- [my tex] defined in ...s is my text
Tokens:
['This', 'is', 'my', 'text']
```
The distinction between `defined in` and `This is` is unclear. I've
worked on this.
The (truncated) output after this PR:
```
ValueError: Following entity spans are not aligned with provided tokenization
Spans:
- [This i] defined in 'This is my ...'
- [my tex] defined in '...s is my text'
Tokens:
['This', 'is', 'my', 'text']
```
Note the additional `'`. Note that the changes rely on `repr`, so if the
snippet contains `'` itself, it uses `"` instead, e.g.:
```python
import argilla as rg
record = rg.TokenClassificationRecord(
text = "This is Tom's text",
tokens=["This", "is", "Tom", "'s", "text"],
prediction=[("ORG", 0, 6), ("PER", 8, 16)],
)
```
```
ValueError: Following entity spans are not aligned with provided tokenization
Spans:
- [This i] defined in 'This is Tom...'
- [Tom's te] defined in "...s is Tom's text"
Tokens:
['This', 'is', 'Tom', "'s", 'text']
```
**Type of change**
- [x] Improvement (change adding some improvement to an existing
functionality)
**How Has This Been Tested**
Modified the relevant tests, ensured they worked.
**Checklist**
- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I added comments to my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- Tom Aarsen
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
src/argilla/utils/span_utils.py | 9 ++-------
tests/client/test_api.py | 2 +-
tests/utils/test_span_utils.py | 4 ++--
3 files changed, 5 insertions(+), 10 deletions(-)
diff --git a/src/argilla/utils/span_utils.py b/src/argilla/utils/span_utils.py
index 4a236c6083..bed9e172f7 100644
--- a/src/argilla/utils/span_utils.py
+++ b/src/argilla/utils/span_utils.py
@@ -93,13 +93,8 @@ def validate(self, spans: List[Tuple[str, int, int]]):
self._start_to_token_idx.get(char_start),
self._end_to_token_idx.get(char_end),
):
- message = f"- [{self.text[char_start:char_end]}] defined in "
- if char_start - 5 > 0:
- message += "..."
- message += self.text[max(char_start - 5, 0) : char_end + 5]
- if char_end + 5 < len(self.text):
- message += "..."
-
+ span_str = self.text[char_start:char_end]
+ message = f"{span} - {repr(span_str)}"
misaligned_spans_errors.append(message)
if not_valid_spans_errors or misaligned_spans_errors:
diff --git a/tests/client/test_api.py b/tests/client/test_api.py
index 0c52482c27..12e66b634c 100644
--- a/tests/client/test_api.py
+++ b/tests/client/test_api.py
@@ -625,7 +625,7 @@ def test_token_classification_spans(span, valid):
with pytest.raises(
ValueError,
match="Following entity spans are not aligned with provided tokenization\n"
- r"Spans:\n- \[s\] defined in Esto es...\n"
+ r"Spans:\n\('test', 1, 2\) - 's'\n"
r"Tokens:\n\['Esto', 'es', 'una', 'prueba'\]",
):
ar.TokenClassificationRecord(
diff --git a/tests/utils/test_span_utils.py b/tests/utils/test_span_utils.py
index e3ddbdc6cf..3cde38131c 100644
--- a/tests/utils/test_span_utils.py
+++ b/tests/utils/test_span_utils.py
@@ -67,7 +67,7 @@ def test_validate_misaligned_spans():
with pytest.raises(
ValueError,
match="Following entity spans are not aligned with provided tokenization\n"
- r"Spans:\n- \[test \] defined in test this.\n"
+ r"Spans:\n\('mock', 0, 5\) - 'test '\n"
r"Tokens:\n\['test', 'this', '.'\]",
):
span_utils.validate([("mock", 0, 5)])
@@ -79,7 +79,7 @@ def test_validate_not_valid_and_misaligned_spans():
ValueError,
match=r"Following entity spans are not valid: \[\('mock', 2, 1\)\]\n"
"Following entity spans are not aligned with provided tokenization\n"
- r"Spans:\n- \[test \] defined in test this.\n"
+ r"Spans:\n\('mock', 0, 5\) - 'test '\n"
r"Tokens:\n\['test', 'this', '.'\]",
):
span_utils.validate([("mock", 2, 1), ("mock", 0, 5)])
From c83ec9e75560294d8ab58ad86c4b4b1c347b5491 Mon Sep 17 00:00:00 2001
From: Francisco Aranda
Date: Thu, 16 Feb 2023 10:39:37 +0100
Subject: [PATCH 09/45] ci: Setup black line-length in toml file (#2352)
Upgrade the black line-length parameter in the project toml file.
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
pyproject.toml | 3 +
scripts/load_data.py | 16 +-
scripts/migrations/es_migration_25042021.py | 4 +-
src/argilla/client/apis/datasets.py | 15 +-
src/argilla/client/apis/search.py | 4 +-
src/argilla/client/apis/status.py | 4 +-
src/argilla/client/client.py | 67 ++----
src/argilla/client/datasets.py | 195 ++++-------------
src/argilla/client/models.py | 49 ++---
src/argilla/client/sdk/client.py | 5 +-
src/argilla/client/sdk/commons/api.py | 16 +-
src/argilla/client/sdk/commons/errors.py | 9 +-
.../client/sdk/commons/errors_handler.py | 4 +-
src/argilla/client/sdk/commons/models.py | 4 +-
src/argilla/client/sdk/datasets/api.py | 4 +-
src/argilla/client/sdk/metrics/api.py | 4 +-
src/argilla/client/sdk/metrics/models.py | 4 +-
src/argilla/client/sdk/text2text/models.py | 5 +-
.../client/sdk/text_classification/api.py | 16 +-
.../client/sdk/text_classification/models.py | 30 +--
.../client/sdk/token_classification/api.py | 8 +-
.../client/sdk/token_classification/models.py | 18 +-
.../text_classification/label_errors.py | 30 +--
.../text_classification/label_models.py | 200 +++++-------------
.../labeling/text_classification/rule.py | 24 +--
.../text_classification/weak_labels.py | 183 ++++------------
src/argilla/listeners/listener.py | 22 +-
src/argilla/metrics/commons.py | 12 +-
src/argilla/metrics/helpers.py | 11 +-
src/argilla/metrics/models.py | 9 +-
.../metrics/token_classification/metrics.py | 32 +--
src/argilla/monitoring/_flair.py | 6 +-
src/argilla/monitoring/_spacy.py | 8 +-
src/argilla/monitoring/_transformers.py | 8 +-
src/argilla/monitoring/asgi.py | 36 +---
src/argilla/monitoring/base.py | 4 +-
src/argilla/monitoring/model_monitor.py | 3 +-
.../server/apis/v0/handlers/datasets.py | 20 +-
.../server/apis/v0/handlers/metrics.py | 4 +-
.../server/apis/v0/handlers/text2text.py | 12 +-
.../apis/v0/handlers/text_classification.py | 64 ++----
.../text_classification_dataset_settings.py | 12 +-
.../apis/v0/handlers/token_classification.py | 20 +-
.../token_classification_dataset_settings.py | 12 +-
src/argilla/server/apis/v0/handlers/users.py | 8 +-
src/argilla/server/apis/v0/helpers.py | 4 +-
.../server/apis/v0/models/commons/model.py | 4 +-
.../server/apis/v0/models/commons/params.py | 17 +-
.../server/apis/v0/models/text2text.py | 4 +-
.../apis/v0/models/text_classification.py | 23 +-
.../apis/v0/models/token_classification.py | 8 +-
.../apis/v0/validators/text_classification.py | 12 +-
.../v0/validators/token_classification.py | 22 +-
src/argilla/server/commons/config.py | 8 +-
src/argilla/server/commons/telemetry.py | 17 +-
.../backend/client_adapters/opensearch.py | 23 +-
.../server/daos/backend/generic_elastic.py | 6 +-
.../server/daos/backend/mappings/helpers.py | 6 +-
.../server/daos/backend/metrics/base.py | 8 +-
.../server/daos/backend/metrics/datasets.py | 4 +-
.../backend/metrics/text_classification.py | 16 +-
.../backend/metrics/token_classification.py | 19 +-
.../server/daos/backend/query_helpers.py | 55 ++---
.../server/daos/backend/search/model.py | 3 +-
.../daos/backend/search/query_builder.py | 23 +-
src/argilla/server/daos/datasets.py | 33 +--
src/argilla/server/daos/models/records.py | 16 +-
src/argilla/server/daos/records.py | 12 +-
src/argilla/server/errors/base_errors.py | 12 +-
src/argilla/server/helpers.py | 8 +-
src/argilla/server/responses/api_responses.py | 4 +-
src/argilla/server/routes.py | 4 +-
.../security/auth_provider/local/provider.py | 16 +-
.../security/auth_provider/local/settings.py | 8 +-
.../security/auth_provider/local/users/dao.py | 4 +-
src/argilla/server/security/model.py | 6 +-
src/argilla/server/server.py | 10 +-
src/argilla/server/services/datasets.py | 36 +---
src/argilla/server/services/info.py | 4 +-
src/argilla/server/services/search/service.py | 8 +-
.../server/services/storage/service.py | 4 +-
.../server/services/tasks/commons/models.py | 4 +-
.../server/services/tasks/text2text/models.py | 12 +-
.../labeling_rules_service.py | 20 +-
.../tasks/text_classification/metrics.py | 26 +--
.../tasks/text_classification/model.py | 59 ++----
.../tasks/text_classification/service.py | 37 +---
.../tasks/token_classification/metrics.py | 50 ++---
.../tasks/token_classification/model.py | 33 +--
.../tasks/token_classification/service.py | 8 +-
src/argilla/utils/span_utils.py | 8 +-
src/argilla/utils/utils.py | 16 +-
tests/client/conftest.py | 47 +---
tests/client/sdk/commons/api.py | 12 +-
tests/client/sdk/conftest.py | 19 +-
tests/client/sdk/datasets/test_models.py | 8 +-
tests/client/sdk/text2text/test_models.py | 4 +-
.../sdk/text_classification/test_models.py | 20 +-
tests/client/sdk/users/test_api.py | 8 +-
tests/client/test_api.py | 16 +-
tests/client/test_dataset.py | 167 ++++-----------
tests/client/test_init.py | 4 +-
tests/client/test_models.py | 88 ++------
tests/datasets/test_datasets.py | 4 +-
.../test_delete_records_from_datasets.py | 14 +-
.../search/test_search_service.py | 20 +-
.../test_log_for_text_classification.py | 13 +-
.../test_log_for_token_classification.py | 8 +-
.../text_classification/test_label_errors.py | 45 +---
.../text_classification/test_label_models.py | 109 +++-------
.../labeling/text_classification/test_rule.py | 16 +-
.../text_classification/test_weak_labels.py | 198 ++++-------------
tests/listeners/test_listener.py | 4 +-
tests/metrics/test_common_metrics.py | 4 +-
tests/monitoring/test_monitor.py | 6 +-
.../test_transformers_monitoring.py | 9 +-
tests/server/backend/test_query_builder.py | 8 +-
tests/server/commons/test_telemetry.py | 4 +-
tests/server/datasets/test_api.py | 15 +-
tests/server/metrics/test_api.py | 16 +-
tests/server/security/test_model.py | 12 +-
tests/server/security/test_provider.py | 4 +-
tests/server/test_errors.py | 5 +-
tests/server/text2text/test_api.py | 4 +-
tests/server/text_classification/test_api.py | 37 +---
.../text_classification/test_api_rules.py | 88 ++------
.../text_classification/test_api_settings.py | 12 +-
.../server/text_classification/test_model.py | 24 +--
tests/server/token_classification/test_api.py | 17 +-
.../token_classification/test_api_settings.py | 12 +-
.../server/token_classification/test_model.py | 12 +-
tests/utils/test_span_utils.py | 8 +-
tests/utils/test_utils.py | 8 +-
133 files changed, 750 insertions(+), 2343 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index f28608d42e..42133cb2c7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -139,3 +139,6 @@ exclude = [
[tool.ruff.per-file-ignores]
# Ignore imported but unused;
"__init__.py" = ["F401"]
+
+[tool.black]
+line-length = 120
diff --git a/scripts/load_data.py b/scripts/load_data.py
index 06405d7650..46ae50a959 100644
--- a/scripts/load_data.py
+++ b/scripts/load_data.py
@@ -64,9 +64,7 @@ def load_news_text_summarization():
rg.log(
dataset_rg,
name="news-text-summarization",
- tags={
- "description": "A text summarization dataset with news pieces and their predicted summaries."
- },
+ tags={"description": "A text summarization dataset with news pieces and their predicted summaries."},
)
@staticmethod
@@ -80,18 +78,14 @@ def load_news_programmatic_labeling():
)
# Define labeling schema to avoid UI user modification
- settings = rg.TextClassificationSettings(
- label_schema={"World", "Sports", "Sci/Tech", "Business"}
- )
+ settings = rg.TextClassificationSettings(label_schema={"World", "Sports", "Sci/Tech", "Business"})
rg.configure_dataset(name="news-programmatic-labeling", settings=settings)
# Log the dataset
rg.log(
dataset_rg,
name="news-programmatic-labeling",
- tags={
- "description": "The AG News with programmatic labeling rules (see weak labeling mode in the UI)."
- },
+ tags={"description": "The AG News with programmatic labeling rules (see weak labeling mode in the UI)."},
)
# Define queries and patterns for each category (using Elasticsearch DSL)
@@ -103,9 +97,7 @@ def load_news_programmatic_labeling():
]
# Define rules
- rules = [
- Rule(query=term, label=label) for terms, label in queries for term in terms
- ]
+ rules = [Rule(query=term, label=label) for terms, label in queries for term in terms]
# Add rules to the dataset
add_rules(dataset="news-programmatic-labeling", rules=rules)
diff --git a/scripts/migrations/es_migration_25042021.py b/scripts/migrations/es_migration_25042021.py
index 63776fe977..1f5a04d602 100644
--- a/scripts/migrations/es_migration_25042021.py
+++ b/scripts/migrations/es_migration_25042021.py
@@ -47,9 +47,7 @@ def batcher(iterable, n, fillvalue=None):
return zip_longest(*args, fillvalue=fillvalue)
-def map_doc_2_action(
- index: str, doc: Dict[str, Any], task: TaskType
-) -> Optional[Dict[str, Any]]:
+def map_doc_2_action(index: str, doc: Dict[str, Any], task: TaskType) -> Optional[Dict[str, Any]]:
"""Configures bulk action"""
doc_data = doc["_source"]
new_record = {
diff --git a/src/argilla/client/apis/datasets.py b/src/argilla/client/apis/datasets.py
index 5fff021362..c887300cb1 100644
--- a/src/argilla/client/apis/datasets.py
+++ b/src/argilla/client/apis/datasets.py
@@ -174,9 +174,7 @@ def scan(
"""
- url = (
- f"{self._API_PREFIX}/{name}/records/:search?limit={self.DEFAULT_SCAN_SIZE}"
- )
+ url = f"{self._API_PREFIX}/{name}/records/:search?limit={self.DEFAULT_SCAN_SIZE}"
query = self._parse_query(query=query)
if limit == 0:
@@ -278,13 +276,10 @@ def delete_records(
def __save_settings__(self, dataset: _DatasetApiModel, settings: Settings):
if __TASK_TO_SETTINGS__.get(dataset.task) != type(settings):
raise ValueError(
- f"The provided settings type {type(settings)} cannot be applied to dataset."
- " Task type mismatch"
+ f"The provided settings type {type(settings)} cannot be applied to dataset." " Task type mismatch"
)
- settings_ = self._SettingsApiModel(
- label_schema={"labels": [label for label in settings.label_schema]}
- )
+ settings_ = self._SettingsApiModel(label_schema={"labels": [label for label in settings.label_schema]})
with api_compatibility(self, min_version=self.__SETTINGS_MIN_API_VERSION__):
self.http_client.put(
@@ -305,9 +300,7 @@ def load_settings(self, name: str) -> Optional[Settings]:
dataset = self.find_by_name(name)
try:
with api_compatibility(self, min_version=self.__SETTINGS_MIN_API_VERSION__):
- response = self.http_client.get(
- f"{self._API_PREFIX}/{dataset.task}/{dataset.name}/settings"
- )
+ response = self.http_client.get(f"{self._API_PREFIX}/{dataset.task}/{dataset.name}/settings")
return __TASK_TO_SETTINGS__.get(dataset.task).from_dict(response)
except NotFoundApiError:
return None
diff --git a/src/argilla/client/apis/search.py b/src/argilla/client/apis/search.py
index 07de6bb71e..7567c78030 100644
--- a/src/argilla/client/apis/search.py
+++ b/src/argilla/client/apis/search.py
@@ -80,7 +80,5 @@ def search_records(
return SearchResults(
total=response["total"],
- records=[
- record_class.parse_obj(r).to_client() for r in response["records"]
- ],
+ records=[record_class.parse_obj(r).to_client() for r in response["records"]],
)
diff --git a/src/argilla/client/apis/status.py b/src/argilla/client/apis/status.py
index d03dcae852..87529e5da4 100644
--- a/src/argilla/client/apis/status.py
+++ b/src/argilla/client/apis/status.py
@@ -63,9 +63,7 @@ def __enter__(self):
if api_version.is_devrelease:
api_version = parse(api_version.base_version)
if not api_version >= self._min_version:
- raise ApiCompatibilityError(
- str(self._min_version), api_version=api_version
- )
+ raise ApiCompatibilityError(str(self._min_version), api_version=api_version)
pass
def __exit__(
diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py
index 5a3974ede3..ef36f35800 100644
--- a/src/argilla/client/client.py
+++ b/src/argilla/client/client.py
@@ -100,16 +100,12 @@ async def __log_internal__(api: "Argilla", *args, **kwargs):
except Exception as ex:
dataset = kwargs["name"]
_LOGGER.error(
- f"\nCannot log data in dataset '{dataset}'\n"
- f"Error: {type(ex).__name__}\n"
- f"Details: {ex}"
+ f"\nCannot log data in dataset '{dataset}'\n" f"Error: {type(ex).__name__}\n" f"Details: {ex}"
)
raise ex
def log(self, *args, **kwargs) -> Future:
- return asyncio.run_coroutine_threadsafe(
- self.__log_internal__(self.__api__, *args, **kwargs), self.__loop__
- )
+ return asyncio.run_coroutine_threadsafe(self.__log_internal__(self.__api__, *args, **kwargs), self.__loop__)
class Argilla:
@@ -172,10 +168,7 @@ def __del__(self):
def client(self) -> AuthenticatedClient:
"""The underlying authenticated HTTP client"""
warnings.warn(
- message=(
- "This prop will be removed in next release. "
- "Please use the http_client prop instead."
- ),
+ message=("This prop will be removed in next release. " "Please use the http_client prop instead."),
category=UserWarning,
)
return self._client
@@ -217,10 +210,7 @@ def set_workspace(self, workspace: str):
if workspace != self.get_workspace():
if workspace == self._user.username:
self._client.headers.pop(WORKSPACE_HEADER_NAME, workspace)
- elif (
- self._user.workspaces is not None
- and workspace not in self._user.workspaces
- ):
+ elif self._user.workspaces is not None and workspace not in self._user.workspaces:
raise Exception(f"Wrong provided workspace {workspace}")
self._client.headers[WORKSPACE_HEADER_NAME] = workspace
self._client.headers[_OLD_WORKSPACE_HEADER_NAME] = workspace
@@ -370,10 +360,7 @@ async def log_async(
bulk_class = Text2TextBulkData
creation_class = CreationText2TextRecord
else:
- raise InputValueError(
- f"Unknown record type {record_type}. Available values are"
- f" {Record.__args__}"
- )
+ raise InputValueError(f"Unknown record type {record_type}. Available values are" f" {Record.__args__}")
processed, failed = 0, 0
progress_bar = tqdm(total=len(records), disable=not verbose)
@@ -398,18 +385,11 @@ async def log_async(
# TODO: improve logging policy in library
if verbose:
- _LOGGER.info(
- f"Processed {processed} records in dataset {name}. Failed: {failed}"
- )
+ _LOGGER.info(f"Processed {processed} records in dataset {name}. Failed: {failed}")
workspace = self.get_workspace()
- if (
- not workspace
- ): # Just for backward comp. with datasets with no workspaces
+ if not workspace: # Just for backward comp. with datasets with no workspaces
workspace = "-"
- print(
- f"{processed} records logged to"
- f" {self._client.base_url}/datasets/{workspace}/{name}"
- )
+ print(f"{processed} records logged to" f" {self._client.base_url}/datasets/{workspace}/{name}")
# Creating a composite BulkResponse with the total processed and failed
return BulkResponse(dataset=name, processed=processed, failed=failed)
@@ -488,8 +468,7 @@ def load(
raise ValueError(
"The argument `as_pandas` is deprecated and will be removed in a future"
" version. Please adapt your code accordingly. ",
- "If you want a pandas DataFrame do"
- " `rg.load('my_dataset').to_pandas()`.",
+ "If you want a pandas DataFrame do" " `rg.load('my_dataset').to_pandas()`.",
)
try:
@@ -525,9 +504,7 @@ def load(
def dataset_metrics(self, name: str) -> List[MetricInfo]:
response = datasets_api.get_dataset(self._client, name)
- response = metrics_api.get_dataset_metrics(
- self._client, name=name, task=response.parsed.task
- )
+ response = metrics_api.get_dataset_metrics(self._client, name=name, task=response.parsed.task)
return response.parsed
@@ -572,9 +549,7 @@ def add_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]):
rule=rule,
)
except AlreadyExistsApiError:
- _LOGGER.warning(
- f"Rule {rule} already exists. Please, update the rule instead."
- )
+ _LOGGER.warning(f"Rule {rule} already exists. Please, update the rule instead.")
except Exception as ex:
_LOGGER.warning(f"Cannot create rule {rule}: {ex}")
@@ -593,36 +568,26 @@ def update_dataset_labeling_rules(
)
except NotFoundApiError:
_LOGGER.info(f"Rule {rule} does not exists, creating...")
- text_classification_api.add_dataset_labeling_rule(
- self._client, name=dataset, rule=rule
- )
+ text_classification_api.add_dataset_labeling_rule(self._client, name=dataset, rule=rule)
except Exception as ex:
_LOGGER.warning(f"Cannot update rule {rule}: {ex}")
def delete_dataset_labeling_rules(self, dataset: str, rules: List[LabelingRule]):
for rule in rules:
try:
- text_classification_api.delete_dataset_labeling_rule(
- self._client, name=dataset, rule=rule
- )
+ text_classification_api.delete_dataset_labeling_rule(self._client, name=dataset, rule=rule)
except Exception as ex:
_LOGGER.warning(f"Cannot delete rule {rule}: {ex}")
"""Deletes the dataset labeling rules"""
for rule in rules:
- text_classification_api.delete_dataset_labeling_rule(
- self._client, name=dataset, rule=rule
- )
+ text_classification_api.delete_dataset_labeling_rule(self._client, name=dataset, rule=rule)
def fetch_dataset_labeling_rules(self, dataset: str) -> List[LabelingRule]:
- response = text_classification_api.fetch_dataset_labeling_rules(
- self._client, name=dataset
- )
+ response = text_classification_api.fetch_dataset_labeling_rules(self._client, name=dataset)
return [LabelingRule.parse_obj(data) for data in response.parsed]
- def rule_metrics_for_dataset(
- self, dataset: str, rule: LabelingRule
- ) -> LabelingRuleMetricsSummary:
+ def rule_metrics_for_dataset(self, dataset: str, rule: LabelingRule) -> LabelingRuleMetricsSummary:
response = text_classification_api.dataset_rule_metrics(
self._client, name=dataset, query=rule.query, label=rule.label
)
diff --git a/src/argilla/client/datasets.py b/src/argilla/client/datasets.py
index 952a3a96ef..35cf564883 100644
--- a/src/argilla/client/datasets.py
+++ b/src/argilla/client/datasets.py
@@ -103,9 +103,7 @@ def _record_init_args(cls) -> List[str]:
def __init__(self, records: Optional[List[Record]] = None):
if self._RECORD_TYPE is None:
- raise NotImplementedError(
- "A Dataset implementation has to define a `_RECORD_TYPE`!"
- )
+ raise NotImplementedError("A Dataset implementation has to define a `_RECORD_TYPE`!")
self._records = records or []
if self._records:
@@ -179,8 +177,7 @@ def to_datasets(self) -> "datasets.Dataset":
del ds_dict["metadata"]
dataset = datasets.Dataset.from_dict(ds_dict)
_LOGGER.warning(
- "The 'metadata' of the records were removed, since it was incompatible"
- " with the 'datasets' format."
+ "The 'metadata' of the records were removed, since it was incompatible" " with the 'datasets' format."
)
return dataset
@@ -221,15 +218,10 @@ def _prepare_dataset_and_column_mapping(
import datasets
if isinstance(dataset, datasets.DatasetDict):
- raise ValueError(
- "`datasets.DatasetDict` are not supported. Please, select the dataset"
- " split before."
- )
+ raise ValueError("`datasets.DatasetDict` are not supported. Please, select the dataset" " split before.")
# clean column mappings
- column_mapping = {
- key: val for key, val in column_mapping.items() if val is not None
- }
+ column_mapping = {key: val for key, val in column_mapping.items() if val is not None}
cols_to_be_renamed, cols_to_be_joined = {}, {}
for field, col in column_mapping.items():
@@ -263,9 +255,7 @@ def _remove_unsupported_columns(
The dataset with unsupported columns removed.
"""
not_supported_columns = [
- col
- for col in dataset.column_names
- if col not in cls._record_init_args() + extra_columns
+ col for col in dataset.column_names if col not in cls._record_init_args() + extra_columns
]
if not_supported_columns:
@@ -279,9 +269,7 @@ def _remove_unsupported_columns(
return dataset
@staticmethod
- def _join_datasets_columns_and_delete(
- row: Dict[str, Any], columns: List[str]
- ) -> Dict[str, Any]:
+ def _join_datasets_columns_and_delete(row: Dict[str, Any], columns: List[str]) -> Dict[str, Any]:
"""Joins columns of a `datasets.Dataset` row into a dict, and deletes the single columns.
Updates the ``row`` dictionary!
@@ -354,9 +342,7 @@ def from_pandas(cls, dataframe: pd.DataFrame) -> "Dataset":
Returns:
The imported records in a argilla Dataset.
"""
- not_supported_columns = [
- col for col in dataframe.columns if col not in cls._record_init_args()
- ]
+ not_supported_columns = [col for col in dataframe.columns if col not in cls._record_init_args()]
if not_supported_columns:
_LOGGER.warning(
"Following columns are not supported by the"
@@ -461,14 +447,10 @@ def prepare_for_training(
)
# check if train sizes sum up to 1
- assert (train_size + test_size) == 1, ValueError(
- "`train_size` and `test_size` must sum to 1."
- )
+ assert (train_size + test_size) == 1, ValueError("`train_size` and `test_size` must sum to 1.")
# check for annotations
- assert any([rec.annotation for rec in self._records]), ValueError(
- "Dataset has no annotations."
- )
+ assert any([rec.annotation for rec in self._records]), ValueError("Dataset has no annotations.")
# shuffle records
shuffled_records = self._records.copy()
@@ -484,13 +466,10 @@ def prepare_for_training(
# prepare for training for the right method
if framework is Framework.TRANSFORMERS:
- return self._prepare_for_training_with_transformers(
- train_size=train_size, test_size=test_size, seed=seed
- )
+ return self._prepare_for_training_with_transformers(train_size=train_size, test_size=test_size, seed=seed)
elif framework is Framework.SPACY and lang is None:
raise ValueError(
- "Please provide a spacy language model to prepare the"
- " dataset for training with the spacy framework."
+ "Please provide a spacy language model to prepare the" " dataset for training with the spacy framework."
)
elif framework in [Framework.SPACY, Framework.SPARK_NLP]:
if train_size and test_size:
@@ -503,12 +482,8 @@ def prepare_for_training(
random_state=seed,
)
if framework is Framework.SPACY:
- train_docbin = self._prepare_for_training_with_spacy(
- nlp=lang, records=records_train
- )
- test_docbin = self._prepare_for_training_with_spacy(
- nlp=lang, records=records_test
- )
+ train_docbin = self._prepare_for_training_with_spacy(nlp=lang, records=records_train)
+ test_docbin = self._prepare_for_training_with_spacy(nlp=lang, records=records_test)
return train_docbin, test_docbin
else:
train_df = self._prepare_for_training_with_spark_nlp(records_train)
@@ -517,18 +492,11 @@ def prepare_for_training(
return train_df, test_df
else:
if framework is Framework.SPACY:
- return self._prepare_for_training_with_spacy(
- nlp=lang, records=shuffled_records
- )
+ return self._prepare_for_training_with_spacy(nlp=lang, records=shuffled_records)
else:
- return self._prepare_for_training_with_spark_nlp(
- records=shuffled_records
- )
+ return self._prepare_for_training_with_spark_nlp(records=shuffled_records)
else:
- raise NotImplementedError(
- f"Framework {framework} is not supported. Choose from:"
- f" {list(Framework)}"
- )
+ raise NotImplementedError(f"Framework {framework} is not supported. Choose from:" f" {list(Framework)}")
@_requires_spacy
def _prepare_for_training_with_spacy(
@@ -674,9 +642,7 @@ def from_datasets(
records = []
for row in dataset:
- row["inputs"] = cls._parse_inputs_field(
- row, cols_to_be_joined.get("inputs")
- )
+ row["inputs"] = cls._parse_inputs_field(row, cols_to_be_joined.get("inputs"))
if row.get("inputs") is not None and row.get("text") is not None:
del row["text"]
@@ -701,10 +667,7 @@ def from_datasets(
if row.get("explanation"):
row["explanation"] = (
{
- key: [
- TokenAttributions(**tokattr_kwargs)
- for tokattr_kwargs in val
- ]
+ key: [TokenAttributions(**tokattr_kwargs) for tokattr_kwargs in val]
for key, val in row["explanation"].items()
}
if row["explanation"] is not None
@@ -712,9 +675,7 @@ def from_datasets(
)
if cols_to_be_joined.get("metadata"):
- row["metadata"] = cls._join_datasets_columns_and_delete(
- row, cols_to_be_joined["metadata"]
- )
+ row["metadata"] = cls._join_datasets_columns_and_delete(row, cols_to_be_joined["metadata"])
records.append(TextClassificationRecord.parse_obj(row))
@@ -766,18 +727,13 @@ def _to_datasets_dict(self) -> Dict:
]
elif key == "explanation":
ds_dict[key] = [
- {
- key: list(map(dict, tokattrs))
- for key, tokattrs in rec.explanation.items()
- }
+ {key: list(map(dict, tokattrs)) for key, tokattrs in rec.explanation.items()}
if rec.explanation is not None
else None
for rec in self._records
]
elif key == "id":
- ds_dict[key] = [
- None if rec.id is None else str(rec.id) for rec in self._records
- ]
+ ds_dict[key] = [None if rec.id is None else str(rec.id) for rec in self._records]
elif key == "metadata":
ds_dict[key] = [getattr(rec, key) or None for rec in self._records]
else:
@@ -787,9 +743,7 @@ def _to_datasets_dict(self) -> Dict:
@classmethod
def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTextClassification":
- return cls(
- [TextClassificationRecord(**row) for row in dataframe.to_dict("records")]
- )
+ return cls([TextClassificationRecord(**row) for row in dataframe.to_dict("records")])
@_requires_datasets
def _prepare_for_training_with_transformers(
@@ -800,12 +754,7 @@ def _prepare_for_training_with_transformers(
):
import datasets
- inputs_keys = {
- key: None
- for rec in self._records
- for key in rec.inputs
- if rec.annotation is not None
- }.keys()
+ inputs_keys = {key: None for rec in self._records for key in rec.inputs if rec.annotation is not None}.keys()
ds_dict = {**{key: [] for key in inputs_keys}, "label": []}
for rec in self._records:
@@ -832,9 +781,7 @@ def _prepare_for_training_with_transformers(
"label": [class_label] if self._records[0].multi_label else class_label,
}
- ds = datasets.Dataset.from_dict(
- ds_dict, features=datasets.Features(feature_dict)
- )
+ ds = datasets.Dataset.from_dict(ds_dict, features=datasets.Features(feature_dict))
if self._records[0].multi_label:
from sklearn.preprocessing import MultiLabelBinarizer
@@ -853,16 +800,12 @@ def _prepare_for_training_with_transformers(
features=datasets.Features(feature_dict),
)
if test_size:
- ds = ds.train_test_split(
- train_size=train_size, test_size=test_size, seed=seed
- )
+ ds = ds.train_test_split(train_size=train_size, test_size=test_size, seed=seed)
return ds
@_requires_spacy
- def _prepare_for_training_with_spacy(
- self, nlp: "spacy.Language", records: List[Record]
- ) -> "spacy.tokens.DocBin":
+ def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[Record]) -> "spacy.tokens.DocBin":
from spacy.tokens import DocBin
db = DocBin()
@@ -889,9 +832,7 @@ def _prepare_for_training_with_spacy(
return db
- def _prepare_for_training_with_spark_nlp(
- self, records: List[Record]
- ) -> "pandas.DataFrame":
+ def _prepare_for_training_with_spark_nlp(self, records: List[Record]) -> "pandas.DataFrame":
if records[0].multi_label:
label_name = "labels"
else:
@@ -1016,9 +957,7 @@ def from_datasets(
continue
if row.get("tags"):
- row["tags"] = cls._parse_datasets_column_with_classlabel(
- row["tags"], dataset.features["tags"]
- )
+ row["tags"] = cls._parse_datasets_column_with_classlabel(row["tags"], dataset.features["tags"])
if row.get("prediction"):
row["prediction"] = cls.__entities_to_tuple__(row["prediction"])
@@ -1027,9 +966,7 @@ def from_datasets(
row["annotation"] = cls.__entities_to_tuple__(row["annotation"])
if cols_to_be_joined.get("metadata"):
- row["metadata"] = cls._join_datasets_columns_and_delete(
- row, cols_to_be_joined["metadata"]
- )
+ row["metadata"] = cls._join_datasets_columns_and_delete(row, cols_to_be_joined["metadata"])
records.append(TokenClassificationRecord.parse_obj(row))
@@ -1062,13 +999,7 @@ def _prepare_for_training_with_transformers(
return datasets.Dataset.from_dict({})
class_tags = ["O"]
- class_tags.extend(
- [
- f"{pre}-{label}"
- for label in sorted(self.__all_labels__())
- for pre in ["B", "I"]
- ]
- )
+ class_tags.extend([f"{pre}-{label}" for label in sorted(self.__all_labels__()) for pre in ["B", "I"]])
class_tags = datasets.ClassLabel(names=class_tags)
def spans2iob(example):
@@ -1078,26 +1009,18 @@ def spans2iob(example):
return class_tags.str2int(tags)
- ds = (
- self.to_datasets()
- .filter(self.__only_annotations__)
- .map(lambda example: {"ner_tags": spans2iob(example)})
- )
+ ds = self.to_datasets().filter(self.__only_annotations__).map(lambda example: {"ner_tags": spans2iob(example)})
new_features = ds.features.copy()
new_features["ner_tags"] = [class_tags]
ds = ds.cast(new_features)
if train_size or test_size:
- ds = ds.train_test_split(
- train_size=train_size, test_size=test_size, seed=seed
- )
+ ds = ds.train_test_split(train_size=train_size, test_size=test_size, seed=seed)
return ds
@_requires_spacy
- def _prepare_for_training_with_spacy(
- self, nlp: "spacy.Language", records: List[Record]
- ) -> "spacy.tokens.DocBin":
+ def _prepare_for_training_with_spacy(self, nlp: "spacy.Language", records: List[Record]) -> "spacy.tokens.DocBin":
from spacy.tokens import DocBin
db = DocBin()
@@ -1128,9 +1051,7 @@ def _prepare_for_training_with_spacy(
return db
- def _prepare_for_training_with_spark_nlp(
- self, records: List[Record]
- ) -> "pandas.DataFrame":
+ def _prepare_for_training_with_spark_nlp(self, records: List[Record]) -> "pandas.DataFrame":
for record in records:
if record.id is None:
record.id = str(uuid.uuid4())
@@ -1163,9 +1084,7 @@ def _to_datasets_dict(self) -> Dict:
# create a dict first, where we make the necessary transformations
def entities_to_dict(
- entities: Optional[
- List[Union[Tuple[str, int, int, float], Tuple[str, int, int]]]
- ]
+ entities: Optional[List[Union[Tuple[str, int, int, float], Tuple[str, int, int]]]]
) -> Optional[List[Dict[str, Union[str, int, float]]]]:
if entities is None:
return None
@@ -1179,17 +1098,11 @@ def entities_to_dict(
ds_dict = {}
for key in self._RECORD_TYPE.__fields__:
if key == "prediction":
- ds_dict[key] = [
- entities_to_dict(rec.prediction) for rec in self._records
- ]
+ ds_dict[key] = [entities_to_dict(rec.prediction) for rec in self._records]
elif key == "annotation":
- ds_dict[key] = [
- entities_to_dict(rec.annotation) for rec in self._records
- ]
+ ds_dict[key] = [entities_to_dict(rec.annotation) for rec in self._records]
elif key == "id":
- ds_dict[key] = [
- None if rec.id is None else str(rec.id) for rec in self._records
- ]
+ ds_dict[key] = [None if rec.id is None else str(rec.id) for rec in self._records]
elif key == "metadata":
ds_dict[key] = [getattr(rec, key) or None for rec in self._records]
else:
@@ -1210,9 +1123,7 @@ def __entities_to_tuple__(
@classmethod
def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTokenClassification":
- return cls(
- [TokenClassificationRecord(**row) for row in dataframe.to_dict("records")]
- )
+ return cls([TokenClassificationRecord(**row) for row in dataframe.to_dict("records")])
@_prepend_docstring(Text2TextRecord)
@@ -1299,9 +1210,7 @@ def from_datasets(
row["prediction"] = cls._parse_prediction_field(row["prediction"])
if cols_to_be_joined.get("metadata"):
- row["metadata"] = cls._join_datasets_columns_and_delete(
- row, cols_to_be_joined["metadata"]
- )
+ row["metadata"] = cls._join_datasets_columns_and_delete(row, cols_to_be_joined["metadata"])
records.append(Text2TextRecord.parse_obj(row))
@@ -1337,15 +1246,11 @@ def pred_to_dict(pred: Union[str, Tuple[str, float]]):
for key in self._RECORD_TYPE.__fields__:
if key == "prediction":
ds_dict[key] = [
- [pred_to_dict(pred) for pred in rec.prediction]
- if rec.prediction is not None
- else None
+ [pred_to_dict(pred) for pred in rec.prediction] if rec.prediction is not None else None
for rec in self._records
]
elif key == "id":
- ds_dict[key] = [
- None if rec.id is None else str(rec.id) for rec in self._records
- ]
+ ds_dict[key] = [None if rec.id is None else str(rec.id) for rec in self._records]
elif key == "metadata":
ds_dict[key] = [getattr(rec, key) or None for rec in self._records]
else:
@@ -1371,14 +1276,10 @@ def prepare_for_training(self, **kwargs) -> "datasets.Dataset":
raise NotImplementedError
-Dataset = Union[
- DatasetForTextClassification, DatasetForTokenClassification, DatasetForText2Text
-]
+Dataset = Union[DatasetForTextClassification, DatasetForTokenClassification, DatasetForText2Text]
-def read_datasets(
- dataset: "datasets.Dataset", task: Union[str, TaskType], **kwargs
-) -> Dataset:
+def read_datasets(dataset: "datasets.Dataset", task: Union[str, TaskType], **kwargs) -> Dataset:
"""Reads a datasets Dataset and returns a argilla Dataset
Args:
@@ -1431,9 +1332,7 @@ def read_datasets(
return DatasetForTokenClassification.from_datasets(dataset, **kwargs)
if task is TaskType.text2text:
return DatasetForText2Text.from_datasets(dataset, **kwargs)
- raise NotImplementedError(
- "Reading a datasets Dataset is not implemented for the given task!"
- )
+ raise NotImplementedError("Reading a datasets Dataset is not implemented for the given task!")
def read_pandas(dataframe: pd.DataFrame, task: Union[str, TaskType]) -> Dataset:
@@ -1488,9 +1387,7 @@ def read_pandas(dataframe: pd.DataFrame, task: Union[str, TaskType]) -> Dataset:
return DatasetForTokenClassification.from_pandas(dataframe)
if task is TaskType.text2text:
return DatasetForText2Text.from_pandas(dataframe)
- raise NotImplementedError(
- "Reading a pandas DataFrame is not implemented for the given task!"
- )
+ raise NotImplementedError("Reading a pandas DataFrame is not implemented for the given task!")
class WrongRecordTypeError(Exception):
diff --git a/src/argilla/client/models.py b/src/argilla/client/models.py
index 845e7f8eb0..89a274865e 100644
--- a/src/argilla/client/models.py
+++ b/src/argilla/client/models.py
@@ -44,8 +44,7 @@ class Framework(Enum):
@classmethod
def _missing_(cls, value):
raise ValueError(
- f"{value} is not a valid {cls.__name__}, please select one of"
- f" {list(cls._value2member_map_.keys())}"
+ f"{value} is not a valid {cls.__name__}, please select one of" f" {list(cls._value2member_map_.keys())}"
)
@@ -68,8 +67,7 @@ def _check_value_length(cls, metadata):
message = (
"Some metadata values could exceed the max length. For those cases,"
" values will be truncated by keeping only the last"
- f" {DEFAULT_MAX_KEYWORD_LENGTH} characters. "
- + _messages.ARGILLA_METADATA_FIELD_WARNING_MESSAGE
+ f" {DEFAULT_MAX_KEYWORD_LENGTH} characters. " + _messages.ARGILLA_METADATA_FIELD_WARNING_MESSAGE
)
warnings.warn(message, UserWarning)
@@ -114,9 +112,7 @@ def _nat_to_none_and_one_to_now(cls, v):
@root_validator
def _check_and_update_status(cls, values):
"""Updates the status if an annotation is provided and no status is specified."""
- values["status"] = values.get("status") or (
- "Default" if values.get("annotation") is None else "Validated"
- )
+ values["status"] = values.get("status") or ("Default" if values.get("annotation") is None else "Validated")
return values
@@ -261,10 +257,7 @@ def _check_text_and_inputs(cls, values):
and values.get("inputs") is not None
and values["text"] != values["inputs"].get("text")
):
- raise ValueError(
- "For a TextClassificationRecord you must provide either 'text' or"
- " 'inputs'"
- )
+ raise ValueError("For a TextClassificationRecord you must provide either 'text' or" " 'inputs'")
if values.get("text") is not None:
values["inputs"] = dict(text=values["text"])
@@ -333,9 +326,7 @@ class TokenClassificationRecord(_Validators):
text: Optional[str] = Field(None, min_length=1)
tokens: Optional[Union[List[str], Tuple[str, ...]]] = None
- prediction: Optional[
- List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]]
- ] = None
+ prediction: Optional[List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]]] = None
prediction_agent: Optional[str] = None
annotation: Optional[List[Tuple[str, int, int]]] = None
annotation_agent: Optional[str] = None
@@ -358,16 +349,10 @@ def __init__(
**data,
):
if text is None and tokens is None:
- raise AssertionError(
- "Missing fields: At least one of `text` or `tokens` argument must be"
- " provided!"
- )
+ raise AssertionError("Missing fields: At least one of `text` or `tokens` argument must be" " provided!")
if (data.get("annotation") or data.get("prediction")) and text is None:
- raise AssertionError(
- "Missing field `text`: "
- "char level spans must be provided with a raw text sentence"
- )
+ raise AssertionError("Missing field `text`: " "char level spans must be provided with a raw text sentence")
if text is None:
text = " ".join(tokens)
@@ -392,9 +377,7 @@ def __setattr__(self, name: str, value: Any):
raise AttributeError(f"You cannot assign a new value to `{name}`")
super().__setattr__(name, value)
- def _validate_spans(
- self, spans: List[Tuple[str, int, int]]
- ) -> List[Tuple[str, int, int]]:
+ def _validate_spans(self, spans: List[Tuple[str, int, int]]) -> List[Tuple[str, int, int]]:
"""Validates the entity spans with respect to the tokens.
If necessary, also performs an automatic correction of the spans.
@@ -427,17 +410,13 @@ def _normalize_tokens(cls, value):
@validator("prediction")
def _add_default_score(
cls,
- prediction: Optional[
- List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]]
- ],
+ prediction: Optional[List[Union[Tuple[str, int, int], Tuple[str, int, int, Optional[float]]]]],
):
"""Adds the default score to the predictions if it is missing"""
if prediction is None:
return prediction
return [
- (pred[0], pred[1], pred[2], 0.0)
- if len(pred) == 3
- else (pred[0], pred[1], pred[2], pred[3] or 0.0)
+ (pred[0], pred[1], pred[2], 0.0) if len(pred) == 3 else (pred[0], pred[1], pred[2], pred[3] or 0.0)
for pred in prediction
]
@@ -492,9 +471,7 @@ def token_span(self, token_idx: int) -> Tuple[int, int]:
raise IndexError(f"Token id {token_idx} out of bounds")
return self._span_utils.token_to_char_idx[token_idx]
- def spans2iob(
- self, spans: Optional[List[Tuple[str, int, int]]] = None
- ) -> Optional[List[str]]:
+ def spans2iob(self, spans: Optional[List[Tuple[str, int, int]]] = None) -> Optional[List[str]]:
"""DEPRECATED, please use the ``argilla.utils.SpanUtils.to_tags()`` method."""
warnings.warn(
"'spans2iob' is deprecated and will be removed in a future version. Please"
@@ -570,9 +547,7 @@ class Text2TextRecord(_Validators):
search_keywords: Optional[List[str]] = None
@validator("prediction")
- def prediction_as_tuples(
- cls, prediction: Optional[List[Union[str, Tuple[str, float]]]]
- ):
+ def prediction_as_tuples(cls, prediction: Optional[List[Union[str, Tuple[str, float]]]]):
"""Preprocess the predictions and wraps them in a tuple if needed"""
if prediction is None:
return prediction
diff --git a/src/argilla/client/sdk/client.py b/src/argilla/client/sdk/client.py
index be4d2983b9..9f925b6795 100644
--- a/src/argilla/client/sdk/client.py
+++ b/src/argilla/client/sdk/client.py
@@ -108,10 +108,7 @@ def __hash__(self):
def with_httpx_error_handler(func):
def wrap_error(base_url: str):
- err_str = (
- f"Your Api endpoint at {base_url} is not available or not"
- " responding."
- )
+ err_str = f"Your Api endpoint at {base_url} is not available or not" " responding."
raise BaseClientError(err_str) from None
@functools.wraps(func)
diff --git a/src/argilla/client/sdk/commons/api.py b/src/argilla/client/sdk/commons/api.py
index 3810aebdf1..5c96eff3f8 100644
--- a/src/argilla/client/sdk/commons/api.py
+++ b/src/argilla/client/sdk/commons/api.py
@@ -50,9 +50,7 @@
}
-def build_param_dict(
- id_from: Optional[str], limit: Optional[int]
-) -> Optional[Dict[str, Union[str, int]]]:
+def build_param_dict(id_from: Optional[str], limit: Optional[int]) -> Optional[Dict[str, Union[str, int]]]:
params = {}
if id_from:
params["id_from"] = id_from
@@ -64,9 +62,7 @@ def build_param_dict(
def bulk(
client: AuthenticatedClient,
name: str,
- json_body: Union[
- TextClassificationBulkData, TokenClassificationBulkData, Text2TextBulkData
- ],
+ json_body: Union[TextClassificationBulkData, TokenClassificationBulkData, Text2TextBulkData],
) -> Response[BulkResponse]:
url = f"{client.base_url}/api/datasets/{name}/{_TASK_TO_ENDPOINT[type(json_body)]}:bulk"
@@ -100,9 +96,7 @@ async def async_bulk(
return build_bulk_response(response, name=name, body=json_body)
-def build_bulk_response(
- response: httpx.Response, name: str, body: Any
-) -> Response[BulkResponse]:
+def build_bulk_response(response: httpx.Response, name: str, body: Any) -> Response[BulkResponse]:
if 200 <= response.status_code < 400:
return Response(
status_code=response.status_code,
@@ -117,9 +111,7 @@ def build_bulk_response(
T = TypeVar("T")
-def build_data_response(
- response: httpx.Response, data_type: Type[T]
-) -> Response[List[T]]:
+def build_data_response(response: httpx.Response, data_type: Type[T]) -> Response[List[T]]:
if 200 <= response.status_code < 400:
parsed_responses = []
for r in response.iter_lines():
diff --git a/src/argilla/client/sdk/commons/errors.py b/src/argilla/client/sdk/commons/errors.py
index d3698c404a..ca7dfd056f 100644
--- a/src/argilla/client/sdk/commons/errors.py
+++ b/src/argilla/client/sdk/commons/errors.py
@@ -26,11 +26,7 @@ def __init__(self, message: str, response: Any):
self.response = response
def __str__(self):
- return (
- f"\nUnexpected response: {self.message}"
- "\nResponse content:"
- f"\n{self.response}"
- )
+ return f"\nUnexpected response: {self.message}" "\nResponse content:" f"\n{self.response}"
class InputValueError(BaseClientError):
@@ -72,8 +68,7 @@ def __init__(self, **ctx):
def __str__(self):
return (
- f"Argilla server returned an error with http status: {self.HTTP_STATUS}"
- + f"\nError details: [{self.ctx}]"
+ f"Argilla server returned an error with http status: {self.HTTP_STATUS}" + f"\nError details: [{self.ctx}]"
)
diff --git a/src/argilla/client/sdk/commons/errors_handler.py b/src/argilla/client/sdk/commons/errors_handler.py
index f1292095dd..9b425b24db 100644
--- a/src/argilla/client/sdk/commons/errors_handler.py
+++ b/src/argilla/client/sdk/commons/errors_handler.py
@@ -29,9 +29,7 @@
)
-def handle_response_error(
- response: httpx.Response, parse_response: bool = True, **client_ctx
-):
+def handle_response_error(response: httpx.Response, parse_response: bool = True, **client_ctx):
try:
response_content = response.json() if parse_response else {}
except JSONDecodeError:
diff --git a/src/argilla/client/sdk/commons/models.py b/src/argilla/client/sdk/commons/models.py
index 926e892cb1..2dff92592a 100644
--- a/src/argilla/client/sdk/commons/models.py
+++ b/src/argilla/client/sdk/commons/models.py
@@ -78,9 +78,7 @@ def datetime_to_isoformat(cls, v: Optional[datetime]):
def _from_client_vectors(vectors: ClientVectors) -> SdkVectors:
sdk_vectors = None
if vectors:
- sdk_vectors = {
- name: VectorInfo(value=vector) for name, vector in vectors.items()
- }
+ sdk_vectors = {name: VectorInfo(value=vector) for name, vector in vectors.items()}
return sdk_vectors
@staticmethod
diff --git a/src/argilla/client/sdk/datasets/api.py b/src/argilla/client/sdk/datasets/api.py
index b836b949d0..a254b93974 100644
--- a/src/argilla/client/sdk/datasets/api.py
+++ b/src/argilla/client/sdk/datasets/api.py
@@ -86,9 +86,7 @@ def delete_dataset(
return handle_response_error(response, dataset=name)
-def _build_response(
- response: httpx.Response, name: str
-) -> Response[Union[Dataset, ErrorMessage, HTTPValidationError]]:
+def _build_response(response: httpx.Response, name: str) -> Response[Union[Dataset, ErrorMessage, HTTPValidationError]]:
if response.status_code == 200:
parsed_response = Dataset(**response.json())
return Response(
diff --git a/src/argilla/client/sdk/metrics/api.py b/src/argilla/client/sdk/metrics/api.py
index 30171dc458..c700c34970 100644
--- a/src/argilla/client/sdk/metrics/api.py
+++ b/src/argilla/client/sdk/metrics/api.py
@@ -32,9 +32,7 @@
def get_dataset_metrics(
client: AuthenticatedClient, name: str, task: str
) -> Response[Union[List[MetricInfo], ErrorMessage, HTTPValidationError]]:
- url = "{}/api/datasets/{task}/{name}/metrics".format(
- client.base_url, name=name, task=task
- )
+ url = "{}/api/datasets/{task}/{name}/metrics".format(client.base_url, name=name, task=task)
response = httpx.get(
url=url,
diff --git a/src/argilla/client/sdk/metrics/models.py b/src/argilla/client/sdk/metrics/models.py
index 3e272966e1..dbcb60c5c8 100644
--- a/src/argilla/client/sdk/metrics/models.py
+++ b/src/argilla/client/sdk/metrics/models.py
@@ -22,6 +22,4 @@ class MetricInfo(BaseModel):
id: str = Field(description="The metric id")
name: str = Field(description="The metric name")
- description: Optional[str] = Field(
- default=None, description="The metric description"
- )
+ description: Optional[str] = Field(default=None, description="The metric description")
diff --git a/src/argilla/client/sdk/text2text/models.py b/src/argilla/client/sdk/text2text/models.py
index f8f563caee..7f067820d1 100644
--- a/src/argilla/client/sdk/text2text/models.py
+++ b/src/argilla/client/sdk/text2text/models.py
@@ -82,10 +82,7 @@ class Text2TextRecord(CreationText2TextRecord):
def to_client(self) -> ClientText2TextRecord:
return ClientText2TextRecord(
text=self.text,
- prediction=[
- (sentence.text, sentence.score)
- for sentence in self.prediction.sentences
- ]
+ prediction=[(sentence.text, sentence.score) for sentence in self.prediction.sentences]
if self.prediction
else None,
prediction_agent=self.prediction.agent if self.prediction else None,
diff --git a/src/argilla/client/sdk/text_classification/api.py b/src/argilla/client/sdk/text_classification/api.py
index 39c603411e..50778ccd85 100644
--- a/src/argilla/client/sdk/text_classification/api.py
+++ b/src/argilla/client/sdk/text_classification/api.py
@@ -52,9 +52,7 @@ def data(
params=params if params else None,
json=request.dict() if request else {},
) as response:
- return build_data_response(
- response=response, data_type=TextClassificationRecord
- )
+ return build_data_response(response=response, data_type=TextClassificationRecord)
def add_dataset_labeling_rule(
@@ -62,9 +60,7 @@ def add_dataset_labeling_rule(
name: str,
rule: LabelingRule,
) -> Response[Union[LabelingRule, HTTPValidationError, ErrorMessage]]:
- url = "{}/api/datasets/{name}/TextClassification/labeling/rules".format(
- client.base_url, name=name
- )
+ url = "{}/api/datasets/{name}/TextClassification/labeling/rules".format(client.base_url, name=name)
response = httpx.post(
url=url,
@@ -118,9 +114,7 @@ def fetch_dataset_labeling_rules(
client: AuthenticatedClient,
name: str,
) -> Response[Union[List[LabelingRule], HTTPValidationError, ErrorMessage]]:
- url = "{}/api/datasets/TextClassification/{name}/labeling/rules".format(
- client.base_url, name=name
- )
+ url = "{}/api/datasets/TextClassification/{name}/labeling/rules".format(client.base_url, name=name)
response = httpx.get(
url=url,
@@ -149,6 +143,4 @@ def dataset_rule_metrics(
timeout=client.get_timeout(),
)
- return build_typed_response(
- response, response_type_class=LabelingRuleMetricsSummary
- )
+ return build_typed_response(response, response_type_class=LabelingRuleMetricsSummary)
diff --git a/src/argilla/client/sdk/text_classification/models.py b/src/argilla/client/sdk/text_classification/models.py
index cfe016cb7b..2fc82b6762 100644
--- a/src/argilla/client/sdk/text_classification/models.py
+++ b/src/argilla/client/sdk/text_classification/models.py
@@ -60,24 +60,15 @@ def from_client(cls, record: ClientTextClassificationRecord):
prediction = None
if record.prediction is not None:
prediction = TextClassificationAnnotation(
- labels=[
- ClassPrediction(**{"class": label, "score": score})
- for label, score in record.prediction
- ],
+ labels=[ClassPrediction(**{"class": label, "score": score}) for label, score in record.prediction],
agent=record.prediction_agent or MACHINE_NAME,
)
annotation = None
if record.annotation is not None:
- annotation_list = (
- record.annotation
- if isinstance(record.annotation, list)
- else [record.annotation]
- )
+ annotation_list = record.annotation if isinstance(record.annotation, list) else [record.annotation]
annotation = TextClassificationAnnotation(
- labels=[
- ClassPrediction(**{"class": label}) for label in annotation_list
- ],
+ labels=[ClassPrediction(**{"class": label}) for label in annotation_list],
agent=record.annotation_agent or MACHINE_NAME,
)
@@ -101,11 +92,7 @@ class TextClassificationRecord(CreationTextClassificationRecord):
def to_client(self) -> ClientTextClassificationRecord:
"""Returns the client model"""
- annotations = (
- [label.class_label for label in self.annotation.labels]
- if self.annotation
- else None
- )
+ annotations = [label.class_label for label in self.annotation.labels] if self.annotation else None
if annotations and not self.multi_label:
annotations = annotations[0]
@@ -116,9 +103,7 @@ def to_client(self) -> ClientTextClassificationRecord:
multi_label=self.multi_label,
status=self.status,
metadata=self.metadata or {},
- prediction=[
- (label.class_label, label.score) for label in self.prediction.labels
- ]
+ prediction=[(label.class_label, label.score) for label in self.prediction.labels]
if self.prediction
else None,
prediction_agent=self.prediction.agent if self.prediction else None,
@@ -126,10 +111,7 @@ def to_client(self) -> ClientTextClassificationRecord:
annotation_agent=self.annotation.agent if self.annotation else None,
vectors=self._to_client_vectors(self.vectors),
explanation={
- key: [
- ClientTokenAttributions.parse_obj(attribution)
- for attribution in attributions
- ]
+ key: [ClientTokenAttributions.parse_obj(attribution) for attribution in attributions]
for key, attributions in self.explanation.items()
}
if self.explanation
diff --git a/src/argilla/client/sdk/token_classification/api.py b/src/argilla/client/sdk/token_classification/api.py
index 10a91d0f59..9d9d8909fc 100644
--- a/src/argilla/client/sdk/token_classification/api.py
+++ b/src/argilla/client/sdk/token_classification/api.py
@@ -36,9 +36,7 @@ def data(
request: Optional[TokenClassificationQuery] = None,
limit: Optional[int] = None,
id_from: Optional[str] = None,
-) -> Response[
- Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage]
-]:
+) -> Response[Union[List[TokenClassificationRecord], HTTPValidationError, ErrorMessage]]:
path = f"/api/datasets/{name}/TokenClassification/data"
params = build_param_dict(id_from, limit)
@@ -48,6 +46,4 @@ def data(
params=params if params else None,
json=request.dict() if request else {},
) as response:
- return build_data_response(
- response=response, data_type=TokenClassificationRecord
- )
+ return build_data_response(response=response, data_type=TokenClassificationRecord)
diff --git a/src/argilla/client/sdk/token_classification/models.py b/src/argilla/client/sdk/token_classification/models.py
index 8a549de39f..0bea10bba2 100644
--- a/src/argilla/client/sdk/token_classification/models.py
+++ b/src/argilla/client/sdk/token_classification/models.py
@@ -63,9 +63,7 @@ def from_client(cls, record: ClientTokenClassificationRecord):
entities=[
EntitySpan(label=ent[0], start=ent[1], end=ent[2])
if len(ent) == 3
- else EntitySpan(
- label=ent[0], start=ent[1], end=ent[2], score=ent[3]
- )
+ else EntitySpan(label=ent[0], start=ent[1], end=ent[2], score=ent[3])
for ent in record.prediction
],
agent=record.prediction_agent or MACHINE_NAME,
@@ -74,10 +72,7 @@ def from_client(cls, record: ClientTokenClassificationRecord):
annotation = None
if record.annotation is not None:
annotation = TokenClassificationAnnotation(
- entities=[
- EntitySpan(label=ent[0], start=ent[1], end=ent[2])
- for ent in record.annotation
- ],
+ entities=[EntitySpan(label=ent[0], start=ent[1], end=ent[2]) for ent in record.annotation],
agent=record.annotation_agent or MACHINE_NAME,
)
@@ -102,16 +97,11 @@ def to_client(self) -> ClientTokenClassificationRecord:
return ClientTokenClassificationRecord(
text=self.text,
tokens=self.tokens,
- prediction=[
- (ent.label, ent.start, ent.end, ent.score)
- for ent in self.prediction.entities
- ]
+ prediction=[(ent.label, ent.start, ent.end, ent.score) for ent in self.prediction.entities]
if self.prediction
else None,
prediction_agent=self.prediction.agent if self.prediction else None,
- annotation=[
- (ent.label, ent.start, ent.end) for ent in self.annotation.entities
- ]
+ annotation=[(ent.label, ent.start, ent.end) for ent in self.annotation.entities]
if self.annotation
else None,
annotation_agent=self.annotation.agent if self.annotation else None,
diff --git a/src/argilla/labeling/text_classification/label_errors.py b/src/argilla/labeling/text_classification/label_errors.py
index e9a8583f50..2460c816cf 100644
--- a/src/argilla/labeling/text_classification/label_errors.py
+++ b/src/argilla/labeling/text_classification/label_errors.py
@@ -95,9 +95,7 @@ def find_label_errors(
# select only records with prediction and annotation
records = [rec for rec in records if rec.prediction and rec.annotation]
if not records:
- raise NoRecordsError(
- "It seems that none of your records have a prediction AND annotation!"
- )
+ raise NoRecordsError("It seems that none of your records have a prediction AND annotation!")
# check and update kwargs for get_noise_indices
_check_and_update_kwargs(cleanlab.__version__, records[0], sort_by, kwargs)
@@ -117,9 +115,7 @@ def find_label_errors(
return records_with_label_errors
-def _check_and_update_kwargs(
- version: str, record: TextClassificationRecord, sort_by: SortBy, kwargs: Dict
-):
+def _check_and_update_kwargs(version: str, record: TextClassificationRecord, sort_by: SortBy, kwargs: Dict):
"""Helper function to check and update the kwargs passed on to cleanlab's `get_noise_indices`.
Args:
@@ -140,9 +136,7 @@ def _check_and_update_kwargs(
if parse_version(version) < parse_version("2.0"):
if "sorted_index_method" in kwargs:
- raise ValueError(
- "The 'sorted_index_method' kwarg is not supported, please use 'sort_by' instead."
- )
+ raise ValueError("The 'sorted_index_method' kwarg is not supported, please use 'sort_by' instead.")
kwargs["sorted_index_method"] = "normalized_margin"
if sort_by is SortBy.PREDICTION:
kwargs["sorted_index_method"] = "prob_given_label"
@@ -150,9 +144,7 @@ def _check_and_update_kwargs(
kwargs["sorted_index_method"] = None
else:
if "return_indices_ranked_by" in kwargs:
- raise ValueError(
- "The 'return_indices_ranked_by' kwarg is not supported, please use 'sort_by' instead."
- )
+ raise ValueError("The 'return_indices_ranked_by' kwarg is not supported, please use 'sort_by' instead.")
kwargs["return_indices_ranked_by"] = "normalized_margin"
if sort_by is SortBy.PREDICTION:
kwargs["return_indices_ranked_by"] = "self_confidence"
@@ -188,20 +180,14 @@ def _construct_s_and_psx(
labels.update(predictions[-1].keys())
labels_mapping = {label: i for i, label in enumerate(sorted(labels))}
- s = (
- np.empty(len(records), dtype=object)
- if records[0].multi_label
- else np.zeros(len(records), dtype=np.short)
- )
+ s = np.empty(len(records), dtype=object) if records[0].multi_label else np.zeros(len(records), dtype=np.short)
psx = np.zeros((len(records), len(labels)), dtype=np.float)
for i, rec, pred in zip(range(len(records)), records, predictions):
try:
psx[i] = [pred[label] for label in labels_mapping]
except KeyError as error:
- raise MissingPredictionError(
- f"It seems a prediction for {error} is missing in the following record: {rec}"
- )
+ raise MissingPredictionError(f"It seems a prediction for {error} is missing in the following record: {rec}")
try:
s[i] = (
@@ -210,9 +196,7 @@ def _construct_s_and_psx(
else labels_mapping[rec.annotation]
)
except KeyError as error:
- raise MissingPredictionError(
- f"It seems predictions are missing for the label {error}!"
- )
+ raise MissingPredictionError(f"It seems predictions are missing for the label {error}!")
return s, psx
diff --git a/src/argilla/labeling/text_classification/label_models.py b/src/argilla/labeling/text_classification/label_models.py
index 7186c46b9e..705894549b 100644
--- a/src/argilla/labeling/text_classification/label_models.py
+++ b/src/argilla/labeling/text_classification/label_models.py
@@ -35,8 +35,7 @@ class TieBreakPolicy(Enum):
@classmethod
def _missing_(cls, value):
raise ValueError(
- f"{value} is not a valid {cls.__name__}, please select one of"
- f" {list(cls._value2member_map_.keys())}"
+ f"{value} is not a valid {cls.__name__}, please select one of" f" {list(cls._value2member_map_.keys())}"
)
@@ -133,12 +132,8 @@ def predict(
Returns:
A dataset of records that include the predictions of the label model.
"""
- wl_matrix = self._weak_labels.matrix(
- has_annotation=None if include_annotated_records else False
- )
- records = self._weak_labels.records(
- has_annotation=None if include_annotated_records else False
- )
+ wl_matrix = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False)
+ records = self._weak_labels.records(has_annotation=None if include_annotated_records else False)
assert records, ValueError(
"No records are being passed. Use `include_annotated_records` to include"
@@ -178,9 +173,7 @@ def _compute_single_label_probs(self, wl_matrix: np.ndarray) -> np.ndarray:
"""
counts = np.column_stack(
[
- np.count_nonzero(
- wl_matrix == self._weak_labels.label2int[label], axis=1
- )
+ np.count_nonzero(wl_matrix == self._weak_labels.label2int[label], axis=1)
for label in self._weak_labels.labels
]
)
@@ -228,34 +221,22 @@ def _make_single_label_records(
tie = True
# maybe skip record
- if not include_abstentions and (
- tie and tie_break_policy is TieBreakPolicy.ABSTAIN
- ):
+ if not include_abstentions and (tie and tie_break_policy is TieBreakPolicy.ABSTAIN):
continue
if not tie:
- pred_for_rec = [
- (self._weak_labels.labels[idx], prob[idx])
- for idx in np.argsort(prob)[::-1]
- ]
+ pred_for_rec = [(self._weak_labels.labels[idx], prob[idx]) for idx in np.argsort(prob)[::-1]]
# resolve ties following the tie break policy
elif tie_break_policy is TieBreakPolicy.ABSTAIN:
pred_for_rec = None
elif tie_break_policy is TieBreakPolicy.RANDOM:
- random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(
- equal_prob_idx
- )
+ random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(equal_prob_idx)
for idx in equal_prob_idx:
if idx == random_idx:
prob[idx] += self._PROBABILITY_INCREASE_ON_TIE_BREAK
else:
- prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / (
- len(equal_prob_idx) - 1
- )
- pred_for_rec = [
- (self._weak_labels.labels[idx], prob[idx])
- for idx in np.argsort(prob)[::-1]
- ]
+ prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / (len(equal_prob_idx) - 1)
+ pred_for_rec = [(self._weak_labels.labels[idx], prob[idx]) for idx in np.argsort(prob)[::-1]]
else:
raise NotImplementedError(
f"The tie break policy '{tie_break_policy.value}' is not"
@@ -321,10 +302,7 @@ def _make_multi_label_records(
pred_for_rec = None
if not all_abstained:
- pred_for_rec = [
- (self._weak_labels.labels[i], prob[i])
- for i in np.argsort(prob)[::-1]
- ]
+ pred_for_rec = [(self._weak_labels.labels[i], prob[i]) for i in np.argsort(prob)[::-1]]
records_with_prediction.append(rec.copy(deep=True))
records_with_prediction[-1].prediction = pred_for_rec
@@ -388,9 +366,7 @@ def score(
probabilities = self._compute_single_label_probs(wl_matrix)
- annotation, prediction = self._score_single_label(
- probabilities, tie_break_policy
- )
+ annotation, prediction = self._score_single_label(probabilities, tie_break_policy)
target_names = self._weak_labels.labels[: annotation.max() + 1]
return classification_report(
@@ -419,18 +395,13 @@ def _score_single_label(
A tuple of the annotation and prediction array.
"""
# 1.e-8 is taken from the abs tolerance of np.isclose
- is_max = (
- np.abs(probabilities.max(axis=1, keepdims=True) - probabilities) < 1.0e-8
- )
+ is_max = np.abs(probabilities.max(axis=1, keepdims=True) - probabilities) < 1.0e-8
is_tie = is_max.sum(axis=1) > 1
prediction = np.argmax(is_max, axis=1)
# we need to transform the indexes!
annotation = np.array(
- [
- self._weak_labels.labels.index(self._weak_labels.int2label[i])
- for i in self._weak_labels.annotation()
- ],
+ [self._weak_labels.labels.index(self._weak_labels.int2label[i]) for i in self._weak_labels.annotation()],
dtype=np.short,
)
@@ -442,21 +413,16 @@ def _score_single_label(
elif tie_break_policy is TieBreakPolicy.RANDOM:
for i in np.nonzero(is_tie)[0]:
equal_prob_idx = np.nonzero(is_max[i])[0]
- random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(
- equal_prob_idx
- )
+ random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(equal_prob_idx)
prediction[i] = equal_prob_idx[random_idx]
else:
raise NotImplementedError(
- f"The tie break policy '{tie_break_policy.value}' is not implemented"
- " for MajorityVoter!"
+ f"The tie break policy '{tie_break_policy.value}' is not implemented" " for MajorityVoter!"
)
return annotation, prediction
- def _score_multi_label(
- self, probabilities: np.ndarray
- ) -> Tuple[np.ndarray, np.ndarray]:
+ def _score_multi_label(self, probabilities: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Helper method to compute scores for multi-label classifications.
Args:
@@ -496,9 +462,7 @@ class Snorkel(LabelModel):
>>> records = label_model.predict()
"""
- def __init__(
- self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu"
- ):
+ def __init__(self, weak_labels: WeakLabels, verbose: bool = True, device: str = "cpu"):
try:
import snorkel # noqa: F401
except ModuleNotFoundError:
@@ -513,23 +477,18 @@ def __init__(
# Check if we need to remap the weak labels to int mapping
# Snorkel expects the abstain id to be -1 and the rest of the labels to be sequential
- if self._weak_labels.label2int[None] != -1 or sorted(
- self._weak_labels.int2label
- ) != list(range(-1, self._weak_labels.cardinality)):
+ if self._weak_labels.label2int[None] != -1 or sorted(self._weak_labels.int2label) != list(
+ range(-1, self._weak_labels.cardinality)
+ ):
self._need_remap = True
self._weaklabelsInt2snorkelInt = {
- self._weak_labels.label2int[label]: i
- for i, label in enumerate([None] + self._weak_labels.labels, -1)
+ self._weak_labels.label2int[label]: i for i, label in enumerate([None] + self._weak_labels.labels, -1)
}
else:
self._need_remap = False
- self._weaklabelsInt2snorkelInt = {
- i: i for i in range(-1, self._weak_labels.cardinality)
- }
+ self._weaklabelsInt2snorkelInt = {i: i for i in range(-1, self._weak_labels.cardinality)}
- self._snorkelInt2weaklabelsInt = {
- val: key for key, val in self._weaklabelsInt2snorkelInt.items()
- }
+ self._snorkelInt2weaklabelsInt = {val: key for key, val in self._weaklabelsInt2snorkelInt.items()}
# instantiate Snorkel's label model
self._model = SnorkelLabelModel(
@@ -548,13 +507,9 @@ def fit(self, include_annotated_records: bool = False, **kwargs):
They must not contain ``L_train``, the label matrix is provided automatically.
"""
if "L_train" in kwargs:
- raise ValueError(
- "Your kwargs must not contain 'L_train', it is provided automatically."
- )
+ raise ValueError("Your kwargs must not contain 'L_train', it is provided automatically.")
- l_train = self._weak_labels.matrix(
- has_annotation=None if include_annotated_records else False
- )
+ l_train = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False)
if self._need_remap:
l_train = self._copy_and_remap(l_train)
@@ -614,9 +569,7 @@ def predict(
if isinstance(tie_break_policy, str):
tie_break_policy = TieBreakPolicy(tie_break_policy)
- l_pred = self._weak_labels.matrix(
- has_annotation=None if include_annotated_records else False
- )
+ l_pred = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False)
if self._need_remap:
l_pred = self._copy_and_remap(l_pred)
@@ -630,9 +583,7 @@ def predict(
# add predictions to records
records_with_prediction = []
for rec, pred, prob in zip(
- self._weak_labels.records(
- has_annotation=None if include_annotated_records else False
- ),
+ self._weak_labels.records(has_annotation=None if include_annotated_records else False),
predictions,
probabilities,
):
@@ -652,15 +603,11 @@ def predict(
if idx == pred:
prob[idx] += self._PROBABILITY_INCREASE_ON_TIE_BREAK
else:
- prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / (
- len(equal_prob_idx) - 1
- )
+ prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / (len(equal_prob_idx) - 1)
pred_for_rec = [
(
- self._weak_labels.int2label[
- self._snorkelInt2weaklabelsInt[snorkel_idx]
- ],
+ self._weak_labels.int2label[self._snorkelInt2weaklabelsInt[snorkel_idx]],
prob[snorkel_idx],
)
for snorkel_idx in np.argsort(prob)[::-1]
@@ -713,8 +660,7 @@ def score(
if self._weak_labels.annotation().size == 0:
raise MissingAnnotationError(
- "You need annotated records to compute scores/metrics for your label"
- " model."
+ "You need annotated records to compute scores/metrics for your label" " model."
)
l_pred = self._weak_labels.matrix(has_annotation=True)
@@ -779,14 +725,10 @@ def __init__(self, weak_labels: WeakLabels, **kwargs):
super().__init__(weak_labels)
if len(self._weak_labels.rules) < 3:
- raise TooFewRulesError(
- "The FlyingSquid label model needs at least three (independent) rules!"
- )
+ raise TooFewRulesError("The FlyingSquid label model needs at least three (independent) rules!")
if "m" in kwargs:
- raise ValueError(
- "Your kwargs must not contain 'm', it is provided automatically."
- )
+ raise ValueError("Your kwargs must not contain 'm', it is provided automatically.")
self._init_kwargs = kwargs
self._models: List[FlyingSquidLabelModel] = []
@@ -800,21 +742,15 @@ def fit(self, include_annotated_records: bool = False, **kwargs):
`LabelModel.fit() `__
method.
"""
- wl_matrix = self._weak_labels.matrix(
- has_annotation=None if include_annotated_records else False
- )
+ wl_matrix = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False)
models = []
# create a label model for each label (except for binary classification)
# much of the implementation is taken from wrench:
# https://github.com/JieyuZ2/wrench/blob/main/wrench/labelmodel/flyingsquid.py
# If binary, we only need one model
- for i in range(
- 1 if self._weak_labels.cardinality == 2 else self._weak_labels.cardinality
- ):
- model = self._FlyingSquidLabelModel(
- m=len(self._weak_labels.rules), **self._init_kwargs
- )
+ for i in range(1 if self._weak_labels.cardinality == 2 else self._weak_labels.cardinality):
+ model = self._FlyingSquidLabelModel(m=len(self._weak_labels.rules), **self._init_kwargs)
wl_matrix_i = self._copy_and_transform_wl_matrix(wl_matrix, i)
model.fit(L_train=wl_matrix_i, **kwargs)
models.append(model)
@@ -839,9 +775,7 @@ def _copy_and_transform_wl_matrix(self, weak_label_matrix: np.ndarray, i: int):
"""
wl_matrix_i = weak_label_matrix.copy()
- target_mask = (
- wl_matrix_i == self._weak_labels.label2int[self._weak_labels.labels[i]]
- )
+ target_mask = wl_matrix_i == self._weak_labels.label2int[self._weak_labels.labels[i]]
abstain_mask = wl_matrix_i == self._weak_labels.label2int[None]
other_mask = (~target_mask) & (~abstain_mask)
@@ -883,9 +817,7 @@ def predict(
if isinstance(tie_break_policy, str):
tie_break_policy = TieBreakPolicy(tie_break_policy)
- wl_matrix = self._weak_labels.matrix(
- has_annotation=None if include_annotated_records else False
- )
+ wl_matrix = self._weak_labels.matrix(has_annotation=None if include_annotated_records else False)
probabilities = self._predict(wl_matrix, verbose)
# add predictions to records
@@ -893,9 +825,7 @@ def predict(
for i, prob, rec in zip(
range(len(probabilities)),
probabilities,
- self._weak_labels.records(
- has_annotation=None if include_annotated_records else False
- ),
+ self._weak_labels.records(has_annotation=None if include_annotated_records else False),
):
# Check if model abstains, that is if the highest probability is assigned to more than one label
# 1.e-8 is taken from the abs tolerance of np.isclose
@@ -905,38 +835,25 @@ def predict(
tie = True
# maybe skip record
- if not include_abstentions and (
- tie and tie_break_policy is TieBreakPolicy.ABSTAIN
- ):
+ if not include_abstentions and (tie and tie_break_policy is TieBreakPolicy.ABSTAIN):
continue
if not tie:
- pred_for_rec = [
- (self._weak_labels.labels[i], prob[i])
- for i in np.argsort(prob)[::-1]
- ]
+ pred_for_rec = [(self._weak_labels.labels[i], prob[i]) for i in np.argsort(prob)[::-1]]
# resolve ties following the tie break policy
elif tie_break_policy is TieBreakPolicy.ABSTAIN:
pred_for_rec = None
elif tie_break_policy is TieBreakPolicy.RANDOM:
- random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(
- equal_prob_idx
- )
+ random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(equal_prob_idx)
for idx in equal_prob_idx:
if idx == random_idx:
prob[idx] += self._PROBABILITY_INCREASE_ON_TIE_BREAK
else:
- prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / (
- len(equal_prob_idx) - 1
- )
- pred_for_rec = [
- (self._weak_labels.labels[i], prob[i])
- for i in np.argsort(prob)[::-1]
- ]
+ prob[idx] -= self._PROBABILITY_INCREASE_ON_TIE_BREAK / (len(equal_prob_idx) - 1)
+ pred_for_rec = [(self._weak_labels.labels[i], prob[i]) for i in np.argsort(prob)[::-1]]
else:
raise NotImplementedError(
- f"The tie break policy '{tie_break_policy.value}' is not"
- " implemented for FlyingSquid!"
+ f"The tie break policy '{tie_break_policy.value}' is not" " implemented for FlyingSquid!"
)
records_with_prediction.append(rec.copy(deep=True))
@@ -962,26 +879,19 @@ def _predict(self, weak_label_matrix: np.ndarray, verbose: bool) -> np.ndarray:
NotFittedError: If the label model was still not fitted.
"""
if not self._models:
- raise NotFittedError(
- "This FlyingSquid instance is not fitted yet. Call `fit` before using"
- " this model."
- )
+ raise NotFittedError("This FlyingSquid instance is not fitted yet. Call `fit` before using" " this model.")
# create predictions for each label
if self._weak_labels.cardinality > 2:
probas = np.zeros((len(weak_label_matrix), self._weak_labels.cardinality))
for i in range(self._weak_labels.cardinality):
wl_matrix_i = self._copy_and_transform_wl_matrix(weak_label_matrix, i)
- probas[:, i] = self._models[i].predict_proba(
- L_matrix=wl_matrix_i, verbose=verbose
- )[:, 0]
+ probas[:, i] = self._models[i].predict_proba(L_matrix=wl_matrix_i, verbose=verbose)[:, 0]
probas = np.nan_to_num(probas, nan=-np.inf) # handle NaN
probas = np.exp(probas) / np.sum(np.exp(probas), axis=1, keepdims=True)
# if binary, we only have one model
else:
wl_matrix_i = self._copy_and_transform_wl_matrix(weak_label_matrix, 0)
- probas = self._models[0].predict_proba(
- L_matrix=wl_matrix_i, verbose=verbose
- )
+ probas = self._models[0].predict_proba(L_matrix=wl_matrix_i, verbose=verbose)
return probas
@@ -1038,18 +948,13 @@ def score(
probabilities = self._predict(wl_matrix, verbose)
# 1.e-8 is taken from the abs tolerance of np.isclose
- is_max = (
- np.abs(probabilities.max(axis=1, keepdims=True) - probabilities) < 1.0e-8
- )
+ is_max = np.abs(probabilities.max(axis=1, keepdims=True) - probabilities) < 1.0e-8
is_tie = is_max.sum(axis=1) > 1
prediction = np.argmax(is_max, axis=1)
# we need to transform the indexes!
annotation = np.array(
- [
- self._weak_labels.labels.index(self._weak_labels.int2label[i])
- for i in self._weak_labels.annotation()
- ],
+ [self._weak_labels.labels.index(self._weak_labels.int2label[i]) for i in self._weak_labels.annotation()],
dtype=np.short,
)
@@ -1061,14 +966,11 @@ def score(
elif tie_break_policy is TieBreakPolicy.RANDOM:
for i in np.nonzero(is_tie)[0]:
equal_prob_idx = np.nonzero(is_max[i])[0]
- random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(
- equal_prob_idx
- )
+ random_idx = int(hashlib.sha1(f"{i}".encode()).hexdigest(), 16) % len(equal_prob_idx)
prediction[i] = equal_prob_idx[random_idx]
else:
raise NotImplementedError(
- f"The tie break policy '{tie_break_policy.value}' is not implemented"
- " for FlyingSquid!"
+ f"The tie break policy '{tie_break_policy.value}' is not implemented" " for FlyingSquid!"
)
return classification_report(
diff --git a/src/argilla/labeling/text_classification/rule.py b/src/argilla/labeling/text_classification/rule.py
index 658e63f7fc..b845882863 100644
--- a/src/argilla/labeling/text_classification/rule.py
+++ b/src/argilla/labeling/text_classification/rule.py
@@ -89,22 +89,16 @@ def _convert_to_labeling_rule(self):
def add_to_dataset(self, dataset: str):
"""Add to rule to the given dataset"""
- api.active_api().add_dataset_labeling_rules(
- dataset, rules=[self._convert_to_labeling_rule()]
- )
+ api.active_api().add_dataset_labeling_rules(dataset, rules=[self._convert_to_labeling_rule()])
def remove_from_dataset(self, dataset: str):
"""Removes the rule from the given dataset"""
- api.active_api().delete_dataset_labeling_rules(
- dataset, rules=[self._convert_to_labeling_rule()]
- )
+ api.active_api().delete_dataset_labeling_rules(dataset, rules=[self._convert_to_labeling_rule()])
def update_at_dataset(self, dataset: str):
"""Updates the rule at the given dataset"""
- api.active_api().update_dataset_labeling_rules(
- dataset, rules=[self._convert_to_labeling_rule()]
- )
+ api.active_api().update_dataset_labeling_rules(dataset, rules=[self._convert_to_labeling_rule()])
def apply(self, dataset: str):
"""Apply the rule to a dataset and save matching ids of the records.
@@ -140,15 +134,11 @@ def metrics(self, dataset: str) -> Dict[str, Union[int, float]]:
"coverage": metrics.coverage,
"annotated_coverage": metrics.coverage_annotated,
"correct": int(metrics.correct) if metrics.correct is not None else None,
- "incorrect": int(metrics.incorrect)
- if metrics.incorrect is not None
- else None,
+ "incorrect": int(metrics.incorrect) if metrics.incorrect is not None else None,
"precision": metrics.precision if metrics.precision is not None else None,
}
- def __call__(
- self, record: TextClassificationRecord
- ) -> Optional[Union[str, List[str]]]:
+ def __call__(self, record: TextClassificationRecord) -> Optional[Union[str, List[str]]]:
"""Check if the given record is among the matching ids from the ``self.apply`` call.
Args:
@@ -161,9 +151,7 @@ def __call__(
RuleNotAppliedError: If the rule was not applied to the dataset before.
"""
if self._matching_ids is None:
- raise RuleNotAppliedError(
- "Rule was still not applied. Please call `self.apply(dataset)` first."
- )
+ raise RuleNotAppliedError("Rule was still not applied. Please call `self.apply(dataset)` first.")
try:
self._matching_ids[record.id]
diff --git a/src/argilla/labeling/text_classification/weak_labels.py b/src/argilla/labeling/text_classification/weak_labels.py
index 21fbeeaf45..0cc6c11615 100644
--- a/src/argilla/labeling/text_classification/weak_labels.py
+++ b/src/argilla/labeling/text_classification/weak_labels.py
@@ -62,16 +62,12 @@ def __init__(
query: Optional[str] = None,
):
if not isinstance(dataset, str):
- raise TypeError(
- f"The name of the dataset must be a string, but you provided: {dataset}"
- )
+ raise TypeError(f"The name of the dataset must be a string, but you provided: {dataset}")
self._dataset = dataset
self._rules = rules or load_rules(dataset)
if not self._rules:
- raise NoRulesFoundError(
- f"No rules were found in the given dataset '{dataset}'"
- )
+ raise NoRulesFoundError(f"No rules were found in the given dataset '{dataset}'")
self._rules_index2name = {
# covers our Rule class, snorkel's LabelingFunction class and arbitrary methods
@@ -94,14 +90,10 @@ def __init__(
f"Following rule names are duplicated x times: { {key: val for key, val in counts.items() if val > 1} }"
" Please make sure to provide unique rule names."
)
- self._rules_name2index = {
- val: key for key, val in self._rules_index2name.items()
- }
+ self._rules_name2index = {val: key for key, val in self._rules_index2name.items()}
# load records and check compatibility
- self._records: DatasetForTextClassification = load(
- dataset, query=query, ids=ids
- )
+ self._records: DatasetForTextClassification = load(dataset, query=query, ids=ids)
if not self._records:
raise NoRecordsFoundError(
f"No records found in dataset '{dataset}'"
@@ -127,9 +119,7 @@ def cardinality(self) -> int:
"""The number of labels."""
raise NotImplementedError
- def records(
- self, has_annotation: Optional[bool] = None
- ) -> List[TextClassificationRecord]:
+ def records(self, has_annotation: Optional[bool] = None) -> List[TextClassificationRecord]:
"""Returns the records corresponding to the weak label matrix
Args:
@@ -228,15 +218,10 @@ def _compute_overlaps_conflicts(
Array of fractions of overlaps/conflicts for each rule, optionally normalized by their coverages.
"""
overlaps_or_conflicts = (
- has_weak_label
- * np.repeat(has_overlaps_or_conflicts, len(self._rules)).reshape(
- has_weak_label.shape
- )
+ has_weak_label * np.repeat(has_overlaps_or_conflicts, len(self._rules)).reshape(has_weak_label.shape)
).sum(axis=0) / len(self._records)
# total
- overlaps_or_conflicts = np.append(
- overlaps_or_conflicts, has_overlaps_or_conflicts.sum() / len(self._records)
- )
+ overlaps_or_conflicts = np.append(overlaps_or_conflicts, has_overlaps_or_conflicts.sum() / len(self._records))
if normalize_by_coverage:
# ignore division by 0 warnings, as we convert the nan back to 0.0 afterwards
@@ -291,17 +276,15 @@ def extend_matrix(
np.copy(embeddings).astype(np.float32), abstains, supports, gpu=gpu
)
elif self._extension_queries is None:
- raise ValueError(
- "Embeddings are not optional the first time a matrix is extended."
- )
+ raise ValueError("Embeddings are not optional the first time a matrix is extended.")
dists, nearest = self._extension_queries
self._extended_matrix = np.copy(self._matrix)
new_points = [(dists[i] > thresholds[i]) for i in range(self._matrix.shape[1])]
for i in range(self._matrix.shape[1]):
- self._extended_matrix[abstains[i][new_points[i]], i] = self._matrix[
- supports[i], i
- ][nearest[i][new_points[i]]]
+ self._extended_matrix[abstains[i][new_points[i]], i] = self._matrix[supports[i], i][
+ nearest[i][new_points[i]]
+ ]
self._extend_matrix_postprocess()
@@ -323,15 +306,11 @@ def _find_dists_and_nearest(
faiss.normalize_L2(embeddings)
embeddings_length = embeddings.shape[1]
- label_fn_indexes = [
- faiss.IndexFlatIP(embeddings_length) for i in range(self._matrix.shape[1])
- ]
+ label_fn_indexes = [faiss.IndexFlatIP(embeddings_length) for i in range(self._matrix.shape[1])]
if gpu:
res = faiss.StandardGpuResources()
- label_fn_indexes = [
- faiss.index_cpu_to_gpu(res, 0, x) for x in label_fn_indexes
- ]
+ label_fn_indexes = [faiss.index_cpu_to_gpu(res, 0, x) for x in label_fn_indexes]
for i in range(self._matrix.shape[1]):
label_fn_indexes[i].add(embeddings[support[i]])
@@ -342,12 +321,8 @@ def _find_dists_and_nearest(
faiss.normalize_L2(embs_query)
dists_and_nearest.append(label_fn_indexes[i].search(embs_query, 1))
- dists = [
- dist_and_nearest[0].flatten() for dist_and_nearest in dists_and_nearest
- ]
- nearest = [
- dist_and_nearest[1].flatten() for dist_and_nearest in dists_and_nearest
- ]
+ dists = [dist_and_nearest[0].flatten() for dist_and_nearest in dists_and_nearest]
+ nearest = [dist_and_nearest[1].flatten() for dist_and_nearest in dists_and_nearest]
return dists, nearest
@@ -453,9 +428,7 @@ def _apply_rules(
rule.apply(self._dataset)
# create weak label matrix, annotation array, final label2int
- weak_label_matrix = np.empty(
- (len(self._records), len(self._rules)), dtype=np.short
- )
+ weak_label_matrix = np.empty((len(self._records), len(self._rules)), dtype=np.short)
annotation_array = np.empty(len(self._records), dtype=np.short)
_label2int = {None: -1} if label2int is None else label2int
if None not in _label2int:
@@ -463,9 +436,7 @@ def _apply_rules(
"Your provided `label2int` mapping does not contain the required abstention label `None`."
)
- for n, record in tqdm(
- enumerate(self._records), total=len(self._records), desc="Applying rules"
- ):
+ for n, record in tqdm(enumerate(self._records), total=len(self._records), desc="Applying rules"):
# FIRST: fill annotation array
try:
annotation = _label2int[record.annotation]
@@ -535,9 +506,7 @@ def matrix(self, has_annotation: Optional[bool] = None) -> np.ndarray:
Returns:
The weak label matrix, or optionally just a part of it.
"""
- matrix = (
- self._matrix if self._extended_matrix is None else self._extended_matrix
- )
+ matrix = self._matrix if self._extended_matrix is None else self._extended_matrix
if has_annotation is True:
return matrix[self._annotation != self._label2int[None]]
@@ -606,9 +575,7 @@ def summary(
polarity = [
set(
self._int2label[integer]
- for integer in np.unique(
- self.matrix()[:, i][self.matrix()[:, i] != self._label2int[None]]
- )
+ for integer in np.unique(self.matrix()[:, i][self.matrix()[:, i] != self._label2int[None]])
)
for i in range(len(self._rules))
]
@@ -623,9 +590,7 @@ def summary(
# overlaps
has_overlaps = has_weak_label.sum(axis=1) > 1
- overlaps = self._compute_overlaps_conflicts(
- has_weak_label, has_overlaps, coverage, normalize_by_coverage
- )
+ overlaps = self._compute_overlaps_conflicts(has_weak_label, has_overlaps, coverage, normalize_by_coverage)
# conflicts
# TODO: For a lot of records (~1e6), this could become slow (~10s) ... a vectorized solution would be better.
@@ -634,9 +599,7 @@ def summary(
axis=1,
arr=self.matrix(),
)
- conflicts = self._compute_overlaps_conflicts(
- has_weak_label, has_conflicts, coverage, normalize_by_coverage
- )
+ conflicts = self._compute_overlaps_conflicts(has_weak_label, has_conflicts, coverage, normalize_by_coverage)
# index for the summary
index = list(self._rules_name2index.keys()) + ["total"]
@@ -645,13 +608,10 @@ def summary(
has_annotation = annotation != self._label2int[None]
if any(has_annotation):
# annotated coverage
- annotated_coverage = (
- has_weak_label[has_annotation].sum(axis=0) / has_annotation.sum()
- )
+ annotated_coverage = has_weak_label[has_annotation].sum(axis=0) / has_annotation.sum()
annotated_coverage = np.append(
annotated_coverage,
- (has_weak_label[has_annotation].sum(axis=1) > 0).sum()
- / has_annotation.sum(),
+ (has_weak_label[has_annotation].sum(axis=1) > 0).sum() / has_annotation.sum(),
)
# correct/incorrect
@@ -692,9 +652,7 @@ def _compute_correct_incorrect(
self, has_weak_label: np.ndarray, annotation: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
"""Helper method to compute the correctly and incorrectly predicted annotations by the rules"""
- annotation_matrix = np.repeat(annotation, len(self._rules)).reshape(
- self.matrix().shape
- )
+ annotation_matrix = np.repeat(annotation, len(self._rules)).reshape(self.matrix().shape)
# correct
correct_with_abstain = annotation_matrix == self.matrix()
@@ -737,13 +695,8 @@ def show_records(
# get rule mask
if rules is not None:
- rules = [
- self._rules_name2index[rule] if isinstance(rule, str) else rule
- for rule in rules
- ]
- idx_by_rules = (self.matrix()[:, rules] != self._label2int[None]).sum(
- axis=1
- ) == len(rules)
+ rules = [self._rules_name2index[rule] if isinstance(rule, str) else rule for rule in rules]
+ idx_by_rules = (self.matrix()[:, rules] != self._label2int[None]).sum(axis=1) == len(rules)
else:
idx_by_rules = np.ones_like(self._records).astype(bool)
@@ -767,9 +720,7 @@ def change_mapping(self, label2int: Dict[str, int]):
for label in self._label2int:
# Check new label2int mapping
if label not in label2int:
- raise MissingLabelError(
- f"The label '{label}' is missing in the new mapping."
- )
+ raise MissingLabelError(f"The label '{label}' is missing in the new mapping.")
# compute masks
label_masks[label] = self.matrix() == self._label2int[label]
annotation_masks[label] = self._annotation == self._label2int[label]
@@ -794,22 +745,18 @@ def extend_matrix(
def _extend_matrix_preprocess(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
abstains = [
- np.argwhere(self._matrix[:, i] == self._label2int[None]).flatten()
- for i in range(self._matrix.shape[1])
+ np.argwhere(self._matrix[:, i] == self._label2int[None]).flatten() for i in range(self._matrix.shape[1])
]
supports = [
- np.argwhere(self._matrix[:, i] != self._label2int[None]).flatten()
- for i in range(self._matrix.shape[1])
+ np.argwhere(self._matrix[:, i] != self._label2int[None]).flatten() for i in range(self._matrix.shape[1])
]
return abstains, supports
def _extend_matrix_postprocess(self):
"""Keeps the rows of the original weak label matrix, for which at least on rule did not abstain."""
- recs_with_votes = np.argwhere(
- (self._matrix != self._label2int[None]).sum(-1) > 0
- ).flatten()
+ recs_with_votes = np.argwhere((self._matrix != self._label2int[None]).sum(-1) > 0).flatten()
self._extended_matrix[recs_with_votes] = self._matrix[recs_with_votes]
@@ -866,26 +813,16 @@ def _apply_rules(self) -> Tuple[np.ndarray, np.ndarray, List[str]]:
# we make two passes over the records:
# FIRST: Get labels from rules and annotations
annotations, weak_labels = [], []
- for record in tqdm(
- self._records, total=len(self._records), desc="Applying rules"
- ):
- annotations.append(
- record.annotation
- if isinstance(record.annotation, list)
- else [record.annotation]
- )
+ for record in tqdm(self._records, total=len(self._records), desc="Applying rules"):
+ annotations.append(record.annotation if isinstance(record.annotation, list) else [record.annotation])
weak_labels.append([np.atleast_1d(rule(record)) for rule in self._rules])
annotation_set = {ann for anns in annotations for ann in anns}
- weak_label_set = {
- wl for wl_record in weak_labels for wl_rule in wl_record for wl in wl_rule
- }
+ weak_label_set = {wl for wl_record in weak_labels for wl_rule in wl_record for wl in wl_rule}
labels = sorted(list(annotation_set.union(weak_label_set) - {None}))
# create weak label matrix (3D), annotation matrix
- weak_label_matrix = np.empty(
- (len(self._records), len(self._rules), len(labels)), dtype=np.byte
- )
+ weak_label_matrix = np.empty((len(self._records), len(self._rules), len(labels)), dtype=np.byte)
annotation_matrix = np.empty((len(self._records), len(labels)), dtype=np.byte)
# SECOND: Fill arrays with weak labels
@@ -939,9 +876,7 @@ def matrix(self, has_annotation: Optional[bool] = None) -> np.ndarray:
Returns:
The 3 dimensional weak label matrix, or optionally just a part of it.
"""
- matrix = (
- self._matrix if self._extended_matrix is None else self._extended_matrix
- )
+ matrix = self._matrix if self._extended_matrix is None else self._extended_matrix
if has_annotation is True:
return matrix[self._annotation.sum(1) >= 0]
@@ -1023,9 +958,7 @@ def summary(
# overlaps
has_overlaps = has_weak_label.sum(axis=1) > 1
- overlaps = self._compute_overlaps_conflicts(
- has_weak_label, has_overlaps, coverage, normalize_by_coverage
- )
+ overlaps = self._compute_overlaps_conflicts(has_weak_label, has_overlaps, coverage, normalize_by_coverage)
# index for the summary
index = list(self._rules_name2index.keys()) + ["total"]
@@ -1034,13 +967,10 @@ def summary(
has_annotation = annotation.sum(1) >= 0
if any(has_annotation):
# annotated coverage
- annotated_coverage = (
- has_weak_label[has_annotation].sum(axis=0) / has_annotation.sum()
- )
+ annotated_coverage = has_weak_label[has_annotation].sum(axis=0) / has_annotation.sum()
annotated_coverage = np.append(
annotated_coverage,
- (has_weak_label[has_annotation].sum(axis=1) > 0).sum()
- / has_annotation.sum(),
+ (has_weak_label[has_annotation].sum(axis=1) > 0).sum() / has_annotation.sum(),
)
# correct/incorrect
@@ -1072,24 +1002,16 @@ def summary(
index=index,
)
- def _compute_correct_incorrect(
- self, annotation: np.ndarray
- ) -> Tuple[np.ndarray, np.ndarray]:
+ def _compute_correct_incorrect(self, annotation: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Helper method to compute the correctly and incorrectly predicted annotations by the rules"""
# transform annotation to tensor
- annotation = np.repeat(annotation, len(self._rules), axis=0).reshape(
- self.matrix().shape
- )
+ annotation = np.repeat(annotation, len(self._rules), axis=0).reshape(self.matrix().shape)
# correct, we don't want to count the "correct non predictions"
correct = ((annotation == self.matrix()) & (self.matrix() == 1)).sum(2).sum(0)
# incorrect, we don't want to count the "misses", since we focus on precision, not recall
- incorrect = (
- ((annotation != self.matrix()) & (self.matrix() == 1) & (annotation != -1))
- .sum(2)
- .sum(0)
- )
+ incorrect = ((annotation != self.matrix()) & (self.matrix() == 1) & (annotation != -1)).sum(2).sum(0)
# add totals at the end
return np.append(correct, correct.sum()), np.append(incorrect, incorrect.sum())
@@ -1120,13 +1042,8 @@ def show_records(
# get rule mask
if rules is not None:
- rules = [
- self._rules_name2index[rule] if isinstance(rule, str) else rule
- for rule in rules
- ]
- idx_by_rules = (self.matrix()[:, rules, :].sum(axis=2) >= 0).sum(
- axis=1
- ) == len(rules)
+ rules = [self._rules_name2index[rule] if isinstance(rule, str) else rule for rule in rules]
+ idx_by_rules = (self.matrix()[:, rules, :].sum(axis=2) >= 0).sum(axis=1) == len(rules)
else:
idx_by_rules = np.ones_like(self._records).astype(bool)
@@ -1135,9 +1052,7 @@ def show_records(
return pd.DataFrame(map(lambda x: x.dict(), filtered_records))
- @_add_docstr(
- WeakLabelsBase.extend_matrix.__doc__.format(class_name="WeakMultiLabels")
- )
+ @_add_docstr(WeakLabelsBase.extend_matrix.__doc__.format(class_name="WeakMultiLabels"))
def extend_matrix(
self,
thresholds: Union[List[float], np.ndarray],
@@ -1147,15 +1062,9 @@ def extend_matrix(
super().extend_matrix(thresholds=thresholds, embeddings=embeddings, gpu=gpu)
def _extend_matrix_preprocess(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
- abstains = [
- np.argwhere(self._matrix[:, i].sum(-1) < 0).flatten()
- for i in range(self._matrix.shape[1])
- ]
+ abstains = [np.argwhere(self._matrix[:, i].sum(-1) < 0).flatten() for i in range(self._matrix.shape[1])]
- supports = [
- np.argwhere(self._matrix[:, i].sum(-1) >= 0).flatten()
- for i in range(self._matrix.shape[1])
- ]
+ supports = [np.argwhere(self._matrix[:, i].sum(-1) >= 0).flatten() for i in range(self._matrix.shape[1])]
return abstains, supports
diff --git a/src/argilla/listeners/listener.py b/src/argilla/listeners/listener.py
index d836bbf5e6..75ed7f6103 100644
--- a/src/argilla/listeners/listener.py
+++ b/src/argilla/listeners/listener.py
@@ -69,9 +69,7 @@ def formatted_query(self) -> Optional[str]:
return None
return self.query.format(**(self.query_params or {}))
- __listener_job__: Optional[schedule.Job] = dataclasses.field(
- init=False, default=None
- )
+ __listener_job__: Optional[schedule.Job] = dataclasses.field(init=False, default=None)
__stop_schedule_event__ = None
__current_thread__ = None
__scheduler__ = schedule.Scheduler()
@@ -120,13 +118,11 @@ def start(self, *action_args, **action_kwargs):
if self.is_running():
raise ValueError("Listener is already running")
- job_step = self.__catch_exceptions__(cancel_on_failure=True)(
- self.__listener_iteration_job__
- )
+ job_step = self.__catch_exceptions__(cancel_on_failure=True)(self.__listener_iteration_job__)
- self.__listener_job__ = self.__scheduler__.every(
- self.interval_in_seconds
- ).seconds.do(job_step, *action_args, **action_kwargs)
+ self.__listener_job__ = self.__scheduler__.every(self.interval_in_seconds).seconds.do(
+ job_step, *action_args, **action_kwargs
+ )
class _ScheduleThread(threading.Thread):
_WAIT_EVENT = threading.Event()
@@ -183,9 +179,7 @@ def __listener_iteration_job__(self, *args, **kwargs):
ctx = RGListenerContext(
listener=self,
query_params=self.query_params,
- metrics=self.__compute_metrics__(
- current_api, dataset, query=self.formatted_query
- ),
+ metrics=self.__compute_metrics__(current_api, dataset, query=self.formatted_query),
)
if self.condition is None:
self._LOGGER.debug("No condition found! Running action...")
@@ -230,9 +224,7 @@ def __run_action__(self, ctx: Optional[RGListenerContext] = None, *args, **kwarg
try:
action_args = [ctx] if ctx else []
if self.query_records:
- action_args.insert(
- 0, argilla.load(name=self.dataset, query=self.formatted_query)
- )
+ action_args.insert(0, argilla.load(name=self.dataset, query=self.formatted_query))
self._LOGGER.debug(f"Running action with arguments: {action_args}")
return self.action(*args, *action_args, **kwargs)
except: # noqa: E722
diff --git a/src/argilla/metrics/commons.py b/src/argilla/metrics/commons.py
index a97f2d57ca..449c5343bf 100644
--- a/src/argilla/metrics/commons.py
+++ b/src/argilla/metrics/commons.py
@@ -41,9 +41,7 @@ def text_length(name: str, query: Optional[str] = None) -> MetricSummary:
return MetricSummary.new_summary(
data=metric.results,
- visualization=lambda: helpers.histogram(
- data=metric.results, title=metric.description
- ),
+ visualization=lambda: helpers.histogram(data=metric.results, title=metric.description),
)
@@ -65,15 +63,11 @@ def records_status(name: str, query: Optional[str] = None) -> MetricSummary:
>>> summary.visualize() # will plot an histogram with results
>>> summary.data # returns the raw result data
"""
- metric = api.active_api().compute_metric(
- name, metric="status_distribution", query=query
- )
+ metric = api.active_api().compute_metric(name, metric="status_distribution", query=query)
return MetricSummary.new_summary(
data=metric.results,
- visualization=lambda: helpers.bar(
- data=metric.results, title=metric.description
- ),
+ visualization=lambda: helpers.bar(data=metric.results, title=metric.description),
)
diff --git a/src/argilla/metrics/helpers.py b/src/argilla/metrics/helpers.py
index 098d599036..aa07aee14f 100644
--- a/src/argilla/metrics/helpers.py
+++ b/src/argilla/metrics/helpers.py
@@ -34,10 +34,7 @@ def bar(data: dict, title: str = "Bar", x_legend: str = "", y_legend: str = ""):
return empty_visualization()
keys, values = zip(*data.items())
- keys = [
- key.encode("unicode-escape").decode() if isinstance(key, str) else key
- for key in keys
- ]
+ keys = [key.encode("unicode-escape").decode() if isinstance(key, str) else key for key in keys]
fig = go.Figure(data=go.Bar(y=values, x=keys))
fig.update_layout(
title=title,
@@ -138,11 +135,7 @@ def f1(data: Dict[str, float], title: str):
row=1,
col=2,
)
- per_label = {
- k: v
- for k, v in data.items()
- if all(key not in k for key in ["macro", "micro", "support"])
- }
+ per_label = {k: v for k, v in data.items() if all(key not in k for key in ["macro", "micro", "support"])}
fig.add_bar(
x=[k for k, v in per_label.items()],
diff --git a/src/argilla/metrics/models.py b/src/argilla/metrics/models.py
index 9fe9c462bb..5f9bcffc77 100644
--- a/src/argilla/metrics/models.py
+++ b/src/argilla/metrics/models.py
@@ -29,15 +29,10 @@ def visualize(self):
try:
return self._build_visualization()
except ModuleNotFoundError:
- warnings.warn(
- "Please, install plotly in order to use this feature\n"
- "%>pip install plotly"
- )
+ warnings.warn("Please, install plotly in order to use this feature\n" "%>pip install plotly")
@classmethod
- def new_summary(
- cls, data: Dict[str, Any], visualization: Callable
- ) -> "MetricSummary":
+ def new_summary(cls, data: Dict[str, Any], visualization: Callable) -> "MetricSummary":
summary = cls(data=data)
summary._build_visualization = visualization
return summary
diff --git a/src/argilla/metrics/token_classification/metrics.py b/src/argilla/metrics/token_classification/metrics.py
index 31e8b52b0d..b6591e26e2 100644
--- a/src/argilla/metrics/token_classification/metrics.py
+++ b/src/argilla/metrics/token_classification/metrics.py
@@ -22,9 +22,7 @@
from argilla.metrics.models import MetricSummary
-def tokens_length(
- name: str, query: Optional[str] = None, interval: int = 1
-) -> MetricSummary:
+def tokens_length(name: str, query: Optional[str] = None, interval: int = 1) -> MetricSummary:
"""Computes the text length distribution measured in number of tokens.
Args:
@@ -42,9 +40,7 @@ def tokens_length(
>>> summary.visualize() # will plot a histogram with results
>>> summary.data # the raw histogram data with bins of size 5
"""
- metric = api.active_api().compute_metric(
- name, metric="tokens_length", query=query, interval=interval
- )
+ metric = api.active_api().compute_metric(name, metric="tokens_length", query=query, interval=interval)
return MetricSummary.new_summary(
data=metric.results,
@@ -56,9 +52,7 @@ def tokens_length(
)
-def token_frequency(
- name: str, query: Optional[str] = None, tokens: int = 1000
-) -> MetricSummary:
+def token_frequency(name: str, query: Optional[str] = None, tokens: int = 1000) -> MetricSummary:
"""Computes the token frequency distribution for a numbe of tokens.
Args:
@@ -76,9 +70,7 @@ def token_frequency(
>>> summary.visualize() # will plot a histogram with results
>>> summary.data # the top-50 tokens frequency
"""
- metric = api.active_api().compute_metric(
- name, metric="token_frequency", query=query, size=tokens
- )
+ metric = api.active_api().compute_metric(name, metric="token_frequency", query=query, size=tokens)
return MetricSummary.new_summary(
data=metric.results,
@@ -143,9 +135,7 @@ def token_capitalness(name: str, query: Optional[str] = None) -> MetricSummary:
>>> summary.visualize() # will plot a histogram with results
>>> summary.data # The token capitalness distribution
"""
- metric = api.active_api().compute_metric(
- name, metric="token_capitalness", query=query
- )
+ metric = api.active_api().compute_metric(name, metric="token_capitalness", query=query)
return MetricSummary.new_summary(
data=metric.results,
@@ -214,9 +204,7 @@ def mention_length(
"""
level = (level or "token").lower().strip()
accepted_levels = ["token", "char"]
- assert (
- level in accepted_levels
- ), f"Unexpected value for level. Accepted values are {accepted_levels}"
+ assert level in accepted_levels, f"Unexpected value for level. Accepted values are {accepted_levels}"
metric = api.active_api().compute_metric(
name,
@@ -421,9 +409,7 @@ def top_k_mentions(
for mention in metric.results["mentions"]:
entities = mention["entities"]
if post_label_filter:
- entities = [
- entity for entity in entities if entity["label"] in post_label_filter
- ]
+ entities = [entity for entity in entities if entity["label"] in post_label_filter]
if entities:
mention["entities"] = entities
filtered_mentions.append(mention)
@@ -434,9 +420,7 @@ def top_k_mentions(
for entity in mention["entities"]:
label = entity["label"]
mentions_for_label = entities.get(label, [0] * len(filtered_mentions))
- mentions_for_label[mention_values.index(mention["mention"])] = entity[
- "count"
- ]
+ mentions_for_label[mention_values.index(mention["mention"])] = entity["count"]
entities[label] = mentions_for_label
return MetricSummary.new_summary(
diff --git a/src/argilla/monitoring/_flair.py b/src/argilla/monitoring/_flair.py
index 6330b9fc78..ad03da91ad 100644
--- a/src/argilla/monitoring/_flair.py
+++ b/src/argilla/monitoring/_flair.py
@@ -69,11 +69,7 @@ def predict(self, sentences: Union[List[Sentence], Sentence], *args, **kwargs):
if not metadata:
metadata = [{}] * len(sentences)
- filtered_data = [
- (sentence, meta)
- for sentence, meta in zip(sentences, metadata)
- if self.is_record_accepted()
- ]
+ filtered_data = [(sentence, meta) for sentence, meta in zip(sentences, metadata) if self.is_record_accepted()]
if filtered_data:
self._log_future = self.send_records(filtered_data)
diff --git a/src/argilla/monitoring/_spacy.py b/src/argilla/monitoring/_spacy.py
index a1c80833ac..195e42f0cb 100644
--- a/src/argilla/monitoring/_spacy.py
+++ b/src/argilla/monitoring/_spacy.py
@@ -61,14 +61,10 @@ def doc2token_classification(
event_timestamp=datetime.utcnow(),
)
- def _prepare_log_data(
- self, docs_info: Tuple[Doc, Optional[Dict[str, Any]]]
- ) -> Dict[str, Any]:
+ def _prepare_log_data(self, docs_info: Tuple[Doc, Optional[Dict[str, Any]]]) -> Dict[str, Any]:
return dict(
records=[
- self.doc2token_classification(
- doc, agent=self.__wrapped__.path.name, metadata=metadata
- )
+ self.doc2token_classification(doc, agent=self.__wrapped__.path.name, metadata=metadata)
for doc, metadata in docs_info
],
name=self.dataset,
diff --git a/src/argilla/monitoring/_transformers.py b/src/argilla/monitoring/_transformers.py
index 46fe973f26..8ddcde1c8b 100644
--- a/src/argilla/monitoring/_transformers.py
+++ b/src/argilla/monitoring/_transformers.py
@@ -63,9 +63,7 @@ def _prepare_log_data(
record = TextClassificationRecord(
text=input_ if isinstance(input_, str) else None,
inputs=input_ if not isinstance(input_, str) else None,
- prediction=[
- (prediction.label, prediction.score) for prediction in predictions
- ],
+ prediction=[(prediction.label, prediction.score) for prediction in predictions],
prediction_agent=agent,
metadata=metadata or {},
multi_label=multi_label,
@@ -82,9 +80,7 @@ def _prepare_log_data(
name=dataset_name,
tags={
"name": self.model_config.name_or_path,
- "transformers_version": self.fetch_transformers_version(
- self.model_config
- ),
+ "transformers_version": self.fetch_transformers_version(self.model_config),
"model_type": self.model_config.model_type,
"task": self.__model__.task,
},
diff --git a/src/argilla/monitoring/asgi.py b/src/argilla/monitoring/asgi.py
index f861eee958..deabe0ef47 100644
--- a/src/argilla/monitoring/asgi.py
+++ b/src/argilla/monitoring/asgi.py
@@ -56,9 +56,7 @@ def token_classification_mapper(inputs, outputs):
tokens=tokens or _default_tokenization_pattern.split(text),
prediction=[
(entity["label"], entity["start"], entity["end"])
- for entity in (
- outputs.get("entities") if isinstance(outputs, dict) else outputs
- )
+ for entity in (outputs.get("entities") if isinstance(outputs, dict) else outputs)
],
event_timestamp=datetime.datetime.now(),
)
@@ -67,12 +65,7 @@ def token_classification_mapper(inputs, outputs):
def text_classification_mapper(inputs, outputs):
return TextClassificationRecord(
inputs=inputs,
- prediction=[
- (label, score)
- for label, score in zip(
- outputs.get("labels", []), outputs.get("scores", [])
- )
- ],
+ prediction=[(label, score) for label, score in zip(outputs.get("labels", []), outputs.get("scores", []))],
event_timestamp=datetime.datetime.now(),
)
@@ -167,16 +160,11 @@ async def dispatch(
elif cached_request.method == "GET":
inputs = cached_request.query_params._dict
else:
- raise NotImplementedError(
- "Only request methods POST, PUT and GET are implemented."
- )
+ raise NotImplementedError("Only request methods POST, PUT and GET are implemented.")
# Must obtain response from request
response: Response = await call_next(cached_request)
- if (
- not isinstance(response, (JSONResponse, StreamingResponse))
- or response.status_code >= 400
- ):
+ if not isinstance(response, (JSONResponse, StreamingResponse)) or response.status_code >= 400:
return response
new_response, outputs = await self._extract_response_content(response)
@@ -186,9 +174,7 @@ async def dispatch(
_logger.error("Cannot log to argilla", exc_info=ex)
return await call_next(request)
- async def _extract_response_content(
- self, response: Response
- ) -> Tuple[Response, List[Dict[str, Any]]]:
+ async def _extract_response_content(self, response: Response) -> Tuple[Response, List[Dict[str, Any]]]:
"""Extracts response body content from response and returns a new processable response"""
body = b""
new_response = response
@@ -206,23 +192,17 @@ async def _extract_response_content(
body = response.body
return new_response, json.loads(body)
- def _prepare_argilla_data(
- self, inputs: List[Dict[str, Any]], outputs: List[Dict[str, Any]], **tags
- ):
+ def _prepare_argilla_data(self, inputs: List[Dict[str, Any]], outputs: List[Dict[str, Any]], **tags):
# using the base monitor, we only need to provide the input data to the rg.log function
# and the monitor will handle the sample rate, queue and argilla interaction
try:
records = self._records_mapper(inputs, outputs)
- assert records, ValueError(
- "The records_mapper returns and empty record list."
- )
+ assert records, ValueError("The records_mapper returns and empty record list.")
if not isinstance(records, list):
records = [records]
except Exception as ex:
records = []
- _logger.error(
- "Cannot log to argilla. Error in records mapper.", exc_info=ex
- )
+ _logger.error("Cannot log to argilla. Error in records mapper.", exc_info=ex)
for record in records:
if self._monitor.agent is not None and not record.prediction_agent:
diff --git a/src/argilla/monitoring/base.py b/src/argilla/monitoring/base.py
index 59fc8976d4..86fb2491e7 100644
--- a/src/argilla/monitoring/base.py
+++ b/src/argilla/monitoring/base.py
@@ -187,9 +187,7 @@ def __init__(
super().__init__(*args, **kwargs)
assert dataset, "Missing dataset"
- assert (
- 0.0 < sample_rate <= 1.0
- ), "Wrong sample rate. Set a value in (0, 1] range."
+ assert 0.0 < sample_rate <= 1.0, "Wrong sample rate. Set a value in (0, 1] range."
self.dataset = dataset
self.sample_rate = sample_rate
diff --git a/src/argilla/monitoring/model_monitor.py b/src/argilla/monitoring/model_monitor.py
index 0caca68e27..9d126329b8 100644
--- a/src/argilla/monitoring/model_monitor.py
+++ b/src/argilla/monitoring/model_monitor.py
@@ -62,7 +62,6 @@ def monitor(
return model_monitor
warnings.warn(
- "The provided task model is not supported by monitoring module. "
- "Predictions won't be logged into argilla"
+ "The provided task model is not supported by monitoring module. " "Predictions won't be logged into argilla"
)
return task_model
diff --git a/src/argilla/server/apis/v0/handlers/datasets.py b/src/argilla/server/apis/v0/handlers/datasets.py
index 0d1470f0f9..46adf21652 100644
--- a/src/argilla/server/apis/v0/handlers/datasets.py
+++ b/src/argilla/server/apis/v0/handlers/datasets.py
@@ -49,9 +49,7 @@ async def list_datasets(
) -> List[Dataset]:
return service.list(
user=current_user,
- workspaces=[request_deps.workspace]
- if request_deps.workspace is not None
- else None,
+ workspaces=[request_deps.workspace] if request_deps.workspace is not None else None,
)
@@ -113,9 +111,7 @@ def update_dataset(
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> Dataset:
- found_ds = service.find_by_name(
- user=current_user, name=name, workspace=ds_params.workspace
- )
+ found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
return service.update(
user=current_user,
@@ -159,9 +155,7 @@ def close_dataset(
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
):
- found_ds = service.find_by_name(
- user=current_user, name=name, workspace=ds_params.workspace
- )
+ found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
service.close(user=current_user, dataset=found_ds)
@@ -175,9 +169,7 @@ def open_dataset(
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
):
- found_ds = service.find_by_name(
- user=current_user, name=name, workspace=ds_params.workspace
- )
+ found_ds = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
service.open(user=current_user, dataset=found_ds)
@@ -194,9 +186,7 @@ def copy_dataset(
service: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> Dataset:
- found = service.find_by_name(
- user=current_user, name=name, workspace=ds_params.workspace
- )
+ found = service.find_by_name(user=current_user, name=name, workspace=ds_params.workspace)
return service.copy_dataset(
user=current_user,
dataset=found,
diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py
index 68b92bbb37..3f4e0b6743 100644
--- a/src/argilla/server/apis/v0/handlers/metrics.py
+++ b/src/argilla/server/apis/v0/handlers/metrics.py
@@ -31,9 +31,7 @@
class MetricInfo(BaseModel):
id: str = Field(description="The metric id")
name: str = Field(description="The metric name")
- description: Optional[str] = Field(
- default=None, description="The metric description"
- )
+ description: Optional[str] = Field(default=None, description="The metric description")
@dataclass
diff --git a/src/argilla/server/apis/v0/handlers/text2text.py b/src/argilla/server/apis/v0/handlers/text2text.py
index d0ca90bb62..cf6230c4ad 100644
--- a/src/argilla/server/apis/v0/handlers/text2text.py
+++ b/src/argilla/server/apis/v0/handlers/text2text.py
@@ -120,9 +120,7 @@ def search_records(
name: str,
search: Text2TextSearchRequest = None,
common_params: CommonTaskHandlerDependencies = Depends(),
- include_metrics: bool = Query(
- False, description="If enabled, return related record metrics"
- ),
+ include_metrics: bool = Query(False, description="If enabled, return related record metrics"),
pagination: RequestPagination = Depends(),
service: Text2TextService = Depends(Text2TextService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
@@ -149,9 +147,7 @@ def search_records(
return Text2TextSearchResults(
total=result.total,
records=[Text2TextRecord.parse_obj(r) for r in result.records],
- aggregations=Text2TextSearchAggregations.parse_obj(result.metrics)
- if result.metrics
- else None,
+ aggregations=Text2TextSearchAggregations.parse_obj(result.metrics) if result.metrics else None,
)
def scan_data_response(
@@ -183,9 +179,7 @@ def grouper(n, iterable, fillvalue=None):
)
) + "\n"
- return StreamingResponseWithErrorHandling(
- stream_generator(data_stream), media_type="application/json"
- )
+ return StreamingResponseWithErrorHandling(stream_generator(data_stream), media_type="application/json")
@router.post(
path=f"{base_endpoint}/data",
diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py
index 414529dbbf..e5d4ca840b 100644
--- a/src/argilla/server/apis/v0/handlers/text_classification.py
+++ b/src/argilla/server/apis/v0/handlers/text_classification.py
@@ -89,9 +89,7 @@ async def bulk_records(
name: str,
bulk: TextClassificationBulkRequest,
common_params: CommonTaskHandlerDependencies = Depends(),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
@@ -120,9 +118,7 @@ async def bulk_records(
# TODO(@frascuchon): Validator should be applied in the service layer
records = [ServiceTextClassificationRecord.parse_obj(r) for r in bulk.records]
- await validator.validate_dataset_records(
- user=current_user, dataset=dataset, records=records
- )
+ await validator.validate_dataset_records(user=current_user, dataset=dataset, records=records)
result = await service.add_records(
dataset=dataset,
@@ -144,13 +140,9 @@ def search_records(
name: str,
search: TextClassificationSearchRequest = None,
common_params: CommonTaskHandlerDependencies = Depends(),
- include_metrics: bool = Query(
- False, description="If enabled, return related record metrics"
- ),
+ include_metrics: bool = Query(False, description="If enabled, return related record metrics"),
pagination: RequestPagination = Depends(),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> TextClassificationSearchResults:
@@ -204,9 +196,7 @@ def search_records(
return TextClassificationSearchResults(
total=result.total,
records=result.records,
- aggregations=TextClassificationSearchAggregations.parse_obj(result.metrics)
- if result.metrics
- else None,
+ aggregations=TextClassificationSearchAggregations.parse_obj(result.metrics) if result.metrics else None,
)
def scan_data_response(
@@ -238,9 +228,7 @@ def grouper(n, iterable, fillvalue=None):
)
) + "\n"
- return StreamingResponseWithErrorHandling(
- stream_generator(data_stream), media_type="application/json"
- )
+ return StreamingResponseWithErrorHandling(stream_generator(data_stream), media_type="application/json")
@router.post(
f"{base_endpoint}/data",
@@ -253,9 +241,7 @@ async def stream_data(
common_params: CommonTaskHandlerDependencies = Depends(),
id_from: Optional[str] = None,
limit: Optional[int] = Query(None, description="Limit loaded records", gt=0),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> StreamingResponse:
@@ -319,9 +305,7 @@ async def list_labeling_rules(
name: str,
common_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> List[LabelingRule]:
dataset = datasets.find_by_name(
@@ -332,9 +316,7 @@ async def list_labeling_rules(
as_dataset_class=TasksFactory.get_task_dataset(task_type),
)
- return [
- LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset)
- ]
+ return [LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset)]
@deprecate_endpoint(
path=f"{new_base_endpoint}/labeling/rules",
@@ -349,9 +331,7 @@ async def create_rule(
name: str,
rule: CreateLabelingRule,
common_params: CommonTaskHandlerDependencies = Depends(),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> LabelingRule:
@@ -385,13 +365,9 @@ async def create_rule(
async def compute_rule_metrics(
name: str,
query: str,
- labels: Optional[List[str]] = Query(
- None, description="Label related to query rule", alias="label"
- ),
+ labels: Optional[List[str]] = Query(None, description="Label related to query rule", alias="label"),
common_params: CommonTaskHandlerDependencies = Depends(),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> LabelingRuleMetricsSummary:
@@ -417,9 +393,7 @@ async def compute_rule_metrics(
async def compute_dataset_rules_metrics(
name: str,
common_params: CommonTaskHandlerDependencies = Depends(),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> DatasetLabelingRulesMetricsSummary:
@@ -444,9 +418,7 @@ async def delete_labeling_rule(
name: str,
query: str,
common_params: CommonTaskHandlerDependencies = Depends(),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> None:
@@ -473,9 +445,7 @@ async def get_rule(
name: str,
query: str,
common_params: CommonTaskHandlerDependencies = Depends(),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> LabelingRule:
@@ -506,9 +476,7 @@ async def update_rule(
query: str,
update: UpdateLabelingRule,
common_params: CommonTaskHandlerDependencies = Depends(),
- service: TextClassificationService = Depends(
- TextClassificationService.get_instance
- ),
+ service: TextClassificationService = Depends(TextClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> LabelingRule:
diff --git a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py
index e40c97e079..70c6b327d5 100644
--- a/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py
+++ b/src/argilla/server/apis/v0/handlers/text_classification_dataset_settings.py
@@ -63,9 +63,7 @@ async def get_dataset_settings(
task=task,
)
- settings = await datasets.get_settings(
- user=user, dataset=found_ds, class_type=__svc_settings_class__
- )
+ settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__)
return TextClassificationSettings.parse_obj(settings)
@deprecate_endpoint(
@@ -79,9 +77,7 @@ async def get_dataset_settings(
response_model=TextClassificationSettings,
)
async def save_settings(
- request: TextClassificationSettings = Body(
- ..., description=f"The {task} dataset settings"
- ),
+ request: TextClassificationSettings = Body(..., description=f"The {task} dataset settings"),
name: str = DATASET_NAME_PATH_PARAM,
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
@@ -94,9 +90,7 @@ async def save_settings(
task=task,
workspace=ws_params.workspace,
)
- await validator.validate_dataset_settings(
- user=user, dataset=found_ds, settings=request
- )
+ await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request)
settings = await datasets.save_settings(
user=user,
dataset=found_ds,
diff --git a/src/argilla/server/apis/v0/handlers/token_classification.py b/src/argilla/server/apis/v0/handlers/token_classification.py
index 7a21310e1d..b4a851b8de 100644
--- a/src/argilla/server/apis/v0/handlers/token_classification.py
+++ b/src/argilla/server/apis/v0/handlers/token_classification.py
@@ -83,9 +83,7 @@ async def bulk_records(
name: str,
bulk: TokenClassificationBulkRequest,
common_params: CommonTaskHandlerDependencies = Depends(),
- service: TokenClassificationService = Depends(
- TokenClassificationService.get_instance
- ),
+ service: TokenClassificationService = Depends(TokenClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
validator: DatasetValidator = Depends(DatasetValidator.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
@@ -145,9 +143,7 @@ def search_records(
description="If enabled, return related record metrics",
),
pagination: RequestPagination = Depends(),
- service: TokenClassificationService = Depends(
- TokenClassificationService.get_instance
- ),
+ service: TokenClassificationService = Depends(TokenClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
) -> TokenClassificationSearchResults:
@@ -173,9 +169,7 @@ def search_records(
return TokenClassificationSearchResults(
total=results.total,
records=[TokenClassificationRecord.parse_obj(r) for r in results.records],
- aggregations=TokenClassificationAggregations.parse_obj(results.metrics)
- if results.metrics
- else None,
+ aggregations=TokenClassificationAggregations.parse_obj(results.metrics) if results.metrics else None,
)
def scan_data_response(
@@ -207,9 +201,7 @@ def grouper(n, iterable, fillvalue=None):
)
) + "\n"
- return StreamingResponseWithErrorHandling(
- stream_generator(data_stream), media_type="application/json"
- )
+ return StreamingResponseWithErrorHandling(stream_generator(data_stream), media_type="application/json")
@router.post(
path=f"{base_endpoint}/data",
@@ -221,9 +213,7 @@ async def stream_data(
query: Optional[TokenClassificationQuery] = None,
common_params: CommonTaskHandlerDependencies = Depends(),
limit: Optional[int] = Query(None, description="Limit loaded records", gt=0),
- service: TokenClassificationService = Depends(
- TokenClassificationService.get_instance
- ),
+ service: TokenClassificationService = Depends(TokenClassificationService.get_instance),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
current_user: User = Security(auth.get_user, scopes=[]),
id_from: Optional[str] = None,
diff --git a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py
index 4756aeb9da..a5ac2411f3 100644
--- a/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py
+++ b/src/argilla/server/apis/v0/handlers/token_classification_dataset_settings.py
@@ -63,9 +63,7 @@ async def get_dataset_settings(
task=task,
)
- settings = await datasets.get_settings(
- user=user, dataset=found_ds, class_type=__svc_settings_class__
- )
+ settings = await datasets.get_settings(user=user, dataset=found_ds, class_type=__svc_settings_class__)
return TokenClassificationSettings.parse_obj(settings)
@deprecate_endpoint(
@@ -79,9 +77,7 @@ async def get_dataset_settings(
response_model=TokenClassificationSettings,
)
async def save_settings(
- request: TokenClassificationSettings = Body(
- ..., description=f"The {task} dataset settings"
- ),
+ request: TokenClassificationSettings = Body(..., description=f"The {task} dataset settings"),
name: str = DATASET_NAME_PATH_PARAM,
ws_params: CommonTaskHandlerDependencies = Depends(),
datasets: DatasetsService = Depends(DatasetsService.get_instance),
@@ -94,9 +90,7 @@ async def save_settings(
task=task,
workspace=ws_params.workspace,
)
- await validator.validate_dataset_settings(
- user=user, dataset=found_ds, settings=request
- )
+ await validator.validate_dataset_settings(user=user, dataset=found_ds, settings=request)
settings = await datasets.save_settings(
user=user,
dataset=found_ds,
diff --git a/src/argilla/server/apis/v0/handlers/users.py b/src/argilla/server/apis/v0/handlers/users.py
index d48d89c367..3d17b69492 100644
--- a/src/argilla/server/apis/v0/handlers/users.py
+++ b/src/argilla/server/apis/v0/handlers/users.py
@@ -22,12 +22,8 @@
router = APIRouter(tags=["users"])
-@router.get(
- "/me", response_model=User, response_model_exclude_none=True, operation_id="whoami"
-)
-async def whoami(
- request: Request, current_user: User = Security(auth.get_user, scopes=[])
-):
+@router.get("/me", response_model=User, response_model_exclude_none=True, operation_id="whoami")
+async def whoami(request: Request, current_user: User = Security(auth.get_user, scopes=[])):
"""
User info endpoint
diff --git a/src/argilla/server/apis/v0/helpers.py b/src/argilla/server/apis/v0/helpers.py
index 55874f58f5..905a7ed421 100644
--- a/src/argilla/server/apis/v0/helpers.py
+++ b/src/argilla/server/apis/v0/helpers.py
@@ -15,9 +15,7 @@
from typing import Callable
-def deprecate_endpoint(
- path: str, new_path: str, router_method: Callable, *args, **kwargs
-):
+def deprecate_endpoint(path: str, new_path: str, router_method: Callable, *args, **kwargs):
"""
Helper decorator to deprecate a `path` endpoint adding the `new_path` endpoint.
diff --git a/src/argilla/server/apis/v0/models/commons/model.py b/src/argilla/server/apis/v0/models/commons/model.py
index 25a62c0657..f25289ab9a 100644
--- a/src/argilla/server/apis/v0/models/commons/model.py
+++ b/src/argilla/server/apis/v0/models/commons/model.py
@@ -67,9 +67,7 @@ class BulkResponse(BaseModel):
failed: int = 0
-class BaseSearchResults(
- GenericModel, Generic[_Record, ServiceSearchResultsAggregations]
-):
+class BaseSearchResults(GenericModel, Generic[_Record, ServiceSearchResultsAggregations]):
total: int = 0
records: List[_Record] = Field(default_factory=list)
aggregations: ServiceSearchResultsAggregations = None
diff --git a/src/argilla/server/apis/v0/models/commons/params.py b/src/argilla/server/apis/v0/models/commons/params.py
index 09a27b559f..4bd92aa4e2 100644
--- a/src/argilla/server/apis/v0/models/commons/params.py
+++ b/src/argilla/server/apis/v0/models/commons/params.py
@@ -23,9 +23,7 @@
)
from argilla.server.security.model import WORKSPACE_NAME_PATTERN
-DATASET_NAME_PATH_PARAM = Path(
- ..., regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name"
-)
+DATASET_NAME_PATH_PARAM = Path(..., regex=DATASET_NAME_REGEX_PATTERN, description="The dataset name")
@dataclass
@@ -33,9 +31,7 @@ class RequestPagination:
"""Query pagination params"""
limit: int = Query(50, gte=0, le=1000, description="Response records limit")
- from_: int = Query(
- 0, ge=0, le=10000, alias="from", description="Record sequence from"
- )
+ from_: int = Query(0, ge=0, le=10000, alias="from", description="Record sequence from")
@dataclass
@@ -60,14 +56,9 @@ class CommonTaskHandlerDependencies:
@property
def workspace(self) -> str:
"""Return read workspace. Query param prior to header param"""
- workspace = (
- self.__workspace_param__
- or self.__workspace_header__
- or self.__old_workspace_header__
- )
+ workspace = self.__workspace_param__ or self.__workspace_header__ or self.__old_workspace_header__
if workspace:
assert WORKSPACE_NAME_PATTERN.match(workspace), (
- "Wrong workspace format. "
- f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}"
+ "Wrong workspace format. " f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}"
)
return workspace
diff --git a/src/argilla/server/apis/v0/models/text2text.py b/src/argilla/server/apis/v0/models/text2text.py
index 5345f95746..1caa2b107d 100644
--- a/src/argilla/server/apis/v0/models/text2text.py
+++ b/src/argilla/server/apis/v0/models/text2text.py
@@ -71,9 +71,7 @@ class Text2TextSearchAggregations(ServiceBaseSearchResultsAggregations):
annotated_text: Dict[str, int] = Field(default_factory=dict)
-class Text2TextSearchResults(
- BaseSearchResults[Text2TextRecord, Text2TextSearchAggregations]
-):
+class Text2TextSearchResults(BaseSearchResults[Text2TextRecord, Text2TextSearchAggregations]):
pass
diff --git a/src/argilla/server/apis/v0/models/text_classification.py b/src/argilla/server/apis/v0/models/text_classification.py
index b2886fe56b..a4aa5621b4 100644
--- a/src/argilla/server/apis/v0/models/text_classification.py
+++ b/src/argilla/server/apis/v0/models/text_classification.py
@@ -47,17 +47,12 @@
class UpdateLabelingRule(BaseModel):
- label: Optional[str] = Field(
- default=None, description="@Deprecated::The label associated with the rule."
- )
+ label: Optional[str] = Field(default=None, description="@Deprecated::The label associated with the rule.")
labels: List[str] = Field(
default_factory=list,
- description="For multi label problems, a list of labels. "
- "It will replace the `label` field",
- )
- description: Optional[str] = Field(
- None, description="A brief description of the rule"
+ description="For multi label problems, a list of labels. " "It will replace the `label` field",
)
+ description: Optional[str] = Field(None, description="A brief description of the rule")
@root_validator
def initialize_labels(cls, values):
@@ -83,9 +78,7 @@ def strip_query(cls, query: str) -> str:
class LabelingRule(CreateLabelingRule):
author: str = Field(description="User who created the rule")
- created_at: Optional[datetime] = Field(
- default_factory=datetime.utcnow, description="Rule creation timestamp"
- )
+ created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="Rule creation timestamp")
class LabelingRuleMetricsSummary(_LabelingRuleMetricsSummary):
@@ -110,9 +103,7 @@ class TextClassificationRecordInputs(BaseRecordInputs[TextClassificationAnnotati
explanation: Optional[Dict[str, List[TokenAttributions]]] = None
-class TextClassificationRecord(
- TextClassificationRecordInputs, BaseRecord[TextClassificationAnnotation]
-):
+class TextClassificationRecord(TextClassificationRecordInputs, BaseRecord[TextClassificationAnnotation]):
pass
@@ -125,9 +116,7 @@ def check_multi_label_integrity(cls, records: List[TextClassificationRecord]):
if records:
multi_label = records[0].multi_label
for record in records[1:]:
- assert (
- multi_label == record.multi_label
- ), "All records must be single/multi labelled"
+ assert multi_label == record.multi_label, "All records must be single/multi labelled"
return records
diff --git a/src/argilla/server/apis/v0/models/token_classification.py b/src/argilla/server/apis/v0/models/token_classification.py
index 63eaaa9632..4ec6608a85 100644
--- a/src/argilla/server/apis/v0/models/token_classification.py
+++ b/src/argilla/server/apis/v0/models/token_classification.py
@@ -61,9 +61,7 @@ def check_text_content(cls, text: str):
return text
-class TokenClassificationRecord(
- TokenClassificationRecordInputs, BaseRecord[TokenClassificationAnnotation]
-):
+class TokenClassificationRecord(TokenClassificationRecordInputs, BaseRecord[TokenClassificationAnnotation]):
pass
@@ -88,9 +86,7 @@ class TokenClassificationAggregations(ServiceBaseSearchResultsAggregations):
mentions: Dict[str, Dict[str, int]] = Field(default_factory=dict)
-class TokenClassificationSearchResults(
- BaseSearchResults[TokenClassificationRecord, TokenClassificationAggregations]
-):
+class TokenClassificationSearchResults(BaseSearchResults[TokenClassificationRecord, TokenClassificationAggregations]):
pass
diff --git a/src/argilla/server/apis/v0/validators/text_classification.py b/src/argilla/server/apis/v0/validators/text_classification.py
index e225749bed..c7a3209a64 100644
--- a/src/argilla/server/apis/v0/validators/text_classification.py
+++ b/src/argilla/server/apis/v0/validators/text_classification.py
@@ -54,9 +54,7 @@ def get_instance(
cls._INSTANCE = cls(datasets, metrics=metrics)
return cls._INSTANCE
- async def validate_dataset_settings(
- self, user: User, dataset: Dataset, settings: TextClassificationSettings
- ):
+ async def validate_dataset_settings(self, user: User, dataset: Dataset, settings: TextClassificationSettings):
if settings and settings.label_schema:
results = self.__metrics__.summarize_metric(
dataset=dataset,
@@ -66,9 +64,7 @@ async def validate_dataset_settings(
)
if results:
labels = results.get("labels", [])
- label_schema = set(
- [label.name for label in settings.label_schema.labels]
- )
+ label_schema = set([label.name for label in settings.label_schema.labels])
for label in labels:
if label not in label_schema:
raise BadRequestError(
@@ -87,9 +83,7 @@ async def validate_dataset_records(
user=user, dataset=dataset, class_type=__svc_settings_class__
)
if settings and settings.label_schema:
- label_schema = set(
- [label.name for label in settings.label_schema.labels]
- )
+ label_schema = set([label.name for label in settings.label_schema.labels])
for r in records:
if r.prediction:
self.__check_label_classes__(
diff --git a/src/argilla/server/apis/v0/validators/token_classification.py b/src/argilla/server/apis/v0/validators/token_classification.py
index a5004c9dfa..52616d366c 100644
--- a/src/argilla/server/apis/v0/validators/token_classification.py
+++ b/src/argilla/server/apis/v0/validators/token_classification.py
@@ -55,9 +55,7 @@ def get_instance(
cls._INSTANCE = cls(datasets, metrics=metrics)
return cls._INSTANCE
- async def validate_dataset_settings(
- self, user: User, dataset: Dataset, settings: TokenClassificationSettings
- ):
+ async def validate_dataset_settings(self, user: User, dataset: Dataset, settings: TokenClassificationSettings):
if settings and settings.label_schema:
results = self.__metrics__.summarize_metric(
dataset=dataset,
@@ -67,9 +65,7 @@ async def validate_dataset_settings(
)
if results:
labels = results.get("labels", [])
- label_schema = set(
- [label.name for label in settings.label_schema.labels]
- )
+ label_schema = set([label.name for label in settings.label_schema.labels])
for label in labels:
if label not in label_schema:
raise BadRequestError(
@@ -84,15 +80,11 @@ async def validate_dataset_records(
records: List[ServiceTokenClassificationRecord],
):
try:
- settings: TokenClassificationSettings = (
- await self.__datasets__.get_settings(
- user=user, dataset=dataset, class_type=__svc_settings_class__
- )
+ settings: TokenClassificationSettings = await self.__datasets__.get_settings(
+ user=user, dataset=dataset, class_type=__svc_settings_class__
)
if settings and settings.label_schema:
- label_schema = set(
- [label.name for label in settings.label_schema.labels]
- )
+ label_schema = set([label.name for label in settings.label_schema.labels])
for r in records:
if r.prediction:
@@ -103,9 +95,7 @@ async def validate_dataset_records(
pass
@staticmethod
- def __check_label_entities__(
- label_schema: Set[str], annotation: ServiceTokenClassificationAnnotation
- ):
+ def __check_label_entities__(label_schema: Set[str], annotation: ServiceTokenClassificationAnnotation):
if not annotation:
return
for entity in annotation.entities:
diff --git a/src/argilla/server/commons/config.py b/src/argilla/server/commons/config.py
index b12dfc368c..47d068fc39 100644
--- a/src/argilla/server/commons/config.py
+++ b/src/argilla/server/commons/config.py
@@ -85,18 +85,14 @@ def __get_task_config__(cls, task):
return config
@classmethod
- def find_task_metric(
- cls, task: TaskType, metric_id: str
- ) -> Optional[ServiceBaseMetric]:
+ def find_task_metric(cls, task: TaskType, metric_id: str) -> Optional[ServiceBaseMetric]:
metrics = cls.find_task_metrics(task, {metric_id})
if metrics:
return metrics[0]
raise EntityNotFoundError(name=metric_id, type=ServiceBaseMetric)
@classmethod
- def find_task_metrics(
- cls, task: TaskType, metric_ids: Set[str]
- ) -> List[ServiceBaseMetric]:
+ def find_task_metrics(cls, task: TaskType, metric_ids: Set[str]) -> List[ServiceBaseMetric]:
if not metric_ids:
return []
diff --git a/src/argilla/server/commons/telemetry.py b/src/argilla/server/commons/telemetry.py
index 3f39929033..e00abd15c6 100644
--- a/src/argilla/server/commons/telemetry.py
+++ b/src/argilla/server/commons/telemetry.py
@@ -67,9 +67,7 @@ def get(cls):
try:
cls.__INSTANCE__ = cls(client=_configure_analytics())
except Exception as err:
- logging.getLogger(__name__).warning(
- f"Cannot initialize telemetry. Error: {err}. Disabling..."
- )
+ logging.getLogger(__name__).warning(f"Cannot initialize telemetry. Error: {err}. Disabling...")
settings.enable_telemetry = False
return None
return cls.__INSTANCE__
@@ -88,9 +86,7 @@ def __post_init__(self):
"version": __version__,
}
- def track_data(
- self, action: str, data: Dict[str, Any], include_system_info: bool = True
- ):
+ def track_data(self, action: str, data: Dict[str, Any], include_system_info: bool = True):
event_data = data.copy()
self.client.track(
user_id=self.__server_id_str__,
@@ -101,18 +97,13 @@ def track_data(
def _process_request_info(request: Request):
- return {
- header: request.headers.get(header)
- for header in ["user-agent", "accept-language"]
- }
+ return {header: request.headers.get(header) for header in ["user-agent", "accept-language"]}
async def track_error(error: ServerError, request: Request):
client = _TelemetryClient.get()
if client:
- client.track_data(
- "ServerErrorFound", {"code": error.code, **_process_request_info(request)}
- )
+ client.track_data("ServerErrorFound", {"code": error.code, **_process_request_info(request)})
async def track_bulk(task: TaskType, records: int):
diff --git a/src/argilla/server/daos/backend/client_adapters/opensearch.py b/src/argilla/server/daos/backend/client_adapters/opensearch.py
index 088b48da84..83629300cb 100644
--- a/src/argilla/server/daos/backend/client_adapters/opensearch.py
+++ b/src/argilla/server/daos/backend/client_adapters/opensearch.py
@@ -220,11 +220,7 @@ def compute_index_metric(
}
)
- filtered_params = {
- argument: params[argument]
- for argument in metric.metric_arg_names
- if argument in params
- }
+ filtered_params = {argument: params[argument] for argument in metric.metric_arg_names if argument in params}
aggregations = metric.aggregation_request(**filtered_params)
if not aggregations:
@@ -250,9 +246,7 @@ def compute_index_metric(
search_aggregations = search.get("aggregations", {})
if search_aggregations:
- parsed_aggregations = query_helpers.parse_aggregations(
- search_aggregations
- )
+ parsed_aggregations = query_helpers.parse_aggregations(search_aggregations)
results.update(parsed_aggregations)
return metric.aggregation_result(results.get(metric.id, results))
@@ -459,9 +453,7 @@ def is_subfield(key: str):
return True
return False
- schema = {
- key: value for key, value in schema.items() if not is_subfield(key)
- }
+ schema = {key: value for key, value in schema.items() if not is_subfield(key)}
return schema
except IndexNotFoundError:
@@ -688,11 +680,7 @@ def is_read_only_index(self, index: str) -> bool:
allow_no_indices=True,
flat_settings=True,
)
- return (
- response[index]["settings"]["index.blocks.write"] == "true"
- if response
- else False
- )
+ return response[index]["settings"]["index.blocks.write"] == "true" if response else False
def enable_read_only_index(self, index: str):
return self._enable_or_disable_read_only_index(
@@ -770,8 +758,7 @@ def _get_fields_schema(
data = response
return {
- key: list(definition["mapping"].values())[0]["type"]
- for key, definition in data["mappings"].items()
+ key: list(definition["mapping"].values())[0]["type"] for key, definition in data["mappings"].items()
}
def _normalize_document(
diff --git a/src/argilla/server/daos/backend/generic_elastic.py b/src/argilla/server/daos/backend/generic_elastic.py
index 9a1bdd103b..61d2ece88a 100644
--- a/src/argilla/server/daos/backend/generic_elastic.py
+++ b/src/argilla/server/daos/backend/generic_elastic.py
@@ -291,8 +291,7 @@ def _check_max_number_of_vectors():
if len(vector_names) > settings.vectors_fields_limit:
raise BadRequestError(
- detail=f"Cannot create more than {settings.vectors_fields_limit} "
- "kind of vectors per dataset"
+ detail=f"Cannot create more than {settings.vectors_fields_limit} " "kind of vectors per dataset"
)
_check_max_number_of_vectors()
@@ -489,8 +488,7 @@ def find_dataset(
if len(docs) > 1:
raise ValueError(
- f"Ambiguous dataset info found for name {name}. "
- "Please provide a valid owner/workspace"
+ f"Ambiguous dataset info found for name {name}. " "Please provide a valid owner/workspace"
)
document = docs[0]
return document
diff --git a/src/argilla/server/daos/backend/mappings/helpers.py b/src/argilla/server/daos/backend/mappings/helpers.py
index b0726f94a2..4b32da9244 100644
--- a/src/argilla/server/daos/backend/mappings/helpers.py
+++ b/src/argilla/server/daos/backend/mappings/helpers.py
@@ -185,11 +185,7 @@ def dynamic_metadata_text():
def dynamic_annotations_text(path: str):
path = f"{path}.*"
- return {
- path: mappings.path_match_keyword_template(
- path=path, enable_text_search_in_keywords=True
- )
- }
+ return {path: mappings.path_match_keyword_template(path=path, enable_text_search_in_keywords=True)}
def tasks_common_mappings():
diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py
index 256ffcdbf1..82bd60a1e7 100644
--- a/src/argilla/server/daos/backend/metrics/base.py
+++ b/src/argilla/server/daos/backend/metrics/base.py
@@ -37,9 +37,7 @@ def __post_init__(self):
def get_function_arg_names(func):
return func.__code__.co_varnames
- def aggregation_request(
- self, *args, **kwargs
- ) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
+ def aggregation_request(self, *args, **kwargs) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""
Configures the summary es aggregation definition
"""
@@ -121,9 +119,7 @@ def _build_aggregation(self, interval: Optional[float] = None) -> Dict[str, Any]
if self.fixed_interval:
interval = self.fixed_interval
- return aggregations.histogram_aggregation(
- field_name=self.field, script=self.script, interval=interval
- )
+ return aggregations.histogram_aggregation(field_name=self.field, script=self.script, interval=interval)
@dataclasses.dataclass
diff --git a/src/argilla/server/daos/backend/metrics/datasets.py b/src/argilla/server/daos/backend/metrics/datasets.py
index a673b7ba02..b261599aed 100644
--- a/src/argilla/server/daos/backend/metrics/datasets.py
+++ b/src/argilla/server/daos/backend/metrics/datasets.py
@@ -15,6 +15,4 @@
# All metrics related to the datasets index
from argilla.server.daos.backend.metrics.base import TermsAggregation
-METRICS = {
- "all_workspaces": TermsAggregation(id="all_workspaces", field="owner.keyword")
-}
+METRICS = {"all_workspaces": TermsAggregation(id="all_workspaces", field="owner.keyword")}
diff --git a/src/argilla/server/daos/backend/metrics/text_classification.py b/src/argilla/server/daos/backend/metrics/text_classification.py
index fa59630558..8d3a83228c 100644
--- a/src/argilla/server/daos/backend/metrics/text_classification.py
+++ b/src/argilla/server/daos/backend/metrics/text_classification.py
@@ -29,9 +29,7 @@ def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]:
rules_filters = [filters.text_query(rule_query) for rule_query in queries]
return aggregations.filters_aggregation(
filters={
- "covered_records": filters.boolean_filter(
- should_filters=rules_filters, minimum_should_match=1
- ),
+ "covered_records": filters.boolean_filter(should_filters=rules_filters, minimum_should_match=1),
"annotated_covered_records": filters.boolean_filter(
filter_query=filters.exists_field("annotated_as"),
should_filters=rules_filters,
@@ -45,9 +43,7 @@ def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]:
class LabelingRulesMetric(ElasticsearchMetric):
id: str
- def _build_aggregation(
- self, rule_query: str, labels: Optional[List[str]]
- ) -> Dict[str, Any]:
+ def _build_aggregation(self, rule_query: str, labels: Optional[List[str]]) -> Dict[str, Any]:
annotated_records_filter = filters.exists_field("annotated_as")
rule_query_filter = filters.text_query(rule_query)
aggr_filters = {
@@ -60,9 +56,7 @@ def _build_aggregation(
if labels is not None:
for label in labels:
- rule_label_annotated_filter = filters.term_filter(
- "annotated_as", value=label
- )
+ rule_label_annotated_filter = filters.term_filter("annotated_as", value=label)
encoded_label = self._encode_label_name(label)
aggr_filters.update(
{
@@ -99,9 +93,7 @@ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, An
aggregation_result = unflatten_dict(aggregation_result)
results = {
"covered_records": aggregation_result.pop("covered_records"),
- "annotated_covered_records": aggregation_result.pop(
- "annotated_covered_records"
- ),
+ "annotated_covered_records": aggregation_result.pop("annotated_covered_records"),
}
all_correct = []
diff --git a/src/argilla/server/daos/backend/metrics/token_classification.py b/src/argilla/server/daos/backend/metrics/token_classification.py
index 2c5dc27f0e..35e2f7bb11 100644
--- a/src/argilla/server/daos/backend/metrics/token_classification.py
+++ b/src/argilla/server/daos/backend/metrics/token_classification.py
@@ -57,11 +57,7 @@ def _inner_aggregation(
self.compound_nested_field(self.labels_field),
size=entity_size,
),
- "count": {
- "cardinality": {
- "field": self.compound_nested_field(self.labels_field)
- }
- },
+ "count": {"cardinality": {"field": self.compound_nested_field(self.labels_field)}},
"entities_variability_filter": {
"bucket_selector": {
"buckets_path": {"numLabels": "count"},
@@ -83,10 +79,7 @@ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, An
result = [
{
"mention": mention,
- "entities": [
- {"label": entity, "count": count}
- for entity, count in mention_aggs["entities"].items()
- ],
+ "entities": [{"label": entity, "count": count} for entity, count in mention_aggs["entities"].items()],
}
for mention, mention_aggs in aggregation_result.items()
]
@@ -194,16 +187,12 @@ def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, An
"predicted_mentions_distribution": NestedBidimensionalTermsAggregation(
id="predicted_mentions_distribution",
nested_path="metrics.predicted.mentions",
- biterms=BidimensionalTermsAggregation(
- id="bi-dimensional", field_x="label", field_y="value"
- ),
+ biterms=BidimensionalTermsAggregation(id="bi-dimensional", field_x="label", field_y="value"),
),
"annotated_mentions_distribution": NestedBidimensionalTermsAggregation(
id="predicted_mentions_distribution",
nested_path="metrics.annotated.mentions",
- biterms=BidimensionalTermsAggregation(
- id="bi-dimensional", field_x="label", field_y="value"
- ),
+ biterms=BidimensionalTermsAggregation(id="bi-dimensional", field_x="label", field_y="value"),
),
"predicted_top_k_mentions_consistency": TopKMentionsConsistency(
id="predicted_top_k_mentions_consistency",
diff --git a/src/argilla/server/daos/backend/query_helpers.py b/src/argilla/server/daos/backend/query_helpers.py
index 1c57c5ccee..92a0f4bbd5 100644
--- a/src/argilla/server/daos/backend/query_helpers.py
+++ b/src/argilla/server/daos/backend/query_helpers.py
@@ -33,16 +33,11 @@ def resolve_mapping(info) -> Dict[str, Any]:
return {
"type": "nested",
"include_in_root": True,
- "properties": {
- key: resolve_mapping(info)
- for key, info in model_class.schema()["properties"].items()
- },
+ "properties": {key: resolve_mapping(info) for key, info in model_class.schema()["properties"].items()},
}
-def parse_aggregations(
- es_aggregations: Dict[str, Any] = None
-) -> Optional[Dict[str, Any]]:
+def parse_aggregations(es_aggregations: Dict[str, Any] = None) -> Optional[Dict[str, Any]]:
"""Transforms elasticsearch raw aggregations into a more friendly structure"""
if es_aggregations is None:
@@ -80,9 +75,7 @@ def parse_buckets(buckets: List[Dict[str, Any]]) -> Dict[str, Any]:
key_metrics = {}
for metric_key, metric in list(bucket.items()):
if "buckets" in metric:
- key_metrics.update(
- {metric_key: parse_buckets(metric.get("buckets", []))}
- )
+ key_metrics.update({metric_key: parse_buckets(metric.get("buckets", []))})
else:
metric_values = list(metric.values())
value = metric_values[0] if len(metric_values) == 1 else metric
@@ -171,13 +164,7 @@ def metadata(metadata: Dict[str, Union[str, List[str]]]) -> List[Dict[str, Any]]
return []
return [
- {
- "terms": {
- f"metadata.{key}": query_text
- if isinstance(query_text, List)
- else [query_text]
- }
- }
+ {"terms": {f"metadata.{key}": query_text if isinstance(query_text, List) else [query_text]}}
for key, query_text in metadata.items()
]
@@ -237,9 +224,7 @@ class aggregations:
MAX_AGGREGATION_SIZE = 5000 # TODO: improve by setting env var
@staticmethod
- def nested_aggregation(
- nested_path: str, inner_aggregation: Dict[str, Any]
- ) -> Dict[str, Any]:
+ def nested_aggregation(nested_path: str, inner_aggregation: Dict[str, Any]) -> Dict[str, Any]:
inner_meta = list(inner_aggregation.values())[0].get("meta", {})
return {
"meta": {
@@ -250,15 +235,11 @@ def nested_aggregation(
}
@staticmethod
- def bidimentional_terms_aggregations(
- field_name_x: str, field_name_y: str, size=DEFAULT_AGGREGATION_SIZE
- ):
+ def bidimentional_terms_aggregations(field_name_x: str, field_name_y: str, size=DEFAULT_AGGREGATION_SIZE):
return {
**aggregations.terms_aggregation(field_name_x, size=size),
"meta": {"kind": "2d-terms"},
- "aggs": {
- field_name_y: aggregations.terms_aggregation(field_name_y, size=size)
- },
+ "aggs": {field_name_y: aggregations.terms_aggregation(field_name_y, size=size)},
}
@staticmethod
@@ -321,9 +302,7 @@ def custom_fields(
) -> Dict[str, Dict[str, Any]]:
"""Build a set of aggregations for a given field definition (extracted from index mapping)"""
- def __resolve_aggregation_for_field_type(
- field_type: str, field_name: str
- ) -> Optional[Dict[str, Any]]:
+ def __resolve_aggregation_for_field_type(field_type: str, field_name: str) -> Optional[Dict[str, Any]]:
if field_type in ["keyword", "long", "integer", "boolean"]:
return aggregations.terms_aggregation(field_name=field_name, size=size)
if field_type in ["float", "date"]:
@@ -337,9 +316,7 @@ def __resolve_aggregation_for_field_type(
return {
key: aggregation
for key, type_ in fields_definitions.items()
- for aggregation in [
- __resolve_aggregation_for_field_type(type_, field_name=key)
- ]
+ for aggregation in [__resolve_aggregation_for_field_type(type_, field_name=key)]
if aggregation
}
@@ -348,9 +325,7 @@ def filters_aggregation(filters: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
return {"filters": {"filters": filters}}
-def find_nested_field_path(
- field_name: str, mapping_definition: Dict[str, Any]
-) -> Optional[str]:
+def find_nested_field_path(field_name: str, mapping_definition: Dict[str, Any]) -> Optional[str]:
"""
Given a field name, find the nested path if any related to field name
definition in provided mapping definition
@@ -367,9 +342,7 @@ def find_nested_field_path(
The found nested path if any, None otherwise
"""
- def build_flatten_properties_map(
- properties: Dict[str, Any], prefix: str = ""
- ) -> Dict[str, Any]:
+ def build_flatten_properties_map(properties: Dict[str, Any], prefix: str = "") -> Dict[str, Any]:
results = {}
for prop_name, prop_value in properties.items():
if prefix:
@@ -377,11 +350,7 @@ def build_flatten_properties_map(
if "type" in prop_value:
results[prop_name] = prop_value["type"]
if "properties" in prop_value:
- results.update(
- build_flatten_properties_map(
- prop_value["properties"], prefix=prop_name
- )
- )
+ results.update(build_flatten_properties_map(prop_value["properties"], prefix=prop_name))
return results
properties_map = build_flatten_properties_map(mapping_definition)
diff --git a/src/argilla/server/daos/backend/search/model.py b/src/argilla/server/daos/backend/search/model.py
index fcaad0e5fa..49a4986595 100644
--- a/src/argilla/server/daos/backend/search/model.py
+++ b/src/argilla/server/daos/backend/search/model.py
@@ -70,8 +70,7 @@ class VectorSearch(BaseModel):
value: List[float]
k: Optional[int] = Field(
default=None,
- description="Number of elements to retrieve. "
- "If not provided, the request size will be used instead",
+ description="Number of elements to retrieve. " "If not provided, the request size will be used instead",
)
diff --git a/src/argilla/server/daos/backend/search/query_builder.py b/src/argilla/server/daos/backend/search/query_builder.py
index 76d9e989ff..3da0591864 100644
--- a/src/argilla/server/daos/backend/search/query_builder.py
+++ b/src/argilla/server/daos/backend/search/query_builder.py
@@ -36,13 +36,9 @@ class HighlightParser:
__HIGHLIGHT_PRE_TAG__ = "<@@-ar-key>"
__HIGHLIGHT_POST_TAG__ = "@@-ar-key>"
- __HIGHLIGHT_VALUES_REGEX__ = re.compile(
- rf"{__HIGHLIGHT_PRE_TAG__}(.+?){__HIGHLIGHT_POST_TAG__}"
- )
+ __HIGHLIGHT_VALUES_REGEX__ = re.compile(rf"{__HIGHLIGHT_PRE_TAG__}(.+?){__HIGHLIGHT_POST_TAG__}")
- __HIGHLIGHT_PHRASE_PRE_PARSER_REGEX__ = re.compile(
- rf"{__HIGHLIGHT_POST_TAG__}\s+{__HIGHLIGHT_PRE_TAG__}"
- )
+ __HIGHLIGHT_PHRASE_PRE_PARSER_REGEX__ = re.compile(rf"{__HIGHLIGHT_POST_TAG__}\s+{__HIGHLIGHT_PRE_TAG__}")
@property
def search_keywords_field(self) -> str:
@@ -102,9 +98,7 @@ def get_instance(cls):
cls._INSTANCE = cls()
return cls._INSTANCE
- def _datasets_to_es_query(
- self, query: Optional[BackendDatasetsQuery] = None
- ) -> Dict[str, Any]:
+ def _datasets_to_es_query(self, query: Optional[BackendDatasetsQuery] = None) -> Dict[str, Any]:
if not query:
return filters.match_all()
@@ -120,9 +114,7 @@ def _datasets_to_es_query(
minimum_should_match=1, # OR Condition
should_filters=[
owners_filter,
- filters.boolean_filter(
- must_not_query=filters.exists_field("owner")
- ),
+ filters.boolean_filter(must_not_query=filters.exists_field("owner")),
],
)
)
@@ -279,8 +271,7 @@ def map_2_es_sort_configuration(
if valid_fields:
if sortable_field.id.split(".")[0] not in valid_fields:
raise AssertionError(
- f"Wrong sort id {sortable_field.id}. Valid values are: "
- f"{[str(v) for v in valid_fields]}"
+ f"Wrong sort id {sortable_field.id}. Valid values are: " f"{[str(v) for v in valid_fields]}"
)
field = sortable_field.id
if field == id_field and use_id_keyword:
@@ -323,9 +314,7 @@ def _to_es_query(cls, query: BackendRecordsQuery) -> Dict[str, Any]:
elif isinstance(value, (str, Enum)):
key_filter = filters.term_filter(key, value)
elif isinstance(value, QueryRange):
- key_filter = filters.range_filter(
- field=key, value_from=value.range_from, value_to=value.range_to
- )
+ key_filter = filters.range_filter(field=key, value_from=value.range_from, value_to=value.range_to)
else:
cls._LOGGER.warning(f"Cannot parse query value {value} for key {key}")
diff --git a/src/argilla/server/daos/datasets.py b/src/argilla/server/daos/datasets.py
index c45f8ddd20..d798f823ea 100644
--- a/src/argilla/server/daos/datasets.py
+++ b/src/argilla/server/daos/datasets.py
@@ -39,9 +39,7 @@ class DatasetsDAO:
@classmethod
def get_instance(
cls,
- es: GenericElasticEngineBackend = Depends(
- GenericElasticEngineBackend.get_instance
- ),
+ es: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance),
records_dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance),
) -> "DatasetsDAO":
"""
@@ -117,9 +115,7 @@ def update_dataset(
self,
dataset: DatasetDB,
) -> DatasetDB:
- self._es.update_dataset_document(
- id=dataset.id, document=self._dataset_to_es_doc(dataset)
- )
+ self._es.update_dataset_document(id=dataset.id, document=self._dataset_to_es_doc(dataset))
return dataset
def delete_dataset(self, dataset: DatasetDB):
@@ -141,9 +137,7 @@ def find_by_name(
return None
base_ds = self._es_doc_to_instance(document)
if task and task != base_ds.task:
- raise WrongTaskError(
- detail=f"Provided task {task} cannot be applied to dataset"
- )
+ raise WrongTaskError(detail=f"Provided task {task} cannot be applied to dataset")
dataset_type = as_dataset_class or BaseDatasetDB
return self._es_doc_to_instance(document, ds_class=dataset_type)
@@ -154,9 +148,7 @@ def _es_doc_to_instance(
) -> DatasetDB:
"""Transforms a stored elasticsearch document into a `BaseDatasetDB`"""
- def key_value_list_to_dict(
- key_value_list: List[Dict[str, Any]]
- ) -> Dict[str, Any]:
+ def key_value_list_to_dict(key_value_list: List[Dict[str, Any]]) -> Dict[str, Any]:
return {data["key"]: json.loads(data["value"]) for data in key_value_list}
tags = doc.get("tags", [])
@@ -173,9 +165,7 @@ def key_value_list_to_dict(
@staticmethod
def _dataset_to_es_doc(dataset: DatasetDB) -> Dict[str, Any]:
def dict_to_key_value_list(data: Dict[str, Any]) -> List[Dict[str, Any]]:
- return [
- {"key": key, "value": json.dumps(value)} for key, value in data.items()
- ]
+ return [{"key": key, "value": json.dumps(value)} for key, value in data.items()]
data = dataset.dict(by_alias=True)
tags = data.get("tags", {})
@@ -208,9 +198,7 @@ def open(self, dataset: DatasetDB):
def get_all_workspaces(self) -> List[str]:
"""Get all datasets (Only for super users)"""
- metric_data = self._es.compute_argilla_metric(
- metric_id="all_argilla_workspaces"
- )
+ metric_data = self._es.compute_argilla_metric(metric_id="all_argilla_workspaces")
return [k for k in metric_data]
def save_settings(
@@ -228,19 +216,14 @@ def save_settings(
def _configure_vectors(self, dataset, settings):
if not settings.vectors:
return
- vectors_cfg = {
- k: v.dim if isinstance(v, EmbeddingsConfig) else int(v)
- for k, v in settings.vectors.items()
- }
+ vectors_cfg = {k: v.dim if isinstance(v, EmbeddingsConfig) else int(v) for k, v in settings.vectors.items()}
self._es.create_dataset(
id=dataset.id,
task=dataset.task,
vectors_cfg=vectors_cfg,
)
- def load_settings(
- self, dataset: DatasetDB, as_class: Type[DatasetSettingsDB]
- ) -> Optional[DatasetSettingsDB]:
+ def load_settings(self, dataset: DatasetDB, as_class: Type[DatasetSettingsDB]) -> Optional[DatasetSettingsDB]:
doc = self._es.find_dataset(id=dataset.id)
if doc and "settings" in doc:
settings = doc["settings"]
diff --git a/src/argilla/server/daos/models/records.py b/src/argilla/server/daos/models/records.py
index 117d39581e..14bb2e96d7 100644
--- a/src/argilla/server/daos/models/records.py
+++ b/src/argilla/server/daos/models/records.py
@@ -60,9 +60,7 @@ class BaseRecordInDB(GenericModel, Generic[AnnotationDB]):
metadata: Dict[str, Any] = Field(default=None)
event_timestamp: Optional[datetime] = None
status: Optional[TaskStatus] = None
- prediction: Optional[AnnotationDB] = Field(
- None, description="Deprecated. Use `predictions` instead"
- )
+ prediction: Optional[AnnotationDB] = Field(None, description="Deprecated. Use `predictions` instead")
annotation: Optional[AnnotationDB] = None
vectors: Optional[Dict[str, BaseEmbeddingVectorDB]] = Field(
@@ -98,21 +96,13 @@ def update_annotation(values, annotation_field: str):
if not annotation.agent:
raise AssertionError("Agent must be defined!")
- annotations.update(
- {
- annotation.agent: annotation.__class__.parse_obj(
- annotation.dict(exclude={"agent"})
- )
- }
- )
+ annotations.update({annotation.agent: annotation.__class__.parse_obj(annotation.dict(exclude={"agent"}))})
values[field_to_update] = annotations
if annotations and not annotation:
# set first annotation
key, value = list(annotations.items())[0]
- values[annotation_field] = value.__class__(
- agent=key, **value.dict(exclude={"agent"})
- )
+ values[annotation_field] = value.__class__(agent=key, **value.dict(exclude={"agent"}))
return values
diff --git a/src/argilla/server/daos/records.py b/src/argilla/server/daos/records.py
index 1ab091d885..83ea922db1 100644
--- a/src/argilla/server/daos/records.py
+++ b/src/argilla/server/daos/records.py
@@ -38,9 +38,7 @@ class DatasetRecordsDAO:
@classmethod
def get_instance(
cls,
- es: GenericElasticEngineBackend = Depends(
- GenericElasticEngineBackend.get_instance
- ),
+ es: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance),
) -> "DatasetRecordsDAO":
"""
Creates a dataset records dao instance
@@ -91,9 +89,7 @@ def add_records(
mapping = self._es.get_schema(id=dataset.id)
exclude_fields = [
- name
- for name in record_class.schema()["properties"]
- if name not in mapping["mappings"]["properties"]
+ name for name in record_class.schema()["properties"] if name not in mapping["mappings"]["properties"]
]
vectors_configuration = {}
@@ -173,9 +169,7 @@ def search_records(
except ClosedIndexError:
raise ClosedDatasetError(dataset.name)
except IndexNotFoundError:
- raise MissingDatasetRecordsError(
- f"No records index found for dataset {dataset.name}"
- )
+ raise MissingDatasetRecordsError(f"No records index found for dataset {dataset.name}")
def scan_dataset(
self,
diff --git a/src/argilla/server/errors/base_errors.py b/src/argilla/server/errors/base_errors.py
index 36ed5bd700..80e8c2018d 100644
--- a/src/argilla/server/errors/base_errors.py
+++ b/src/argilla/server/errors/base_errors.py
@@ -46,11 +46,7 @@ def code(self) -> str:
@property
def arguments(self):
- return (
- {k: v for k, v in vars(self).items() if v is not None}
- if vars(self)
- else None
- )
+ return {k: v for k, v in vars(self).items() if v is not None} if vars(self) else None
def __str__(self):
args = self.arguments or {}
@@ -77,11 +73,7 @@ def __init__(self, error: Exception):
@classmethod
def api_documentation(cls):
return {
- "content": {
- "application/json": {
- "example": {"detail": {"code": "builtins.TypeError"}}
- }
- },
+ "content": {"application/json": {"example": {"detail": {"code": "builtins.TypeError"}}}},
}
diff --git a/src/argilla/server/helpers.py b/src/argilla/server/helpers.py
index c523f918ba..81b6ad6be6 100644
--- a/src/argilla/server/helpers.py
+++ b/src/argilla/server/helpers.py
@@ -22,9 +22,7 @@
_LOGGER = logging.getLogger("argilla.server")
-def unflatten_dict(
- data: Dict[str, Any], sep: str = ".", stop_keys: Optional[List[str]] = None
-) -> Dict[str, Any]:
+def unflatten_dict(data: Dict[str, Any], sep: str = ".", stop_keys: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Given a flat dictionary keys, build a hierarchical version by grouping keys
@@ -57,9 +55,7 @@ def unflatten_dict(
return resultDict
-def flatten_dict(
- data: Dict[str, Any], sep: str = ".", drop_empty: bool = False
-) -> Dict[str, Any]:
+def flatten_dict(data: Dict[str, Any], sep: str = ".", drop_empty: bool = False) -> Dict[str, Any]:
"""
Flatten a data dictionary
diff --git a/src/argilla/server/responses/api_responses.py b/src/argilla/server/responses/api_responses.py
index 157e38a2b1..0056b2cda0 100644
--- a/src/argilla/server/responses/api_responses.py
+++ b/src/argilla/server/responses/api_responses.py
@@ -23,9 +23,7 @@ async def stream_response(self, send: Send) -> None:
try:
return await super().stream_response(send)
except Exception as ex:
- json_response: JSONResponse = (
- await APIErrorHandler.common_exception_handler(send, error=ex)
- )
+ json_response: JSONResponse = await APIErrorHandler.common_exception_handler(send, error=ex)
await send(
{
"type": "http.response.body",
diff --git a/src/argilla/server/routes.py b/src/argilla/server/routes.py
index be28d9fbb7..45fce83d42 100644
--- a/src/argilla/server/routes.py
+++ b/src/argilla/server/routes.py
@@ -34,9 +34,7 @@
)
from argilla.server.errors.base_errors import __ALL__
-api_router = APIRouter(
- responses={error.HTTP_STATUS: error.api_documentation() for error in __ALL__}
-)
+api_router = APIRouter(responses={error.HTTP_STATUS: error.api_documentation() for error in __ALL__})
dependencies = []
diff --git a/src/argilla/server/security/auth_provider/local/provider.py b/src/argilla/server/security/auth_provider/local/provider.py
index 3a8851df79..3658f7c049 100644
--- a/src/argilla/server/security/auth_provider/local/provider.py
+++ b/src/argilla/server/security/auth_provider/local/provider.py
@@ -77,17 +77,11 @@ async def login_for_access_token(
user = self.users.authenticate_user(form_data.username, form_data.password)
if not user:
raise UnauthorizedError()
- access_token_expires = timedelta(
- minutes=self.settings.token_expiration_in_minutes
- )
- access_token = self._create_access_token(
- user.username, expires_delta=access_token_expires
- )
+ access_token_expires = timedelta(minutes=self.settings.token_expiration_in_minutes)
+ access_token = self._create_access_token(user.username, expires_delta=access_token_expires)
return Token(access_token=access_token)
- def _create_access_token(
- self, username: str, expires_delta: Optional[timedelta] = None
- ) -> str:
+ def _create_access_token(self, username: str, expires_delta: Optional[timedelta] = None) -> str:
"""
Creates an access token
@@ -162,9 +156,7 @@ async def get_user(
-------
"""
- user = await self._find_user_by_api_key(
- api_key
- ) or await self._find_user_by_api_key(old_api_key)
+ user = await self._find_user_by_api_key(api_key) or await self._find_user_by_api_key(old_api_key)
if user:
return user
if token:
diff --git a/src/argilla/server/security/auth_provider/local/settings.py b/src/argilla/server/security/auth_provider/local/settings.py
index 43bce77a3e..b2157e1e61 100644
--- a/src/argilla/server/security/auth_provider/local/settings.py
+++ b/src/argilla/server/security/auth_provider/local/settings.py
@@ -43,17 +43,13 @@ class Settings(BaseSettings):
token_api_url: str = "/api/security/token"
default_apikey: str = DEFAULT_API_KEY
- default_password: str = (
- "$2y$12$MPcRR71ByqgSI8AaqgxrMeSdrD4BcxDIdYkr.ePQoKz7wsGK7SAca" # 1234
- )
+ default_password: str = "$2y$12$MPcRR71ByqgSI8AaqgxrMeSdrD4BcxDIdYkr.ePQoKz7wsGK7SAca" # 1234
users_db_file: str = ".users.yml"
@property
def public_oauth_token_url(self):
"""The final public token url used for openapi doc setup"""
- return (
- f"{server_settings.base_url}{helpers.remove_prefix(self.token_api_url,'/')}"
- )
+ return f"{server_settings.base_url}{helpers.remove_prefix(self.token_api_url,'/')}"
class Config:
env_prefix = "ARGILLA_LOCAL_AUTH_"
diff --git a/src/argilla/server/security/auth_provider/local/users/dao.py b/src/argilla/server/security/auth_provider/local/users/dao.py
index 0d6e1025bd..895c6fdf77 100644
--- a/src/argilla/server/security/auth_provider/local/users/dao.py
+++ b/src/argilla/server/security/auth_provider/local/users/dao.py
@@ -34,9 +34,7 @@ def __init__(self, users_file: str):
self.__users__: Dict[str, UserInDB] = {}
try:
with open(users_file) as file:
- user_list = [
- UserInDB(**user_data) for user_data in yaml.safe_load(file)
- ]
+ user_list = [UserInDB(**user_data) for user_data in yaml.safe_load(file)]
self.__users__ = {user.username: user for user in user_list}
except FileNotFoundError:
self.__users__ = {
diff --git a/src/argilla/server/security/model.py b/src/argilla/server/security/model.py
index bb8ac6c972..9f182f5e6c 100644
--- a/src/argilla/server/security/model.py
+++ b/src/argilla/server/security/model.py
@@ -37,8 +37,7 @@ class User(BaseModel):
def check_username(cls, value):
if not re.compile(DATASET_NAME_REGEX_PATTERN).match(value):
raise ValueError(
- "Wrong username. "
- f"The username {value} does not match the pattern {DATASET_NAME_REGEX_PATTERN}"
+ "Wrong username. " f"The username {value} does not match the pattern {DATASET_NAME_REGEX_PATTERN}"
)
return value
@@ -48,8 +47,7 @@ def check_workspace_pattern(cls, workspace: str):
if not workspace:
return workspace
assert WORKSPACE_NAME_PATTERN.match(workspace), (
- "Wrong workspace format. "
- f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}"
+ "Wrong workspace format. " f"Workspace must match pattern {WORKSPACE_NAME_PATTERN.pattern}"
)
return workspace
diff --git a/src/argilla/server/server.py b/src/argilla/server/server.py
index 67e146d670..def025bf53 100644
--- a/src/argilla/server/server.py
+++ b/src/argilla/server/server.py
@@ -65,9 +65,7 @@ def configure_api_exceptions(api: FastAPI):
"""Configures fastapi exception handlers"""
api.exception_handler(EntityNotFoundError)(APIErrorHandler.common_exception_handler)
api.exception_handler(Exception)(APIErrorHandler.common_exception_handler)
- api.exception_handler(RequestValidationError)(
- APIErrorHandler.common_exception_handler
- )
+ api.exception_handler(RequestValidationError)(APIErrorHandler.common_exception_handler)
def configure_api_router(app: FastAPI):
@@ -185,11 +183,7 @@ def configure_telemetry(app):
"""
)
message += "\n\n "
- message += (
- "#set ARGILLA_ENABLE_TELEMETRY=0"
- if os.name == "nt"
- else "$>export ARGILLA_ENABLE_TELEMETRY=0"
- )
+ message += "#set ARGILLA_ENABLE_TELEMETRY=0" if os.name == "nt" else "$>export ARGILLA_ENABLE_TELEMETRY=0"
message += "\n"
@app.on_event("startup")
diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py
index d67914d354..b6af813d4b 100644
--- a/src/argilla/server/services/datasets.py
+++ b/src/argilla/server/services/datasets.py
@@ -38,18 +38,14 @@ class ServiceBaseDatasetSettings(BaseDatasetSettingsDB):
ServiceDataset = TypeVar("ServiceDataset", bound=ServiceBaseDataset)
-ServiceDatasetSettings = TypeVar(
- "ServiceDatasetSettings", bound=ServiceBaseDatasetSettings
-)
+ServiceDatasetSettings = TypeVar("ServiceDatasetSettings", bound=ServiceBaseDatasetSettings)
class DatasetsService:
_INSTANCE: "DatasetsService" = None
@classmethod
- def get_instance(
- cls, dao: DatasetsDAO = Depends(DatasetsDAO.get_instance)
- ) -> "DatasetsService":
+ def get_instance(cls, dao: DatasetsDAO = Depends(DatasetsDAO.get_instance)) -> "DatasetsService":
if not cls._INSTANCE:
cls._INSTANCE = cls(dao)
return cls._INSTANCE
@@ -61,16 +57,10 @@ def create_dataset(self, user: User, dataset: ServiceDataset) -> ServiceDataset:
user.check_workspace(dataset.owner)
try:
- self.find_by_name(
- user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner
- )
- raise EntityAlreadyExistsError(
- name=dataset.name, type=ServiceDataset, workspace=dataset.owner
- )
+ self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner)
+ raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.owner)
except WrongTaskError: # Found a dataset with same name but different task
- raise EntityAlreadyExistsError(
- name=dataset.name, type=ServiceDataset, workspace=dataset.owner
- )
+ raise EntityAlreadyExistsError(name=dataset.name, type=ServiceDataset, workspace=dataset.owner)
except EntityNotFoundError:
# The dataset does not exist -> create it !
date_now = datetime.utcnow()
@@ -117,9 +107,7 @@ def __find_by_name_with_superuser_fallback__(
as_dataset_class: Optional[Type[ServiceDataset]],
task: Optional[str] = None,
):
- found_ds = self.__dao__.find_by_name(
- name=name, owner=owner, task=task, as_dataset_class=as_dataset_class
- )
+ found_ds = self.__dao__.find_by_name(name=name, owner=owner, task=task, as_dataset_class=as_dataset_class)
if not found_ds and user.is_superuser():
try:
found_ds = self.__dao__.find_by_name(
@@ -157,15 +145,11 @@ def update(
tags: Dict[str, str],
metadata: Dict[str, Any],
) -> ServiceDataset:
- found = self.find_by_name(
- user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner
- )
+ found = self.find_by_name(user=user, name=dataset.name, task=dataset.task, workspace=dataset.owner)
dataset.tags = {**found.tags, **(tags or {})}
dataset.metadata = {**found.metadata, **(metadata or {})}
- updated = found.copy(
- update={**dataset.dict(by_alias=True), "last_updated": datetime.utcnow()}
- )
+ updated = found.copy(update={**dataset.dict(by_alias=True), "last_updated": datetime.utcnow()})
return self.__dao__.update_dataset(updated)
def list(
@@ -175,9 +159,7 @@ def list(
task2dataset_map: Dict[str, Type[ServiceDataset]] = None,
) -> List[ServiceDataset]:
owners = user.check_workspaces(workspaces)
- return self.__dao__.list_datasets(
- owner_list=owners, task2dataset_map=task2dataset_map
- )
+ return self.__dao__.list_datasets(owner_list=owners, task2dataset_map=task2dataset_map)
def close(self, user: User, dataset: ServiceDataset):
found = self.find_by_name(user=user, name=dataset.name, workspace=dataset.owner)
diff --git a/src/argilla/server/services/info.py b/src/argilla/server/services/info.py
index 6e8b7c2f0d..d4f64a3cc6 100644
--- a/src/argilla/server/services/info.py
+++ b/src/argilla/server/services/info.py
@@ -72,9 +72,7 @@ class ApiInfoService:
@classmethod
def get_instance(
cls,
- backend: GenericElasticEngineBackend = Depends(
- GenericElasticEngineBackend.get_instance
- ),
+ backend: GenericElasticEngineBackend = Depends(GenericElasticEngineBackend.get_instance),
) -> "ApiInfoService":
"""
Creates an api info service
diff --git a/src/argilla/server/services/search/service.py b/src/argilla/server/services/search/service.py
index b543a50cf7..4aa71bd2fe 100644
--- a/src/argilla/server/services/search/service.py
+++ b/src/argilla/server/services/search/service.py
@@ -83,9 +83,7 @@ def search(
size=size,
record_from=record_from,
exclude_fields=exclude_fields,
- highligth_results=query is not None
- and query.query_text is not None
- and len(query.query_text) > 0,
+ highligth_results=query is not None and query.query_text is not None and len(query.query_text) > 0,
)
metrics_results = {}
for metric in metrics or []:
@@ -98,9 +96,7 @@ def search(
)
metrics_results[metric.id] = metrics_
except Exception as ex:
- self.__LOGGER__.warning(
- "Cannot compute metric [%s]. Error: %s", metric.id, ex
- )
+ self.__LOGGER__.warning("Cannot compute metric [%s]. Error: %s", metric.id, ex)
metrics_results[metric.id] = {}
return ServiceSearchResults(
diff --git a/src/argilla/server/services/storage/service.py b/src/argilla/server/services/storage/service.py
index 703adaab55..9fecae36d6 100644
--- a/src/argilla/server/services/storage/service.py
+++ b/src/argilla/server/services/storage/service.py
@@ -98,9 +98,7 @@ async def delete_records(
"Only dataset creators or administrators can delete datasets"
)
- processed, deleted = await self.__dao__.delete_records_by_query(
- dataset, query=query
- )
+ processed, deleted = await self.__dao__.delete_records_by_query(dataset, query=query)
return DeleteRecordsOut(
processed=processed or 0,
diff --git a/src/argilla/server/services/tasks/commons/models.py b/src/argilla/server/services/tasks/commons/models.py
index a39388230c..6a99db7c7c 100644
--- a/src/argilla/server/services/tasks/commons/models.py
+++ b/src/argilla/server/services/tasks/commons/models.py
@@ -36,9 +36,7 @@ class BulkResponse(BaseModel):
ServiceAnnotation = TypeVar("ServiceAnnotation", bound=ServiceBaseAnnotation)
-class ServiceBaseRecordInputs(
- BaseRecordInDB[ServiceAnnotation], Generic[ServiceAnnotation]
-):
+class ServiceBaseRecordInputs(BaseRecordInDB[ServiceAnnotation], Generic[ServiceAnnotation]):
pass
diff --git a/src/argilla/server/services/tasks/text2text/models.py b/src/argilla/server/services/tasks/text2text/models.py
index f8d1e8438c..d07e0c1002 100644
--- a/src/argilla/server/services/tasks/text2text/models.py
+++ b/src/argilla/server/services/tasks/text2text/models.py
@@ -50,19 +50,11 @@ def all_text(self) -> str:
@property
def predicted_as(self) -> Optional[List[str]]:
- return (
- [sentence.text for sentence in self.prediction.sentences]
- if self.prediction
- else None
- )
+ return [sentence.text for sentence in self.prediction.sentences] if self.prediction else None
@property
def annotated_as(self) -> Optional[List[str]]:
- return (
- [sentence.text for sentence in self.annotation.sentences]
- if self.annotation
- else None
- )
+ return [sentence.text for sentence in self.annotation.sentences] if self.annotation else None
@property
def scores(self) -> List[float]:
diff --git a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py
index 9b0b4871c0..38471bfccc 100644
--- a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py
+++ b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py
@@ -59,9 +59,7 @@ def __init__(self, datasets: DatasetsDAO, records: DatasetRecordsDAO):
self.__records__ = records
# TODO(@frascuchon): Move all rules management methods to the common datasets service like settings
- def list_rules(
- self, dataset: ServiceTextClassificationDataset
- ) -> List[ServiceLabelingRule]:
+ def list_rules(self, dataset: ServiceTextClassificationDataset) -> List[ServiceLabelingRule]:
"""List a set of rules for a given dataset"""
return dataset.rules
@@ -72,9 +70,7 @@ def delete_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str
dataset.rules = new_rules_set
self.__datasets__.update_dataset(dataset)
- def add_rule(
- self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule
- ) -> ServiceLabelingRule:
+ def add_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> ServiceLabelingRule:
"""Adds a rule to a dataset"""
for r in dataset.rules:
if r.query == rule.query:
@@ -105,9 +101,7 @@ def compute_rule_metrics(
LabelingRuleSummary.parse_obj(metric_data),
)
- def _count_annotated_records(
- self, dataset: ServiceTextClassificationDataset
- ) -> int:
+ def _count_annotated_records(self, dataset: ServiceTextClassificationDataset) -> int:
results = self.__records__.search_records(
dataset,
size=0,
@@ -132,18 +126,14 @@ def all_rules_metrics(
DatasetLabelingRulesSummary.parse_obj(metric_data),
)
- def find_rule_by_query(
- self, dataset: ServiceTextClassificationDataset, rule_query: str
- ) -> ServiceLabelingRule:
+ def find_rule_by_query(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule:
rule_query = rule_query.strip()
for rule in dataset.rules:
if rule.query == rule_query:
return rule
raise EntityNotFoundError(rule_query, type=ServiceLabelingRule)
- def replace_rule(
- self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule
- ):
+ def replace_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule):
for idx, r in enumerate(dataset.rules):
if r.query == rule.query:
dataset.rules[idx] = rule
diff --git a/src/argilla/server/services/tasks/text_classification/metrics.py b/src/argilla/server/services/tasks/text_classification/metrics.py
index 93c2a856dc..d09891620b 100644
--- a/src/argilla/server/services/tasks/text_classification/metrics.py
+++ b/src/argilla/server/services/tasks/text_classification/metrics.py
@@ -41,11 +41,7 @@ class F1Metric(ServicePythonMetric):
def apply(self, records: Iterable[ServiceTextClassificationRecord]) -> Any:
filtered_records = list(filter(lambda r: r.predicted is not None, records))
# TODO: This must be precomputed with using a global dataset metric
- ds_labels = {
- label
- for record in filtered_records
- for label in record.annotated_as + record.predicted_as
- }
+ ds_labels = {label for record in filtered_records for label in record.annotated_as + record.predicted_as}
if not len(ds_labels):
return {}
@@ -69,12 +65,8 @@ def apply(self, records: Iterable[ServiceTextClassificationRecord]) -> Any:
y_true = mlb.fit_transform(y_true)
y_pred = mlb.fit_transform(y_pred)
- micro_p, micro_r, micro_f, _ = precision_recall_fscore_support(
- y_true=y_true, y_pred=y_pred, average="micro"
- )
- macro_p, macro_r, macro_f, _ = precision_recall_fscore_support(
- y_true=y_true, y_pred=y_pred, average="macro"
- )
+ micro_p, micro_r, micro_f, _ = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, average="micro")
+ macro_p, macro_r, macro_f, _ = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, average="macro")
per_label = {}
for label, p, r, f, s in zip(
@@ -127,21 +119,15 @@ def apply(
records: Iterable[ServiceTextClassificationRecord],
) -> Dict[str, Any]:
ds_labels = set()
- for _ in range(
- 0, self.records_to_fetch
- ): # Only a few of records will be parsed
+ for _ in range(0, self.records_to_fetch): # Only a few of records will be parsed
record = next(records, None)
if record is None:
break
if record.annotation:
- ds_labels.update(
- [label.class_label for label in record.annotation.labels]
- )
+ ds_labels.update([label.class_label for label in record.annotation.labels])
if record.prediction:
- ds_labels.update(
- [label.class_label for label in record.prediction.labels]
- )
+ ds_labels.update([label.class_label for label in record.prediction.labels])
return {"labels": ds_labels or []}
diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py
index a94a2632a4..73c5cd7d05 100644
--- a/src/argilla/server/services/tasks/text_classification/model.py
+++ b/src/argilla/server/services/tasks/text_classification/model.py
@@ -36,21 +36,14 @@ class ServiceLabelingRule(BaseModel):
query: str = Field(description="The es rule query")
author: str = Field(description="User who created the rule")
- created_at: Optional[datetime] = Field(
- default_factory=datetime.utcnow, description="Rule creation timestamp"
- )
+ created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="Rule creation timestamp")
- label: Optional[str] = Field(
- default=None, description="@Deprecated::The label associated with the rule."
- )
+ label: Optional[str] = Field(default=None, description="@Deprecated::The label associated with the rule.")
labels: List[str] = Field(
default_factory=list,
- description="For multi label problems, a list of labels. "
- "It will replace the `label` field",
- )
- description: Optional[str] = Field(
- None, description="A brief description of the rule"
+ description="For multi label problems, a list of labels. " "It will replace the `label` field",
)
+ description: Optional[str] = Field(None, description="A brief description of the rule")
@root_validator
def initialize_labels(cls, values):
@@ -198,19 +191,13 @@ def _check_annotation_integrity(
status: TaskStatus,
):
if status == TaskStatus.validated and not multi_label:
- assert (
- annotation and len(annotation.labels) > 0
- ), "Annotation must include some label for validated records"
+ assert annotation and len(annotation.labels) > 0, "Annotation must include some label for validated records"
if not multi_label and annotation:
- assert (
- len(annotation.labels) == 1
- ), "Single label record must include only one annotation label"
+ assert len(annotation.labels) == 1, "Single label record must include only one annotation label"
@classmethod
- def _check_score_integrity(
- cls, prediction: TextClassificationAnnotation, multi_label: bool
- ):
+ def _check_score_integrity(cls, prediction: TextClassificationAnnotation, multi_label: bool):
"""
Checks the score value integrity
@@ -235,24 +222,16 @@ def task(cls) -> TaskType:
@property
def predicted(self) -> Optional[PredictionStatus]:
if self.predicted_by and self.annotated_by:
- return (
- PredictionStatus.OK
- if set(self.predicted_as) == set(self.annotated_as)
- else PredictionStatus.KO
- )
+ return PredictionStatus.OK if set(self.predicted_as) == set(self.annotated_as) else PredictionStatus.KO
return None
@property
def predicted_as(self) -> List[str]:
- return self._labels_from_annotation(
- self.prediction, multi_label=self.multi_label
- )
+ return self._labels_from_annotation(self.prediction, multi_label=self.multi_label)
@property
def annotated_as(self) -> List[str]:
- return self._labels_from_annotation(
- self.annotation, multi_label=self.multi_label
- )
+ return self._labels_from_annotation(self.annotation, multi_label=self.multi_label)
@property
def scores(self) -> List[float]:
@@ -263,11 +242,7 @@ def scores(self) -> List[float]:
if self.multi_label
else [
prediction_class.score
- for prediction_class in [
- self._max_class_prediction(
- self.prediction, multi_label=self.multi_label
- )
- ]
+ for prediction_class in [self._max_class_prediction(self.prediction, multi_label=self.multi_label)]
if prediction_class
]
)
@@ -303,22 +278,16 @@ def _labels_from_annotation(
return []
if multi_label:
- return [
- label.class_label for label in annotation.labels if label.score > 0.5
- ]
+ return [label.class_label for label in annotation.labels if label.score > 0.5]
- class_prediction = cls._max_class_prediction(
- annotation, multi_label=multi_label
- )
+ class_prediction = cls._max_class_prediction(annotation, multi_label=multi_label)
if class_prediction is None:
return []
return [class_prediction.class_label]
@staticmethod
- def _max_class_prediction(
- p: TextClassificationAnnotation, multi_label: bool
- ) -> Optional[ClassPrediction]:
+ def _max_class_prediction(p: TextClassificationAnnotation, multi_label: bool) -> Optional[ClassPrediction]:
if multi_label or p is None or not p.labels:
return None
return p.labels[0]
diff --git a/src/argilla/server/services/tasks/text_classification/service.py b/src/argilla/server/services/tasks/text_classification/service.py
index e6351d13e9..8a5407258c 100644
--- a/src/argilla/server/services/tasks/text_classification/service.py
+++ b/src/argilla/server/services/tasks/text_classification/service.py
@@ -191,16 +191,13 @@ def _check_multi_label_integrity(
is_multi_label_dataset = self._is_dataset_multi_label(dataset)
if is_multi_label_dataset is not None:
is_multi_label = records[0].multi_label
- assert is_multi_label == is_multi_label_dataset, (
- "You cannot pass {labels_type} records for this dataset. "
- "Stored records are {labels_type}".format(
- labels_type="multi-label" if is_multi_label else "single-label"
- )
+ assert (
+ is_multi_label == is_multi_label_dataset
+ ), "You cannot pass {labels_type} records for this dataset. " "Stored records are {labels_type}".format(
+ labels_type="multi-label" if is_multi_label else "single-label"
)
- def _is_dataset_multi_label(
- self, dataset: ServiceTextClassificationDataset
- ) -> Optional[bool]:
+ def _is_dataset_multi_label(self, dataset: ServiceTextClassificationDataset) -> Optional[bool]:
try:
results = self.__search__.search(
dataset,
@@ -212,14 +209,10 @@ def _is_dataset_multi_label(
if results.records:
return results.records[0].multi_label
- def get_labeling_rules(
- self, dataset: ServiceTextClassificationDataset
- ) -> Iterable[ServiceLabelingRule]:
+ def get_labeling_rules(self, dataset: ServiceTextClassificationDataset) -> Iterable[ServiceLabelingRule]:
return self.__labeling__.list_rules(dataset)
- def add_labeling_rule(
- self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule
- ) -> None:
+ def add_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> None:
"""
Adds a labeling rule
@@ -252,14 +245,10 @@ def update_labeling_rule(
self.__labeling__.replace_rule(dataset, found_rule)
return found_rule
- def find_labeling_rule(
- self, dataset: ServiceTextClassificationDataset, rule_query: str
- ) -> ServiceLabelingRule:
+ def find_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule:
return self.__labeling__.find_rule_by_query(dataset, rule_query=rule_query)
- def delete_labeling_rule(
- self, dataset: ServiceTextClassificationDataset, rule_query: str
- ):
+ def delete_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str):
if rule_query.strip():
return self.__labeling__.delete_rule(dataset, rule_query)
@@ -309,9 +298,7 @@ def compute_rule_metrics(
)
coverage = metrics.covered_records / total if total > 0 else None
- coverage_annotated = (
- metrics.annotated_covered_records / annotated if annotated > 0 else None
- )
+ coverage_annotated = metrics.annotated_covered_records / annotated if annotated > 0 else None
return LabelingRuleMetricsSummary(
total_records=total,
@@ -326,9 +313,7 @@ def compute_rule_metrics(
def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDataset):
total, annotated, metrics = self.__labeling__.all_rules_metrics(dataset)
coverage = metrics.covered_records / total if total else None
- coverage_annotated = (
- metrics.annotated_covered_records / annotated if annotated else None
- )
+ coverage_annotated = metrics.annotated_covered_records / annotated if annotated else None
return DatasetLabelingRulesMetricsSummary(
coverage=coverage,
coverage_annotated=coverage_annotated,
diff --git a/src/argilla/server/services/tasks/token_classification/metrics.py b/src/argilla/server/services/tasks/token_classification/metrics.py
index e210b0960f..9169229b18 100644
--- a/src/argilla/server/services/tasks/token_classification/metrics.py
+++ b/src/argilla/server/services/tasks/token_classification/metrics.py
@@ -36,9 +36,7 @@ class F1Metric(ServicePythonMetric[ServiceTokenClassificationRecord]):
A named entity is correct only if it is an exact match (...).”`
"""
- def apply(
- self, records: Iterable[ServiceTokenClassificationRecord]
- ) -> Dict[str, Any]:
+ def apply(self, records: Iterable[ServiceTokenClassificationRecord]) -> Dict[str, Any]:
# store entities per label in dicts
predicted_entities = {}
annotated_entities = {}
@@ -66,9 +64,7 @@ def apply(
{
f"{label}_precision": precision,
f"{label}_recall": recall,
- f"{label}_f1": self._safe_divide(
- 2 * precision * recall, precision + recall
- ),
+ f"{label}_f1": self._safe_divide(2 * precision * recall, precision + recall),
}
)
@@ -83,9 +79,7 @@ def apply(
averaged_metrics = {
"precision_macro": precision_macro,
"recall_macro": recall_macro,
- "f1_macro": self._safe_divide(
- 2 * precision_macro * recall_macro, precision_macro + recall_macro
- ),
+ "f1_macro": self._safe_divide(2 * precision_macro * recall_macro, precision_macro + recall_macro),
}
precision_micro = self._safe_divide(correct_total, predicted_total)
@@ -94,18 +88,14 @@ def apply(
{
"precision_micro": precision_micro,
"recall_micro": recall_micro,
- "f1_micro": self._safe_divide(
- 2 * precision_micro * recall_micro, precision_micro + recall_micro
- ),
+ "f1_micro": self._safe_divide(2 * precision_micro * recall_micro, precision_micro + recall_micro),
}
)
return {**averaged_metrics, **per_label_metrics}
@staticmethod
- def _add_entities_to_dict(
- entities: List[EntitySpan], dictionary: Dict[str, Set[Tuple[int, int]]]
- ):
+ def _add_entities_to_dict(entities: List[EntitySpan], dictionary: Dict[str, Set[Tuple[int, int]]]):
"""Helper function for the apply method."""
for ent in entities:
try:
@@ -144,21 +134,15 @@ def apply(
) -> Dict[str, Any]:
ds_labels = set()
- for _ in range(
- 0, self.records_to_fetch
- ): # Only a few of records will be parsed
+ for _ in range(0, self.records_to_fetch): # Only a few of records will be parsed
record: ServiceTokenClassificationRecord = next(records, None)
if record is None:
break
if record.annotation:
- ds_labels.update(
- [entity.label for entity in record.annotation.entities]
- )
+ ds_labels.update([entity.label for entity in record.annotation.entities])
if record.prediction:
- ds_labels.update(
- [entity.label for entity in record.prediction.entities]
- )
+ ds_labels.update([entity.label for entity in record.prediction.entities])
return {"labels": ds_labels or []}
@@ -252,9 +236,7 @@ def mention_tokens_length(entity: EntitySpan) -> int:
label=entity.label,
score=entity.score,
capitalness=TokenClassificationMetrics.capitalness(mention),
- density=TokenClassificationMetrics.density(
- _tokens_length, sentence_length=len(record.tokens)
- ),
+ density=TokenClassificationMetrics.density(_tokens_length, sentence_length=len(record.tokens)),
tokens_length=_tokens_length,
chars_length=len(mention),
)
@@ -295,26 +277,18 @@ def record_metrics(cls, record: ServiceTokenClassificationRecord) -> Dict[str, A
annotated_tags = cls._compute_iob_tags(span_utils, record.annotation) or []
predicted_tags = cls._compute_iob_tags(span_utils, record.prediction) or []
- tokens_metrics = cls.build_tokens_metrics(
- record, predicted_tags or annotated_tags
- )
+ tokens_metrics = cls.build_tokens_metrics(record, predicted_tags or annotated_tags)
return {
**base_metrics,
"tokens": tokens_metrics,
"tokens_length": len(record.tokens),
"predicted": {
"mentions": cls.mentions_metrics(record, record.predicted_mentions()),
- "tags": [
- TokenTagMetrics(tag=tag, value=token)
- for tag, token in zip(predicted_tags, record.tokens)
- ],
+ "tags": [TokenTagMetrics(tag=tag, value=token) for tag, token in zip(predicted_tags, record.tokens)],
},
"annotated": {
"mentions": cls.mentions_metrics(record, record.annotated_mentions()),
- "tags": [
- TokenTagMetrics(tag=tag, value=token)
- for tag, token in zip(annotated_tags, record.tokens)
- ],
+ "tags": [TokenTagMetrics(tag=tag, value=token) for tag, token in zip(annotated_tags, record.tokens)],
},
}
diff --git a/src/argilla/server/services/tasks/token_classification/model.py b/src/argilla/server/services/tasks/token_classification/model.py
index 926783ecbd..ce01b96887 100644
--- a/src/argilla/server/services/tasks/token_classification/model.py
+++ b/src/argilla/server/services/tasks/token_classification/model.py
@@ -74,9 +74,7 @@ class ServiceTokenClassificationAnnotation(ServiceBaseAnnotation):
score: Optional[float] = None
-class ServiceTokenClassificationRecord(
- ServiceBaseRecord[ServiceTokenClassificationAnnotation]
-):
+class ServiceTokenClassificationRecord(ServiceBaseRecord[ServiceTokenClassificationAnnotation]):
tokens: List[str] = Field(min_items=1)
text: str = Field()
_raw_text: Optional[str] = Field(alias="raw_text")
@@ -94,8 +92,7 @@ def extended_fields(self) -> Dict[str, Any]:
for mention, entity in self.predicted_mentions()
],
MENTIONS_ES_FIELD_NAME: [
- {"mention": mention, "entity": entity.label}
- for mention, entity in self.annotated_mentions()
+ {"mention": mention, "entity": entity.label} for mention, entity in self.annotated_mentions()
],
}
@@ -148,11 +145,7 @@ def predicted(self) -> Optional[PredictionStatus]:
return PredictionStatus.KO
for ann, pred in zip(annotated_entities, predicted_entities):
- if (
- ann.start != pred.start
- or ann.end != pred.end
- or ann.label != pred.label
- ):
+ if ann.start != pred.start or ann.end != pred.end or ann.label != pred.label:
return PredictionStatus.KO
return PredictionStatus.OK
@@ -179,18 +172,12 @@ def all_text(self) -> str:
def predicted_mentions(self) -> List[Tuple[str, EntitySpan]]:
return [
- (mention, entity)
- for mention, entity in self.__mentions_from_entities__(
- self.predicted_entities()
- ).items()
+ (mention, entity) for mention, entity in self.__mentions_from_entities__(self.predicted_entities()).items()
]
def annotated_mentions(self) -> List[Tuple[str, EntitySpan]]:
return [
- (mention, entity)
- for mention, entity in self.__mentions_from_entities__(
- self.annotated_entities()
- ).items()
+ (mention, entity) for mention, entity in self.__mentions_from_entities__(self.annotated_entities()).items()
]
def annotated_entities(self) -> Set[EntitySpan]:
@@ -205,14 +192,8 @@ def predicted_entities(self) -> Set[EntitySpan]:
return set()
return set(self.prediction.entities)
- def __mentions_from_entities__(
- self, entities: Set[EntitySpan]
- ) -> Dict[str, EntitySpan]:
- return {
- mention: entity
- for entity in entities
- for mention in [self.text[entity.start : entity.end]]
- }
+ def __mentions_from_entities__(self, entities: Set[EntitySpan]) -> Dict[str, EntitySpan]:
+ return {mention: entity for entity in entities for mention in [self.text[entity.start : entity.end]]}
class Config:
allow_population_by_field_name = True
diff --git a/src/argilla/server/services/tasks/token_classification/service.py b/src/argilla/server/services/tasks/token_classification/service.py
index 534ff6920b..50ef8dae2d 100644
--- a/src/argilla/server/services/tasks/token_classification/service.py
+++ b/src/argilla/server/services/tasks/token_classification/service.py
@@ -131,12 +131,8 @@ def search(
results.metrics["status"] = results.metrics["status_distribution"]
results.metrics["predicted"] = results.metrics["error_distribution"]
results.metrics["predicted"].pop("unknown", None)
- results.metrics["mentions"] = results.metrics[
- "annotated_mentions_distribution"
- ]
- results.metrics["predicted_mentions"] = results.metrics[
- "predicted_mentions_distribution"
- ]
+ results.metrics["mentions"] = results.metrics["annotated_mentions_distribution"]
+ results.metrics["predicted_mentions"] = results.metrics["predicted_mentions_distribution"]
return results
diff --git a/src/argilla/utils/span_utils.py b/src/argilla/utils/span_utils.py
index bed9e172f7..208450a24b 100644
--- a/src/argilla/utils/span_utils.py
+++ b/src/argilla/utils/span_utils.py
@@ -100,9 +100,7 @@ def validate(self, spans: List[Tuple[str, int, int]]):
if not_valid_spans_errors or misaligned_spans_errors:
message = ""
if not_valid_spans_errors:
- message += (
- f"Following entity spans are not valid: {not_valid_spans_errors}\n"
- )
+ message += f"Following entity spans are not valid: {not_valid_spans_errors}\n"
if misaligned_spans_errors:
spans = "\n".join(misaligned_spans_errors)
@@ -191,9 +189,7 @@ def get_prefix_and_entity(tag_str: str) -> Tuple[str, Optional[str]]:
return splits[0], "-".join(splits[1:])
if len(tags) != len(self.tokens):
- raise ValueError(
- "The list of tags must have the same length as the list of tokens!"
- )
+ raise ValueError("The list of tags must have the same length as the list of tokens!")
spans, start_idx = [], None
for idx, tag in enumerate(tags):
diff --git a/src/argilla/utils/utils.py b/src/argilla/utils/utils.py
index e45e3fa4d9..902d440eca 100644
--- a/src/argilla/utils/utils.py
+++ b/src/argilla/utils/utils.py
@@ -44,9 +44,7 @@ def __init__(
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
- self.__all__ = list(import_structure.keys()) + list(
- chain(*import_structure.values())
- )
+ self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
@@ -83,9 +81,7 @@ def __getattr__(self, name: str) -> Any:
elif name in self._deprecated_modules:
value = self._get_module(name, deprecated=True)
elif name in self._deprecated_class_to_module.keys():
- module = self._get_module(
- self._deprecated_class_to_module[name], deprecated=True, class_name=name
- )
+ module = self._get_module(self._deprecated_class_to_module[name], deprecated=True, class_name=name)
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
@@ -139,9 +135,7 @@ def limit_value_length(data: Any, max_length: int) -> Any:
if isinstance(data, str):
return data[-max_length:]
if isinstance(data, dict):
- return {
- k: limit_value_length(v, max_length=max_length) for k, v in data.items()
- }
+ return {k: limit_value_length(v, max_length=max_length) for k, v in data.items()}
if isinstance(data, (list, tuple, set)):
new_values = map(lambda x: limit_value_length(x, max_length=max_length), data)
return type(data)(new_values)
@@ -167,9 +161,7 @@ def start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
if not (__LOOP__ and __THREAD__):
loop = asyncio.new_event_loop()
- thread = threading.Thread(
- target=start_background_loop, args=(loop,), daemon=True
- )
+ thread = threading.Thread(target=start_background_loop, args=(loop,), daemon=True)
thread.start()
__LOOP__, __THREAD__ = loop, thread
return __LOOP__, __THREAD__
diff --git a/tests/client/conftest.py b/tests/client/conftest.py
index ab95abdd6f..f80b838c10 100644
--- a/tests/client/conftest.py
+++ b/tests/client/conftest.py
@@ -87,13 +87,7 @@ def singlelabel_textclassification_records(
id=1,
event_timestamp=datetime.datetime(2000, 1, 1),
metadata={"mock_metadata": "mock"},
- explanation={
- "text": [
- ar.TokenAttributions(
- token="mock", attributions={"a": 0.1, "b": 0.5}
- )
- ]
- },
+ explanation={"text": [ar.TokenAttributions(token="mock", attributions={"a": 0.1, "b": 0.5})]},
status="Validated",
),
ar.TextClassificationRecord(
@@ -103,13 +97,7 @@ def singlelabel_textclassification_records(
id=2,
event_timestamp=datetime.datetime(2000, 2, 1),
metadata={"mock2_metadata": "mock2"},
- explanation={
- "text": [
- ar.TokenAttributions(
- token="mock2", attributions={"a": 0.7, "b": 0.2}
- )
- ]
- },
+ explanation={"text": [ar.TokenAttributions(token="mock2", attributions={"a": 0.7, "b": 0.2})]},
status="Default",
),
ar.TextClassificationRecord(
@@ -152,8 +140,7 @@ def log_singlelabel_textclassification_records(
"multi_label": False,
},
records=[
- CreationTextClassificationRecord.from_client(rec)
- for rec in singlelabel_textclassification_records
+ CreationTextClassificationRecord.from_client(rec) for rec in singlelabel_textclassification_records
],
).dict(by_alias=True),
)
@@ -174,13 +161,7 @@ def multilabel_textclassification_records(request) -> List[ar.TextClassification
id=1,
event_timestamp=datetime.datetime(2000, 1, 1),
metadata={"mock_metadata": "mock"},
- explanation={
- "text": [
- ar.TokenAttributions(
- token="mock", attributions={"a": 0.1, "b": 0.5}
- )
- ]
- },
+ explanation={"text": [ar.TokenAttributions(token="mock", attributions={"a": 0.1, "b": 0.5})]},
status="Validated",
),
ar.TextClassificationRecord(
@@ -191,13 +172,7 @@ def multilabel_textclassification_records(request) -> List[ar.TextClassification
id=2,
event_timestamp=datetime.datetime(2000, 2, 1),
metadata={"mock2_metadata": "mock2"},
- explanation={
- "text": [
- ar.TokenAttributions(
- token="mock2", attributions={"a": 0.7, "b": 0.2}
- )
- ]
- },
+ explanation={"text": [ar.TokenAttributions(token="mock2", attributions={"a": 0.7, "b": 0.2})]},
status="Default",
),
ar.TextClassificationRecord(
@@ -245,8 +220,7 @@ def log_multilabel_textclassification_records(
"multi_label": True,
},
records=[
- CreationTextClassificationRecord.from_client(rec)
- for rec in multilabel_textclassification_records
+ CreationTextClassificationRecord.from_client(rec) for rec in multilabel_textclassification_records
],
).dict(by_alias=True),
)
@@ -321,10 +295,7 @@ def log_tokenclassification_records(
"env": "test",
"task": TaskType.token_classification,
},
- records=[
- CreationTokenClassificationRecord.from_client(rec)
- for rec in tokenclassification_records
- ],
+ records=[CreationTokenClassificationRecord.from_client(rec) for rec in tokenclassification_records],
).dict(by_alias=True),
)
@@ -395,9 +366,7 @@ def log_text2text_records(
"env": "test",
"task": TaskType.text2text,
},
- records=[
- CreationText2TextRecord.from_client(rec) for rec in text2text_records
- ],
+ records=[CreationText2TextRecord.from_client(rec) for rec in text2text_records],
).dict(by_alias=True),
)
diff --git a/tests/client/sdk/commons/api.py b/tests/client/sdk/commons/api.py
index b7ca635486..554a03611c 100644
--- a/tests/client/sdk/commons/api.py
+++ b/tests/client/sdk/commons/api.py
@@ -81,13 +81,9 @@ def test_build_bulk_response(status_code, expected):
elif status_code == 500:
server_response = ErrorMessage(detail="test")
elif status_code == 422:
- server_response = HTTPValidationError(
- detail=[ValidationError(loc=["test"], msg="test", type="test")]
- )
+ server_response = HTTPValidationError(detail=[ValidationError(loc=["test"], msg="test", type="test")])
- httpx_response = HttpxResponse(
- status_code=status_code, content=server_response.json()
- )
+ httpx_response = HttpxResponse(status_code=status_code, content=server_response.json())
response = build_bulk_response(httpx_response, name="mock-dataset", body={})
assert isinstance(response, Response)
@@ -112,9 +108,7 @@ def test_build_data_response(status_code, expected):
elif status_code == 500:
server_response = ErrorMessage(detail="test")
elif status_code == 422:
- server_response = HTTPValidationError(
- detail=[ValidationError(loc=["test"], msg="test", type="test")]
- )
+ server_response = HTTPValidationError(detail=[ValidationError(loc=["test"], msg="test", type="test")])
httpx_response = HttpxResponse(
status_code=status_code,
diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py
index 75ef26042b..68b91ffe0f 100644
--- a/tests/client/sdk/conftest.py
+++ b/tests/client/sdk/conftest.py
@@ -61,8 +61,7 @@ def check_schema_props(client_props, server_props):
continue
if name not in server_props:
LOGGER.warning(
- f"Client property {name} not found in server properties. "
- "Make sure your API compatibility"
+ f"Client property {name} not found in server properties. " "Make sure your API compatibility"
)
different_props.append(name)
continue
@@ -105,9 +104,7 @@ def _expands_schema(
expanded_props = self._expands_schema(field_props, definitions)
definition["items"] = expanded_props.get("properties", expanded_props)
new_schema[name] = definition
- elif "additionalProperties" in definition and "$ref" in definition.get(
- "additionalProperties", {}
- ):
+ elif "additionalProperties" in definition and "$ref" in definition.get("additionalProperties", {}):
additionalProperties_refs = self._expands_schema(
{name: definition["additionalProperties"]},
definitions=definitions,
@@ -116,9 +113,7 @@ def _expands_schema(
elif "allOf" in definition:
allOf_expanded = [
self._expands_schema(
- definitions[def_["$ref"].replace("#/definitions/", "")].get(
- "properties", {}
- ),
+ definitions[def_["$ref"].replace("#/definitions/", "")].get("properties", {}),
definitions,
)
for def_ in definition["allOf"]
@@ -140,18 +135,14 @@ def helpers():
@pytest.fixture
def sdk_client(mocked_client, monkeypatch):
- client = AuthenticatedClient(
- base_url="http://localhost:6900", token=DEFAULT_API_KEY
- )
+ client = AuthenticatedClient(base_url="http://localhost:6900", token=DEFAULT_API_KEY)
monkeypatch.setattr(client, "__httpx__", mocked_client)
return client
@pytest.fixture
def bulk_textclass_data():
- explanation = {
- "text": [ar.TokenAttributions(token="test", attributions={"test": 0.5})]
- }
+ explanation = {"text": [ar.TokenAttributions(token="test", attributions={"test": 0.5})]}
records = [
ar.TextClassificationRecord(
text="test",
diff --git a/tests/client/sdk/datasets/test_models.py b/tests/client/sdk/datasets/test_models.py
index c490db0e9b..a1281e782b 100644
--- a/tests/client/sdk/datasets/test_models.py
+++ b/tests/client/sdk/datasets/test_models.py
@@ -19,13 +19,9 @@
def test_dataset_schema(helpers):
client_schema = Dataset.schema()
- server_schema = helpers.remove_key(
- ServerDataset.schema(), key="created_by"
- ) # don't care about creator here
+ server_schema = helpers.remove_key(ServerDataset.schema(), key="created_by") # don't care about creator here
- assert helpers.remove_description(client_schema) == helpers.remove_description(
- server_schema
- )
+ assert helpers.remove_description(client_schema) == helpers.remove_description(server_schema)
def test_TaskType_enum():
diff --git a/tests/client/sdk/text2text/test_models.py b/tests/client/sdk/text2text/test_models.py
index 57f2e70178..0783606f3e 100644
--- a/tests/client/sdk/text2text/test_models.py
+++ b/tests/client/sdk/text2text/test_models.py
@@ -65,9 +65,7 @@ def test_from_client_prediction(prediction, expected):
sdk_record = CreationText2TextRecord.from_client(record)
assert len(sdk_record.prediction.sentences) == len(prediction)
- assert all(
- [sentence.score == expected for sentence in sdk_record.prediction.sentences]
- )
+ assert all([sentence.score == expected for sentence in sdk_record.prediction.sentences])
assert sdk_record.metrics == {}
diff --git a/tests/client/sdk/text_classification/test_models.py b/tests/client/sdk/text_classification/test_models.py
index 9bebb747c5..583483a951 100644
--- a/tests/client/sdk/text_classification/test_models.py
+++ b/tests/client/sdk/text_classification/test_models.py
@@ -68,15 +68,11 @@ def test_labeling_rule_metrics_schema(helpers):
client_schema = LabelingRuleMetricsSummary.schema()
server_schema = ServerLabelingRuleMetricsSummary.schema()
- assert helpers.remove_description(client_schema) == helpers.remove_description(
- server_schema
- )
+ assert helpers.remove_description(client_schema) == helpers.remove_description(server_schema)
def test_from_client_explanation():
- token_attributions = [
- TokenAttributions(token="test", attributions={"label1": 1.0, "label2": 2.0})
- ]
+ token_attributions = [TokenAttributions(token="test", attributions={"label1": 1.0, "label2": 2.0})]
record = TextClassificationRecord(
inputs={"text": "test"},
prediction=[("label1", 0.5), ("label2", 0.5)],
@@ -90,9 +86,7 @@ def test_from_client_explanation():
assert sdk_record.explanation["text"] == token_attributions
-@pytest.mark.parametrize(
- "annotation,expected", [("label1", 1), (["label1", "label2"], 2)]
-)
+@pytest.mark.parametrize("annotation,expected", [("label1", 1), (["label1", "label2"], 2)])
def test_from_client_annotation(annotation, expected):
record = TextClassificationRecord(
inputs={"text": "test"},
@@ -127,13 +121,9 @@ def test_from_client_agent(pred_agent, annot_agent, pred_expected, annot_expecte
assert sdk_record.metrics == {}
-@pytest.mark.parametrize(
- "multi_label,expected", [(False, "annot_label"), (True, ["annot_label"])]
-)
+@pytest.mark.parametrize("multi_label,expected", [(False, "annot_label"), (True, ["annot_label"])])
def test_to_client(multi_label, expected):
- annotation = TextClassificationAnnotation(
- labels=[ClassPrediction(**{"class": "annot_label"})], agent="annot_agent"
- )
+ annotation = TextClassificationAnnotation(labels=[ClassPrediction(**{"class": "annot_label"})], agent="annot_agent")
prediction = TextClassificationAnnotation(
labels=[
ClassPrediction(**{"class": "label1", "score": 0.5}),
diff --git a/tests/client/sdk/users/test_api.py b/tests/client/sdk/users/test_api.py
index 8eb8f2b3c7..a4a1f211c9 100644
--- a/tests/client/sdk/users/test_api.py
+++ b/tests/client/sdk/users/test_api.py
@@ -26,15 +26,11 @@ def test_whoami(mocked_client, sdk_client):
def test_whoami_with_auth_error(monkeypatch, mocked_client):
with pytest.raises(UnauthorizedApiError):
- sdk_client = AuthenticatedClient(
- base_url="http://localhost:6900", token="wrong-apikey"
- )
+ sdk_client = AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey")
monkeypatch.setattr(sdk_client, "__httpx__", mocked_client)
whoami(sdk_client)
def test_whoami_with_connection_error():
with pytest.raises(BaseClientError):
- whoami(
- AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey")
- )
+ whoami(AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey"))
diff --git a/tests/client/test_api.py b/tests/client/test_api.py
index 12e66b634c..4c7ec4bb44 100644
--- a/tests/client/test_api.py
+++ b/tests/client/test_api.py
@@ -166,9 +166,7 @@ def test_log_something(monkeypatch, mocked_client):
assert response.processed == 1
assert response.failed == 0
- response = mocked_client.post(
- f"/api/datasets/{dataset_name}/TextClassification:search"
- )
+ response = mocked_client.post(f"/api/datasets/{dataset_name}/TextClassification:search")
assert response.status_code == 200, response.json()
results = TextClassificationSearchResults.parse_obj(response.json())
@@ -200,9 +198,7 @@ def test_load_limits(mocked_client, supported_vector_search):
def test_log_records_with_too_long_text(mocked_client):
dataset_name = "test_log_records_with_too_long_text"
mocked_client.delete(f"/api/datasets/{dataset_name}")
- item = ar.TextClassificationRecord(
- inputs={"text": "This is a toooooo long text\n" * 10000}
- )
+ item = ar.TextClassificationRecord(inputs={"text": "This is a toooooo long text\n" * 10000})
api.log([item], name=dataset_name)
@@ -218,9 +214,7 @@ def test_log_without_name(mocked_client):
match="Empty dataset name has been passed as argument.",
):
api.log(
- ar.TextClassificationRecord(
- inputs={"text": "This is a single record. Only this. No more."}
- ),
+ ar.TextClassificationRecord(inputs={"text": "This is a single record. Only this. No more."}),
name=None,
)
@@ -311,9 +305,7 @@ def inner(*args, **kwargs):
return inner
with pytest.raises(error_type):
- monkeypatch.setattr(
- httpx, "delete", send_mock_response_with_http_status(status)
- )
+ monkeypatch.setattr(httpx, "delete", send_mock_response_with_http_status(status))
api.delete("dataset")
diff --git a/tests/client/test_dataset.py b/tests/client/test_dataset.py
index 83ae429af0..404b2768b4 100644
--- a/tests/client/test_dataset.py
+++ b/tests/client/test_dataset.py
@@ -50,9 +50,7 @@ def test_init_NotImplementedError(self):
DatasetBase()
def test_init(self, monkeypatch, singlelabel_textclassification_records):
- monkeypatch.setattr(
- "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord
- )
+ monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord)
ds = DatasetBase(
records=singlelabel_textclassification_records,
@@ -62,9 +60,7 @@ def test_init(self, monkeypatch, singlelabel_textclassification_records):
ds = DatasetBase()
assert ds._records == []
- with pytest.raises(
- WrongRecordTypeError, match="but you provided Text2TextRecord"
- ):
+ with pytest.raises(WrongRecordTypeError, match="but you provided Text2TextRecord"):
DatasetBase(
records=[ar.Text2TextRecord(text="test")],
)
@@ -90,9 +86,7 @@ def test_init(self, monkeypatch, singlelabel_textclassification_records):
ds._from_pandas("mock")
def test_to_dataframe(self, monkeypatch, singlelabel_textclassification_records):
- monkeypatch.setattr(
- "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord
- )
+ monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord)
df = DatasetBase(singlelabel_textclassification_records).to_pandas()
@@ -101,9 +95,7 @@ def test_to_dataframe(self, monkeypatch, singlelabel_textclassification_records)
assert list(df.columns) == list(TextClassificationRecord.__fields__)
def test_prepare_dataset_and_column_mapping(self, monkeypatch, caplog):
- monkeypatch.setattr(
- "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord
- )
+ monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord)
ds = datasets.Dataset.from_dict(
{
@@ -119,12 +111,8 @@ def test_prepare_dataset_and_column_mapping(self, monkeypatch, caplog):
with pytest.raises(ValueError, match="datasets.DatasetDict` are not supported"):
DatasetBase._prepare_dataset_and_column_mapping(ds_dict, None)
- col_mapping = dict(
- id="ID", inputs=["inputs_a", "inputs_b"], metadata="metadata"
- )
- prepared_ds, col_to_be_joined = DatasetBase._prepare_dataset_and_column_mapping(
- ds, col_mapping
- )
+ col_mapping = dict(id="ID", inputs=["inputs_a", "inputs_b"], metadata="metadata")
+ prepared_ds, col_to_be_joined = DatasetBase._prepare_dataset_and_column_mapping(ds, col_mapping)
assert prepared_ds.column_names == ["id", "inputs_a", "inputs_b", "metadata"]
assert col_to_be_joined == {
@@ -138,12 +126,8 @@ def test_prepare_dataset_and_column_mapping(self, monkeypatch, caplog):
)
def test_from_pandas(self, monkeypatch, caplog):
- monkeypatch.setattr(
- "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord
- )
- monkeypatch.setattr(
- "argilla.client.datasets.DatasetBase._from_pandas", lambda x: x
- )
+ monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord)
+ monkeypatch.setattr("argilla.client.datasets.DatasetBase._from_pandas", lambda x: x)
df = pd.DataFrame({"unsupported_column": [None]})
empty_df = DatasetBase.from_pandas(df)
@@ -153,8 +137,7 @@ def test_from_pandas(self, monkeypatch, caplog):
assert caplog.record_tuples[0][1] == 30
assert (
"Following columns are not supported by the "
- "TextClassificationRecord model and are ignored: ['unsupported_column']"
- == caplog.record_tuples[0][2]
+ "TextClassificationRecord model and are ignored: ['unsupported_column']" == caplog.record_tuples[0][2]
)
def test_to_datasets(self, monkeypatch, caplog):
@@ -170,9 +153,7 @@ def test_to_datasets(self, monkeypatch, caplog):
assert len(datasets_ds) == 0
assert len(caplog.record_tuples) == 1
assert caplog.record_tuples[0][1] == 30
- assert (
- "The 'metadata' of the records were removed" in caplog.record_tuples[0][2]
- )
+ assert "The 'metadata' of the records were removed" in caplog.record_tuples[0][2]
def test_datasets_not_installed(self, monkeypatch):
monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", "mock")
@@ -186,12 +167,8 @@ def test_datasets_wrong_version(self, monkeypatch):
with pytest.raises(ModuleNotFoundError, match="pip install -U datasets>1.17.0"):
DatasetBase().to_datasets()
- def test_iter_len_getitem(
- self, monkeypatch, singlelabel_textclassification_records
- ):
- monkeypatch.setattr(
- "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord
- )
+ def test_iter_len_getitem(self, monkeypatch, singlelabel_textclassification_records):
+ monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord)
dataset = DatasetBase(singlelabel_textclassification_records)
for record, expected in zip(dataset, singlelabel_textclassification_records):
@@ -201,9 +178,7 @@ def test_iter_len_getitem(
assert dataset[1] is singlelabel_textclassification_records[1]
def test_setitem_delitem(self, monkeypatch, singlelabel_textclassification_records):
- monkeypatch.setattr(
- "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord
- )
+ monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord)
dataset = DatasetBase(
[rec.copy(deep=True) for rec in singlelabel_textclassification_records],
)
@@ -226,12 +201,8 @@ def test_setitem_delitem(self, monkeypatch, singlelabel_textclassification_recor
):
dataset[0] = ar.Text2TextRecord(text="mock")
- def test_prepare_for_training_train_test_splits(
- self, monkeypatch, singlelabel_textclassification_records
- ):
- monkeypatch.setattr(
- "argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord
- )
+ def test_prepare_for_training_train_test_splits(self, monkeypatch, singlelabel_textclassification_records):
+ monkeypatch.setattr("argilla.client.datasets.DatasetBase._RECORD_TYPE", TextClassificationRecord)
temp_records = copy.deepcopy(singlelabel_textclassification_records)
ds = DatasetBase(temp_records)
@@ -241,9 +212,7 @@ def test_prepare_for_training_train_test_splits(
):
ds.prepare_for_training(train_size=-1)
- with pytest.raises(
- AssertionError, match="`train_size` and `test_size` must sum to 1."
- ):
+ with pytest.raises(AssertionError, match="`train_size` and `test_size` must sum to 1."):
ds.prepare_for_training(test_size=0.1, train_size=0.6)
for rec in ds:
@@ -307,18 +276,14 @@ def test_to_from_datasets(self, records, request):
assert rec.inputs == {"text": "mock"}
def test_from_to_datasets_id(self):
- dataset_rb = ar.DatasetForTextClassification(
- [ar.TextClassificationRecord(text="mock")]
- )
+ dataset_rb = ar.DatasetForTextClassification([ar.TextClassificationRecord(text="mock")])
dataset_ds = dataset_rb.to_datasets()
assert dataset_ds["id"] == [None]
assert ar.read_datasets(dataset_ds, task="TextClassification")[0].id is None
def test_datasets_empty_metadata(self):
- dataset = ar.DatasetForTextClassification(
- [ar.TextClassificationRecord(text="mock")]
- )
+ dataset = ar.DatasetForTextClassification([ar.TextClassificationRecord(text="mock")])
assert dataset.to_datasets()["metadata"] == [None]
@pytest.mark.parametrize(
@@ -359,9 +324,7 @@ def test_push_to_hub(self, request, name: str):
# TODO(@frascuchon): move dataset to new organization
dataset_name = f"rubrix/_test_text_classification_records-{name}"
dataset_ds = ar.DatasetForTextClassification(records).to_datasets()
- _push_to_hub_with_retries(
- dataset_ds, repo_id=dataset_name, token=_HF_HUB_ACCESS_TOKEN, private=True
- )
+ _push_to_hub_with_retries(dataset_ds, repo_id=dataset_name, token=_HF_HUB_ACCESS_TOKEN, private=True)
sleep(1)
dataset_ds = datasets.load_dataset(
dataset_name,
@@ -490,9 +453,7 @@ def test_from_datasets_with_annotation_arg(self):
}
),
)
- dataset_rb = ar.DatasetForTextClassification.from_datasets(
- dataset_ds, annotation="label"
- )
+ dataset_rb = ar.DatasetForTextClassification.from_datasets(dataset_ds, annotation="label")
assert [rec.annotation for rec in dataset_rb] == ["HAM", None]
@@ -546,32 +507,24 @@ def test_to_from_datasets(self, tokenclassification_records):
assert isinstance(dataset, ar.DatasetForTokenClassification)
_compare_datasets(dataset, expected_dataset)
- missing_optional_cols = datasets.Dataset.from_dict(
- {"text": ["mock"], "tokens": [["mock"]]}
- )
+ missing_optional_cols = datasets.Dataset.from_dict({"text": ["mock"], "tokens": [["mock"]]})
rec = ar.DatasetForTokenClassification.from_datasets(missing_optional_cols)[0]
assert rec.text == "mock" and rec.tokens == ["mock"]
def test_from_to_datasets_id(self):
- dataset_rb = ar.DatasetForTokenClassification(
- [ar.TokenClassificationRecord(text="mock", tokens=["mock"])]
- )
+ dataset_rb = ar.DatasetForTokenClassification([ar.TokenClassificationRecord(text="mock", tokens=["mock"])])
dataset_ds = dataset_rb.to_datasets()
assert dataset_ds["id"] == [None]
assert ar.read_datasets(dataset_ds, task="TokenClassification")[0].id is None
def test_prepare_for_training_empty(self):
- dataset = ar.DatasetForTokenClassification(
- [ar.TokenClassificationRecord(text="mock", tokens=["mock"])]
- )
+ dataset = ar.DatasetForTokenClassification([ar.TokenClassificationRecord(text="mock", tokens=["mock"])])
with pytest.raises(AssertionError):
dataset.prepare_for_training()
def test_datasets_empty_metadata(self):
- dataset = ar.DatasetForTokenClassification(
- [ar.TokenClassificationRecord(text="mock", tokens=["mock"])]
- )
+ dataset = ar.DatasetForTokenClassification([ar.TokenClassificationRecord(text="mock", tokens=["mock"])])
assert dataset.to_datasets()["metadata"] == [None]
def test_to_from_pandas(self, tokenclassification_records):
@@ -593,9 +546,7 @@ def test_to_from_pandas(self, tokenclassification_records):
reason="You need a HF Hub access token to test the push_to_hub feature",
)
def test_push_to_hub(self, tokenclassification_records):
- dataset_ds = ar.DatasetForTokenClassification(
- tokenclassification_records
- ).to_datasets()
+ dataset_ds = ar.DatasetForTokenClassification(tokenclassification_records).to_datasets()
_push_to_hub_with_retries(
dataset_ds,
# TODO(@frascuchon): Move dataset to the new org
@@ -624,26 +575,18 @@ def test_prepare_for_training_with_spacy(self):
use_auth_token=_HF_HUB_ACCESS_TOKEN,
split="train",
)
- rb_dataset: DatasetForTokenClassification = ar.read_datasets(
- ner_dataset, task="TokenClassification"
- )
+ rb_dataset: DatasetForTokenClassification = ar.read_datasets(ner_dataset, task="TokenClassification")
for r in rb_dataset:
- r.annotation = [
- (label, start, end) for label, start, end, _ in r.prediction
- ]
+ r.annotation = [(label, start, end) for label, start, end, _ in r.prediction]
with pytest.raises(ValueError):
train = rb_dataset.prepare_for_training(framework="spacy")
- train = rb_dataset.prepare_for_training(
- framework="spacy", lang=spacy.blank("en")
- )
+ train = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en"))
assert isinstance(train, spacy.tokens.DocBin)
assert len(train) == 100
- train, test = rb_dataset.prepare_for_training(
- framework="spacy", lang=spacy.blank("en"), train_size=0.8
- )
+ train, test = rb_dataset.prepare_for_training(framework="spacy", lang=spacy.blank("en"), train_size=0.8)
assert isinstance(train, spacy.tokens.DocBin)
assert isinstance(test, spacy.tokens.DocBin)
assert len(train) == 80
@@ -660,21 +603,15 @@ def test_prepare_for_training_with_spark_nlp(self):
use_auth_token=_HF_HUB_ACCESS_TOKEN,
split="train",
)
- rb_dataset: DatasetForTokenClassification = ar.read_datasets(
- ner_dataset, task="TokenClassification"
- )
+ rb_dataset: DatasetForTokenClassification = ar.read_datasets(ner_dataset, task="TokenClassification")
for r in rb_dataset:
- r.annotation = [
- (label, start, end) for label, start, end, _ in r.prediction
- ]
+ r.annotation = [(label, start, end) for label, start, end, _ in r.prediction]
train = rb_dataset.prepare_for_training(framework="spark-nlp")
assert isinstance(train, pd.DataFrame)
assert len(train) == 100
- train, test = rb_dataset.prepare_for_training(
- framework="spark-nlp", train_size=0.8
- )
+ train, test = rb_dataset.prepare_for_training(framework="spark-nlp", train_size=0.8)
assert isinstance(train, pd.DataFrame)
assert isinstance(test, pd.DataFrame)
assert len(train) == 80
@@ -691,18 +628,12 @@ def test_prepare_for_training(self):
use_auth_token=_HF_HUB_ACCESS_TOKEN,
split="train",
)
- rb_dataset: DatasetForTokenClassification = ar.read_datasets(
- ner_dataset, task="TokenClassification"
- )
+ rb_dataset: DatasetForTokenClassification = ar.read_datasets(ner_dataset, task="TokenClassification")
for r in rb_dataset:
- r.annotation = [
- (label, start, end) for label, start, end, _ in r.prediction
- ]
+ r.annotation = [(label, start, end) for label, start, end, _ in r.prediction]
train = rb_dataset.prepare_for_training()
- assert isinstance(train, datasets.DatasetD.Dataset) or isinstance(
- train, datasets.Dataset
- )
+ assert isinstance(train, datasets.DatasetD.Dataset) or isinstance(train, datasets.Dataset)
assert "ner_tags" in train.column_names
assert len(train) == 100
assert train.features["ner_tags"] == [
@@ -760,9 +691,7 @@ def test_from_dataset_with_non_argilla_format(self):
use_auth_token=_HF_HUB_ACCESS_TOKEN,
)
- rb_ds = ar.DatasetForTokenClassification.from_datasets(
- ds, tags="ner_tags", metadata=["spans"]
- )
+ rb_ds = ar.DatasetForTokenClassification.from_datasets(ds, tags="ner_tags", metadata=["spans"])
again_the_ds = rb_ds.to_datasets()
assert again_the_ds.column_names == [
@@ -781,9 +710,7 @@ def test_from_dataset_with_non_argilla_format(self):
def test_from_datasets_with_empty_tokens(self, caplog):
dataset_ds = datasets.Dataset.from_dict({"empty_tokens": [["mock"], []]})
- dataset_rb = ar.DatasetForTokenClassification.from_datasets(
- dataset_ds, tokens="empty_tokens"
- )
+ dataset_rb = ar.DatasetForTokenClassification.from_datasets(dataset_ds, tokens="empty_tokens")
assert caplog.record_tuples[0][1] == 30
assert caplog.record_tuples[0][2] == "Ignoring row with no tokens."
@@ -834,9 +761,7 @@ def test_to_from_datasets(self, text2text_records):
assert rec.text == "mock"
# alternative format for the predictions
- ds = datasets.Dataset.from_dict(
- {"text": ["example"], "prediction": [["ejemplo"]]}
- )
+ ds = datasets.Dataset.from_dict({"text": ["example"], "prediction": [["ejemplo"]]})
rec = ar.DatasetForText2Text.from_datasets(ds)[0]
assert rec.prediction[0][0] == "ejemplo"
assert rec.prediction[0][1] == pytest.approx(1.0)
@@ -899,9 +824,7 @@ def test_from_dataset_with_non_argilla_format(self):
use_auth_token=_HF_HUB_ACCESS_TOKEN,
)
- rb_ds = ar.DatasetForText2Text.from_datasets(
- ds, text="description", annotation="abstract"
- )
+ rb_ds = ar.DatasetForText2Text.from_datasets(ds, text="description", annotation="abstract")
again_the_ds = rb_ds.to_datasets()
assert again_the_ds.column_names == [
@@ -924,9 +847,7 @@ def _compare_datasets(dataset, expected_dataset):
# TODO: have to think about how we deal with `None`s
if col in ["metadata", "metrics"]:
continue
- assert getattr(rec, col) == getattr(
- expected, col
- ), f"Wrong column value '{col}'"
+ assert getattr(rec, col) == getattr(expected, col), f"Wrong column value '{col}'"
@pytest.mark.parametrize(
@@ -941,9 +862,7 @@ def test_read_pandas(monkeypatch, task, dataset_class):
def mock_from_pandas(mock):
return mock
- monkeypatch.setattr(
- f"argilla.client.datasets.{dataset_class}.from_pandas", mock_from_pandas
- )
+ monkeypatch.setattr(f"argilla.client.datasets.{dataset_class}.from_pandas", mock_from_pandas)
assert ar.read_pandas("mock", task) == "mock"
@@ -960,8 +879,6 @@ def test_read_datasets(monkeypatch, task, dataset_class):
def mock_from_datasets(mock):
return mock
- monkeypatch.setattr(
- f"argilla.client.datasets.{dataset_class}.from_datasets", mock_from_datasets
- )
+ monkeypatch.setattr(f"argilla.client.datasets.{dataset_class}.from_datasets", mock_from_datasets)
assert ar.read_datasets("mock", task) == "mock"
diff --git a/tests/client/test_init.py b/tests/client/test_init.py
index 2b998d2e02..5efebd74d4 100644
--- a/tests/client/test_init.py
+++ b/tests/client/test_init.py
@@ -45,9 +45,7 @@ def test_init_with_extra_headers(mocked_client):
active_api = api.active_api()
for key, value in expected_headers.items():
- assert (
- active_api.http_client.headers[key] == value
- ), f"{key}:{value} not in client headers"
+ assert active_api.http_client.headers[key] == value, f"{key}:{value} not in client headers"
def test_init(mocked_client):
diff --git a/tests/client/test_models.py b/tests/client/test_models.py
index bf3d72c366..7fe2c38e0f 100644
--- a/tests/client/test_models.py
+++ b/tests/client/test_models.py
@@ -40,29 +40,21 @@
def test_text_classification_record(annotation, status, expected_status):
"""Just testing its dynamic defaults"""
if status:
- record = TextClassificationRecord(
- inputs={"text": "test"}, annotation=annotation, status=status
- )
+ record = TextClassificationRecord(inputs={"text": "test"}, annotation=annotation, status=status)
else:
- record = TextClassificationRecord(
- inputs={"text": "test"}, annotation=annotation
- )
+ record = TextClassificationRecord(inputs={"text": "test"}, annotation=annotation)
assert record.status == expected_status
def test_text_classification_input_string():
event_timestamp = datetime.datetime.now()
- assert TextClassificationRecord(
- text="A text", event_timestamp=event_timestamp
- ) == TextClassificationRecord(
+ assert TextClassificationRecord(text="A text", event_timestamp=event_timestamp) == TextClassificationRecord(
inputs=dict(text="A text"), event_timestamp=event_timestamp
)
assert TextClassificationRecord(
inputs=["A text", "another text"], event_timestamp=event_timestamp
- ) == TextClassificationRecord(
- inputs=dict(text=["A text", "another text"]), event_timestamp=event_timestamp
- )
+ ) == TextClassificationRecord(inputs=dict(text=["A text", "another text"]), event_timestamp=event_timestamp)
def test_text_classification_text_inputs():
@@ -74,38 +66,27 @@ def test_text_classification_text_inputs():
with pytest.warns(
FutureWarning,
- match=(
- "the `inputs` argument of the `TextClassificationRecord` will not accept"
- " strings."
- ),
+ match=("the `inputs` argument of the `TextClassificationRecord` will not accept" " strings."),
):
TextClassificationRecord(inputs="mock")
event_timestamp = datetime.datetime.now()
- assert TextClassificationRecord(
- text="mock", event_timestamp=event_timestamp
- ) == TextClassificationRecord(
+ assert TextClassificationRecord(text="mock", event_timestamp=event_timestamp) == TextClassificationRecord(
inputs={"text": "mock"}, event_timestamp=event_timestamp
)
- assert TextClassificationRecord(
- inputs=["mock"], event_timestamp=event_timestamp
- ) == TextClassificationRecord(
+ assert TextClassificationRecord(inputs=["mock"], event_timestamp=event_timestamp) == TextClassificationRecord(
inputs={"text": ["mock"]}, event_timestamp=event_timestamp
)
assert TextClassificationRecord(
text="mock", inputs={"text": "mock"}, event_timestamp=event_timestamp
- ) == TextClassificationRecord(
- inputs={"text": "mock"}, event_timestamp=event_timestamp
- )
+ ) == TextClassificationRecord(inputs={"text": "mock"}, event_timestamp=event_timestamp)
rec = TextClassificationRecord(text="mock")
with pytest.raises(AttributeError, match="You cannot assign a new value to `text`"):
rec.text = "mock"
- with pytest.raises(
- AttributeError, match="You cannot assign a new value to `inputs`"
- ):
+ with pytest.raises(AttributeError, match="You cannot assign a new value to `inputs`"):
rec.inputs = "mock"
@@ -120,9 +101,7 @@ def test_text_classification_text_inputs():
)
def test_token_classification_record(annotation, status, expected_status, expected_iob):
"""Just testing its dynamic defaults"""
- record = TokenClassificationRecord(
- text="test text", tokens=["test", "text"], annotation=annotation, status=status
- )
+ record = TokenClassificationRecord(text="test text", tokens=["test", "text"], annotation=annotation, status=status)
assert record.status == expected_status
assert record.spans2iob(record.annotation) == expected_iob
@@ -146,10 +125,7 @@ def test_token_classification_with_tokens_and_tags(tokens, tags, annotation):
def test_token_classification_validations():
with pytest.raises(
AssertionError,
- match=(
- "Missing fields: "
- "At least one of `text` or `tokens` argument must be provided!"
- ),
+ match=("Missing fields: " "At least one of `text` or `tokens` argument must be provided!"),
):
TokenClassificationRecord()
@@ -157,25 +133,17 @@ def test_token_classification_validations():
annotation = [("test", 0, 4)]
with pytest.raises(
AssertionError,
- match=(
- "Missing field `text`: "
- "char level spans must be provided with a raw text sentence"
- ),
+ match=("Missing field `text`: " "char level spans must be provided with a raw text sentence"),
):
TokenClassificationRecord(tokens=tokens, annotation=annotation)
with pytest.raises(
AssertionError,
- match=(
- "Missing field `text`: "
- "char level spans must be provided with a raw text sentence"
- ),
+ match=("Missing field `text`: " "char level spans must be provided with a raw text sentence"),
):
TokenClassificationRecord(tokens=tokens, prediction=annotation)
- TokenClassificationRecord(
- text=" ".join(tokens), tokens=tokens, prediction=annotation
- )
+ TokenClassificationRecord(text=" ".join(tokens), tokens=tokens, prediction=annotation)
record = TokenClassificationRecord(tokens=tokens)
assert record.text == "test text"
@@ -185,16 +153,12 @@ def test_token_classification_with_mutation():
text_a = "The text"
text_b = "Another text sample here !!!"
- record = TokenClassificationRecord(
- text=text_a, tokens=text_a.split(" "), annotation=[]
- )
+ record = TokenClassificationRecord(text=text_a, tokens=text_a.split(" "), annotation=[])
assert record.spans2iob(record.annotation) == ["O"] * len(text_a.split(" "))
with pytest.raises(AttributeError, match="You cannot assign a new value to `text`"):
record.text = text_b
- with pytest.raises(
- AttributeError, match="You cannot assign a new value to `tokens`"
- ):
+ with pytest.raises(AttributeError, match="You cannot assign a new value to `tokens`"):
record.tokens = text_b.split(" ")
@@ -212,9 +176,7 @@ def test_token_classification_with_mutation():
],
)
def test_token_classification_prediction_validator(prediction, expected):
- record = TokenClassificationRecord(
- text="this", tokens=["this"], prediction=prediction
- )
+ record = TokenClassificationRecord(text="this", tokens=["this"], prediction=prediction)
assert record.prediction == expected
@@ -246,9 +208,7 @@ def test_metadata_values_length():
def test_model_serialization_with_numpy_nan():
- record = Text2TextRecord(
- text="My name is Sarah and I love my dog.", metadata={"nan": numpy.nan}
- )
+ record = Text2TextRecord(text="My name is Sarah and I love my dog.", metadata={"nan": numpy.nan})
json_record = json.loads(record.json())
@@ -260,13 +220,9 @@ class MockRecord(_Validators):
annotation: Optional[Any] = None
annotation_agent: Optional[str] = None
- with pytest.warns(
- UserWarning, match="`prediction_agent` will not be logged to the server."
- ):
+ with pytest.warns(UserWarning, match="`prediction_agent` will not be logged to the server."):
MockRecord(prediction_agent="mock")
- with pytest.warns(
- UserWarning, match="`annotation_agent` will not be logged to the server."
- ):
+ with pytest.warns(UserWarning, match="`annotation_agent` will not be logged to the server."):
MockRecord(annotation_agent="mock")
@@ -312,9 +268,7 @@ def test_text2text_prediction_validator(prediction, expected):
"record",
[
TextClassificationRecord(text="This is a test"),
- TokenClassificationRecord(
- text="This is a test", tokens="This is a test".split()
- ),
+ TokenClassificationRecord(text="This is a test", tokens="This is a test".split()),
Text2TextRecord(text="This is a test"),
],
)
diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py
index 4499d25964..7384aa9bae 100644
--- a/tests/datasets/test_datasets.py
+++ b/tests/datasets/test_datasets.py
@@ -57,9 +57,7 @@ def test_delete_dataset_by_non_creator(mocked_client):
try:
dataset = "test_delete_dataset_by_non_creator"
ar.delete(dataset)
- ar.configure_dataset(
- dataset, settings=TextClassificationSettings(label_schema={"A", "B", "C"})
- )
+ ar.configure_dataset(dataset, settings=TextClassificationSettings(label_schema={"A", "B", "C"}))
mocked_client.change_current_user("mock-user")
with pytest.raises(ForbiddenApiError):
ar.delete(dataset)
diff --git a/tests/functional_tests/datasets/test_delete_records_from_datasets.py b/tests/functional_tests/datasets/test_delete_records_from_datasets.py
index 238c358338..56b2e6820c 100644
--- a/tests/functional_tests/datasets/test_delete_records_from_datasets.py
+++ b/tests/functional_tests/datasets/test_delete_records_from_datasets.py
@@ -26,10 +26,7 @@ def test_delete_records_from_dataset(mocked_client):
ar.log(
name=dataset,
records=[
- ar.TextClassificationRecord(
- id=i, text="This is the text", metadata=dict(idx=i)
- )
- for i in range(0, 50)
+ ar.TextClassificationRecord(id=i, text="This is the text", metadata=dict(idx=i)) for i in range(0, 50)
],
)
@@ -40,9 +37,7 @@ def test_delete_records_from_dataset(mocked_client):
assert len(ds) == 50
time.sleep(1)
- matched, processed = ar.delete_records(
- name=dataset, query="id:10", discard_only=False
- )
+ matched, processed = ar.delete_records(name=dataset, query="id:10", discard_only=False)
assert matched, processed == (1, 1)
time.sleep(1)
@@ -58,10 +53,7 @@ def test_delete_records_without_permission(mocked_client):
ar.log(
name=dataset,
records=[
- ar.TextClassificationRecord(
- id=i, text="This is the text", metadata=dict(idx=i)
- )
- for i in range(0, 50)
+ ar.TextClassificationRecord(id=i, text="This is the text", metadata=dict(idx=i)) for i in range(0, 50)
],
)
try:
diff --git a/tests/functional_tests/search/test_search_service.py b/tests/functional_tests/search/test_search_service.py
index 9fc0b9b439..90ea518ae1 100644
--- a/tests/functional_tests/search/test_search_service.py
+++ b/tests/functional_tests/search/test_search_service.py
@@ -70,9 +70,7 @@ def test_query_builder_with_query_range(backend: GenericElasticEngineBackend):
}
-def test_query_builder_with_nested(
- mocked_client, dao, backend: GenericElasticEngineBackend
-):
+def test_query_builder_with_nested(mocked_client, dao, backend: GenericElasticEngineBackend):
dataset = Dataset(
name="test_query_builder_with_nested",
owner=argilla.get_workspace(),
@@ -106,20 +104,8 @@ def test_query_builder_with_nested(
"query": {
"bool": {
"must": [
- {
- "term": {
- "metrics.predicted.mentions.label": {
- "value": "NAME"
- }
- }
- },
- {
- "range": {
- "metrics.predicted.mentions.score": {
- "lte": "0.1"
- }
- }
- },
+ {"term": {"metrics.predicted.mentions.label": {"value": "NAME"}}},
+ {"range": {"metrics.predicted.mentions.score": {"lte": "0.1"}}},
]
}
},
diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py
index 5e2a16f8bb..be50cdf406 100644
--- a/tests/functional_tests/test_log_for_text_classification.py
+++ b/tests/functional_tests/test_log_for_text_classification.py
@@ -226,9 +226,7 @@ def test_search_keywords(mocked_client):
top_keywords = set(
[
keyword
- for keywords in df.search_keywords.value_counts(sort=True, ascending=False)
- .index[:3]
- .tolist()
+ for keywords in df.search_keywords.value_counts(sort=True, ascending=False).index[:3].tolist()
for keyword in keywords
]
)
@@ -262,17 +260,12 @@ def test_logging_with_metadata_limits_exceeded(mocked_client):
expected_record = ar.TextClassificationRecord(
text="The input text",
- metadata={
- k: f"this is a string {k}"
- for k in range(0, settings.metadata_fields_limit + 1)
- },
+ metadata={k: f"this is a string {k}" for k in range(0, settings.metadata_fields_limit + 1)},
)
with pytest.raises(BadRequestApiError):
ar.log(expected_record, name=dataset)
- expected_record.metadata = {
- k: f"This is a string {k}" for k in range(0, settings.metadata_fields_limit)
- }
+ expected_record.metadata = {k: f"This is a string {k}" for k in range(0, settings.metadata_fields_limit)}
# Dataset creation with data
ar.log(expected_record, name=dataset)
# This call will check already included fields
diff --git a/tests/functional_tests/test_log_for_token_classification.py b/tests/functional_tests/test_log_for_token_classification.py
index 2df8c37c5c..7aa45b3ed4 100644
--- a/tests/functional_tests/test_log_for_token_classification.py
+++ b/tests/functional_tests/test_log_for_token_classification.py
@@ -29,9 +29,7 @@ def test_log_with_empty_text(mocked_client):
text = " "
argilla.delete(dataset)
- with pytest.raises(
- Exception, match="The provided `text` contains only whitespaces."
- ):
+ with pytest.raises(Exception, match="The provided `text` contains only whitespaces."):
argilla.log(
TokenClassificationRecord(id=0, text=text, tokens=["a", "b", "c"]),
name=dataset,
@@ -531,9 +529,7 @@ def test_search_keywords(mocked_client):
top_keywords = set(
[
keyword
- for keywords in df.search_keywords.value_counts(sort=True, ascending=False)
- .index[:3]
- .tolist()
+ for keywords in df.search_keywords.value_counts(sort=True, ascending=False).index[:3].tolist()
for keyword in keywords
]
)
diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py
index f33f9d8ab6..b07780ec71 100644
--- a/tests/labeling/text_classification/test_label_errors.py
+++ b/tests/labeling/text_classification/test_label_errors.py
@@ -27,15 +27,11 @@
from pkg_resources import parse_version
-@pytest.fixture(
- params=[False, True], ids=["single_label", "multi_label"], scope="module"
-)
+@pytest.fixture(params=[False, True], ids=["single_label", "multi_label"], scope="module")
def records(request):
if request.param:
return [
- ar.TextClassificationRecord(
- text="test", annotation=anot, prediction=pred, multi_label=True, id=i
- )
+ ar.TextClassificationRecord(text="test", annotation=anot, prediction=pred, multi_label=True, id=i)
for i, anot, pred in zip(
range(2 * 6),
[["bad"], ["bad", "good"]] * 6,
@@ -70,9 +66,7 @@ def test_no_records():
ar.TextClassificationRecord(text="test", annotation="test"),
]
- with pytest.raises(
- NoRecordsError, match="none of your records have a prediction AND annotation"
- ):
+ with pytest.raises(NoRecordsError, match="none of your records have a prediction AND annotation"):
find_label_errors(records)
@@ -84,10 +78,7 @@ def test_multi_label_warning(caplog):
multi_label=True,
)
find_label_errors([record], multi_label="True")
- assert (
- "You provided the kwarg 'multi_label', but it is determined automatically"
- in caplog.text
- )
+ assert "You provided the kwarg 'multi_label', but it is determined automatically" in caplog.text
@pytest.mark.parametrize(
@@ -120,9 +111,7 @@ def mock_find_label_issues(*args, **kwargs):
mock_find_label_issues,
)
- record = ar.TextClassificationRecord(
- text="mock", prediction=[("mock", 0.1)], annotation="mock"
- )
+ record = ar.TextClassificationRecord(text="mock", prediction=[("mock", 0.1)], annotation="mock")
find_label_errors(records=[record], sort_by=sort_by)
@@ -144,9 +133,7 @@ def mock_get_noise_indices(s, psx, n_jobs, **kwargs):
mock_get_noise_indices,
)
- with pytest.raises(
- ValueError, match="'sorted_index_method' kwarg is not supported"
- ):
+ with pytest.raises(ValueError, match="'sorted_index_method' kwarg is not supported"):
find_label_errors(records=records, sorted_index_method="mock")
find_label_errors(records=records, mock="mock")
@@ -156,9 +143,7 @@ def mock_find_label_issues(s, psx, n_jobs, **kwargs):
assert kwargs == {
"mock": "mock",
"multi_label": is_multi_label,
- "return_indices_ranked_by": "normalized_margin"
- if not is_multi_label
- else "self_confidence",
+ "return_indices_ranked_by": "normalized_margin" if not is_multi_label else "self_confidence",
}
return []
@@ -167,9 +152,7 @@ def mock_find_label_issues(s, psx, n_jobs, **kwargs):
mock_find_label_issues,
)
- with pytest.raises(
- ValueError, match="'return_indices_ranked_by' kwarg is not supported"
- ):
+ with pytest.raises(ValueError, match="'return_indices_ranked_by' kwarg is not supported"):
find_label_errors(records=records, return_indices_ranked_by="mock")
find_label_errors(records=records, mock="mock")
@@ -207,22 +190,14 @@ def test_construct_s_and_psx(records):
def test_missing_predictions():
- records = [
- ar.TextClassificationRecord(
- text="test", annotation="mock", prediction=[("mock2", 0.1)]
- )
- ]
+ records = [ar.TextClassificationRecord(text="test", annotation="mock", prediction=[("mock2", 0.1)])]
with pytest.raises(
MissingPredictionError,
match="It seems predictions are missing for the label 'mock'",
):
_construct_s_and_psx(records)
- records.append(
- ar.TextClassificationRecord(
- text="test", annotation="mock", prediction=[("mock", 0.1)]
- )
- )
+ records.append(ar.TextClassificationRecord(text="test", annotation="mock", prediction=[("mock", 0.1)]))
with pytest.raises(
MissingPredictionError,
match="It seems a prediction for 'mock' is missing in the following record",
diff --git a/tests/labeling/text_classification/test_label_models.py b/tests/labeling/text_classification/test_label_models.py
index 0cde14796f..4cad2d86f4 100644
--- a/tests/labeling/text_classification/test_label_models.py
+++ b/tests/labeling/text_classification/test_label_models.py
@@ -43,9 +43,7 @@ def mock_load(*args, **kwargs):
TextClassificationRecord(text="test", annotation="neutral"),
]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array(
@@ -63,17 +61,13 @@ def mock_apply(self, *args, **kwargs):
@pytest.fixture
def weak_labels_from_guide(monkeypatch, resources):
- matrix_and_annotation = np.load(
- str(resources / "weak-supervision-guide-matrix.npy")
- )
+ matrix_and_annotation = np.load(str(resources / "weak-supervision-guide-matrix.npy"))
matrix, annotation = matrix_and_annotation[:, :-1], matrix_and_annotation[:, -1]
def mock_load(*args, **kwargs):
return [TextClassificationRecord(text="mock", id=i) for i in range(len(matrix))]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
return matrix, annotation, {None: -1, "SPAM": 0, "HAM": 1}
@@ -87,19 +81,13 @@ def mock_apply(self, *args, **kwargs):
def weak_multi_labels(monkeypatch):
def mock_load(*args, **kwargs):
return [
- TextClassificationRecord(
- text="test", multi_label=True, annotation=["scared"]
- ),
- TextClassificationRecord(
- text="test", multi_label=True, annotation=["sad", "scared"]
- ),
+ TextClassificationRecord(text="test", multi_label=True, annotation=["scared"]),
+ TextClassificationRecord(text="test", multi_label=True, annotation=["sad", "scared"]),
TextClassificationRecord(text="test", multi_label=True, annotation=[]),
TextClassificationRecord(text="test", multi_label=True),
]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array(
@@ -111,9 +99,7 @@ def mock_apply(self, *args, **kwargs):
],
dtype=np.byte,
)
- annotation_array = np.array(
- [[0, 0, 1], [1, 0, 1], [0, 0, 0], [-1, -1, -1]], dtype=np.byte
- )
+ annotation_array = np.array([[0, 0, 1], [1, 0, 1], [0, 0, 0], [-1, -1, -1]], dtype=np.byte)
labels = ["sad", "happy", "scared"]
return weak_label_matrix, annotation_array, labels
@@ -159,9 +145,7 @@ def test_no_need_to_fit_error(self):
("weak_multi_labels", False, 1),
],
)
- def test_predict(
- self, monkeypatch, request, wls, include_annotated_records, expected
- ):
+ def test_predict(self, monkeypatch, request, wls, include_annotated_records, expected):
def compute_probs(self, wl_matrix, **kwargs):
assert len(wl_matrix) == expected
compute_probs.called = None
@@ -171,20 +155,13 @@ def make_records(self, probabilities, records, **kwargs):
return records
single_or_multi = "multi" if wls == "weak_multi_labels" else "single"
- monkeypatch.setattr(
- MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs
- )
- monkeypatch.setattr(
- MajorityVoter, f"_make_{single_or_multi}_label_records", make_records
- )
+ monkeypatch.setattr(MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs)
+ monkeypatch.setattr(MajorityVoter, f"_make_{single_or_multi}_label_records", make_records)
weak_labels = request.getfixturevalue(wls)
mj = MajorityVoter(weak_labels)
- assert (
- len(mj.predict(include_annotated_records=include_annotated_records))
- == expected
- )
+ assert len(mj.predict(include_annotated_records=include_annotated_records)) == expected
assert hasattr(compute_probs, "called")
def test_compute_single_label_probs(self, weak_labels):
@@ -210,9 +187,7 @@ def test_compute_single_label_probs(self, weak_labels):
(False, TieBreakPolicy.RANDOM, 4),
],
)
- def test_make_single_label_records(
- self, weak_labels, include_abstentions, tie_break_policy, expected
- ):
+ def test_make_single_label_records(self, weak_labels, include_abstentions, tie_break_policy, expected):
mj = MajorityVoter(weak_labels)
probs = mj._compute_single_label_probs(weak_labels.matrix())
@@ -273,9 +248,7 @@ def test_compute_multi_label_probs(self, weak_multi_labels):
(False, 3),
],
)
- def test_make_multi_label_records(
- self, weak_multi_labels, include_abstentions, expected
- ):
+ def test_make_multi_label_records(self, weak_multi_labels, include_abstentions, expected):
mj = MajorityVoter(weak_multi_labels)
probs = mj._compute_multi_label_probs(weak_multi_labels.matrix())
@@ -324,9 +297,7 @@ def score(self, probabilities, tie_break_policy=None):
return np.array([[1, 1, 1], [0, 0, 0]]), np.array([[1, 1, 1], [1, 0, 0]])
single_or_multi = "multi" if wls == "weak_multi_labels" else "single"
- monkeypatch.setattr(
- MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs
- )
+ monkeypatch.setattr(MajorityVoter, f"_compute_{single_or_multi}_label_probs", compute_probs)
monkeypatch.setattr(MajorityVoter, f"_score_{single_or_multi}_label", score)
weak_labels = request.getfixturevalue(wls)
@@ -349,45 +320,31 @@ def score(self, probabilities, tie_break_policy=None):
def test_score_single_label(self, weak_labels, tie_break_policy, expected):
mj = MajorityVoter(weak_labels)
- probabilities = np.array(
- [[0.5, 0.5, 0.0], [0.5, 0.0, 0.5], [1.0 / 3, 0.0, 2.0 / 3]]
- )
+ probabilities = np.array([[0.5, 0.5, 0.0], [0.5, 0.0, 0.5], [1.0 / 3, 0.0, 2.0 / 3]])
if tie_break_policy is TieBreakPolicy.TRUE_RANDOM:
- with pytest.raises(
- NotImplementedError, match="not implemented for MajorityVoter"
- ):
+ with pytest.raises(NotImplementedError, match="not implemented for MajorityVoter"):
mj._score_single_label(probabilities, tie_break_policy)
return
- annotation, prediction = mj._score_single_label(
- probabilities=probabilities, tie_break_policy=tie_break_policy
- )
+ annotation, prediction = mj._score_single_label(probabilities=probabilities, tie_break_policy=tie_break_policy)
assert np.allclose(annotation, expected[0])
assert np.allclose(prediction, expected[1])
def test_score_single_label_no_ties(self, weak_labels):
mj = MajorityVoter(weak_labels)
- probabilities = np.array(
- [[0.5, 0.3, 0.0], [0.5, 0.0, 0.0], [1.0 / 3, 0.0, 2.0 / 3]]
- )
+ probabilities = np.array([[0.5, 0.3, 0.0], [0.5, 0.0, 0.0], [1.0 / 3, 0.0, 2.0 / 3]])
- _, prediction = mj._score_single_label(
- probabilities=probabilities, tie_break_policy=TieBreakPolicy.ABSTAIN
- )
- _, prediction2 = mj._score_single_label(
- probabilities=probabilities, tie_break_policy=TieBreakPolicy.RANDOM
- )
+ _, prediction = mj._score_single_label(probabilities=probabilities, tie_break_policy=TieBreakPolicy.ABSTAIN)
+ _, prediction2 = mj._score_single_label(probabilities=probabilities, tie_break_policy=TieBreakPolicy.RANDOM)
assert np.allclose(prediction, prediction2)
def test_score_multi_label(self, weak_multi_labels):
mj = MajorityVoter(weak_multi_labels)
- probabilities = np.array(
- [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [np.nan, np.nan, np.nan]]
- )
+ probabilities = np.array([[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [np.nan, np.nan, np.nan]])
annotation, prediction = mj._score_multi_label(probabilities=probabilities)
@@ -428,9 +385,7 @@ def test_init_wrong_mapping(self, weak_labels, wrong_mapping, expected):
label_model = Snorkel(weak_labels)
assert label_model._weaklabelsInt2snorkelInt == expected
- assert label_model._snorkelInt2weaklabelsInt == {
- k: v for v, k in expected.items()
- }
+ assert label_model._snorkelInt2weaklabelsInt == {k: v for v, k in expected.items()}
@pytest.mark.parametrize(
"include_annotated_records",
@@ -450,9 +405,7 @@ def mock_fit(self, L_train, *args, **kwargs):
)
label_model = Snorkel(weak_labels)
- label_model.fit(
- include_annotated_records=include_annotated_records, passed_on=None
- )
+ label_model.fit(include_annotated_records=include_annotated_records, passed_on=None)
def test_fit_automatically_added_kwargs(self, weak_labels):
label_model = Snorkel(weak_labels)
@@ -525,12 +478,8 @@ def mock_predict(self, L, return_probs, tie_break_policy, *args, **kwargs):
prediction_agent="mock_agent",
)
assert len(records) == expected[0]
- assert [
- rec.prediction[0][0] if rec.prediction else None for rec in records
- ] == expected[1]
- assert [
- rec.prediction[0][1] if rec.prediction else None for rec in records
- ] == expected[2]
+ assert [rec.prediction[0][0] if rec.prediction else None for rec in records] == expected[1]
+ assert [rec.prediction[0][1] if rec.prediction else None for rec in records] == expected[2]
assert records[0].prediction_agent == "mock_agent"
@pytest.mark.parametrize("policy,expected", [("abstain", 0.5), ("random", 2.0 / 3)])
@@ -795,16 +744,12 @@ def mock_predict(self, weak_label_matrix, verbose):
assert isinstance(label_model.score(output_str=True), str)
- @pytest.mark.parametrize(
- "tbp,vrb,expected", [("abstain", False, 1.0), ("random", True, 2 / 3.0)]
- )
+ @pytest.mark.parametrize("tbp,vrb,expected", [("abstain", False, 1.0), ("random", True, 2 / 3.0)])
def test_score_tbp(self, monkeypatch, weak_labels, tbp, vrb, expected):
def mock_predict(self, weak_label_matrix, verbose):
assert verbose is vrb
assert len(weak_label_matrix) == 3
- return np.array(
- [[0.8, 0.1, 0.1], [0.4, 0.4, 0.2], [1 / 3.0, 1 / 3.0, 1 / 3.0]]
- )
+ return np.array([[0.8, 0.1, 0.1], [0.4, 0.4, 0.2], [1 / 3.0, 1 / 3.0, 1 / 3.0]])
monkeypatch.setattr(FlyingSquid, "_predict", mock_predict)
diff --git a/tests/labeling/text_classification/test_rule.py b/tests/labeling/text_classification/test_rule.py
index c46c7cdf3e..7920319b4e 100644
--- a/tests/labeling/text_classification/test_rule.py
+++ b/tests/labeling/text_classification/test_rule.py
@@ -69,9 +69,7 @@ def log_dataset(mocked_client) -> str:
"id": idx,
}
)
- for text, label, idx in zip(
- ["negative", "positive"], ["negative", "positive"], [1, 2]
- )
+ for text, label, idx in zip(["negative", "positive"], ["negative", "positive"], [1, 2])
]
mocked_client.post(
f"/api/datasets/{dataset_name}/TextClassification:bulk",
@@ -83,9 +81,7 @@ def log_dataset(mocked_client) -> str:
return dataset_name
-@pytest.mark.parametrize(
- "name,expected", [(None, "query_string"), ("test_name", "test_name")]
-)
+@pytest.mark.parametrize("name,expected", [(None, "query_string"), ("test_name", "test_name")])
def test_name(name, expected):
rule = Rule(query="query_string", label="mock", name=name)
assert rule.name == expected
@@ -357,9 +353,7 @@ def test_rule_metrics(mocked_client, log_dataset, rule, expected_metrics):
),
],
)
-def test_rule_metrics_without_annotated(
- mocked_client, log_dataset_without_annotations, rule, expected_metrics
-):
+def test_rule_metrics_without_annotated(mocked_client, log_dataset_without_annotations, rule, expected_metrics):
delete_rule_silently(mocked_client, log_dataset_without_annotations, rule)
mocked_client.post(
@@ -373,8 +367,6 @@ def test_rule_metrics_without_annotated(
def delete_rule_silently(client, dataset: str, rule: Rule):
try:
- client.delete(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}"
- )
+ client.delete(f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}")
except EntityNotFoundError:
pass
diff --git a/tests/labeling/text_classification/test_weak_labels.py b/tests/labeling/text_classification/test_weak_labels.py
index 8a42263881..e6d69f1341 100644
--- a/tests/labeling/text_classification/test_weak_labels.py
+++ b/tests/labeling/text_classification/test_weak_labels.py
@@ -163,9 +163,7 @@ def test_rules_from_dataset(self, monkeypatch, log_dataset):
assert wl.rules is mock_rules
def test_norulesfounderror(self, monkeypatch):
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load_rules", lambda x: []
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load_rules", lambda x: [])
with pytest.raises(NoRulesFoundError, match="No rules were found"):
WeakLabelsBase("mock")
@@ -179,13 +177,9 @@ def test_no_records_found_error(self, monkeypatch):
def mock_load(*args, **kwargs):
return []
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
- with pytest.raises(
- NoRecordsFoundError, match="No records found in dataset 'mock'."
- ):
+ with pytest.raises(NoRecordsFoundError, match="No records found in dataset 'mock'."):
WeakLabels(rules=[lambda x: None], dataset="mock")
with pytest.raises(
NoRecordsFoundError,
@@ -212,9 +206,7 @@ def test_rules_records_properties(self, monkeypatch):
def mock_load(*args, **kwargs):
return expected_records
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
weak_labels = WeakLabelsBase(rules=[lambda x: "mock"] * 2, dataset="mock")
@@ -233,9 +225,7 @@ def test_not_implemented_errors(self, monkeypatch):
def mock_load(*args, **kwargs):
return ["mock"]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
weak_labels = WeakLabelsBase(rules=["mock"], dataset="mock")
@@ -258,9 +248,7 @@ def test_faiss_not_installed(self, monkeypatch):
def mock_load(*args, **kwargs):
return ["mock"]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
monkeypatch.setitem(sys.modules, "faiss", None)
with pytest.raises(ModuleNotFoundError, match="pip install faiss-cpu"):
weak_labels = WeakLabelsBase(rules=[lambda x: "mock"] * 2, dataset="mock")
@@ -272,9 +260,7 @@ def test_multi_label_error(self, monkeypatch):
def mock_load(*args, **kwargs):
return [TextClassificationRecord(text="test", multi_label=True)]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
with pytest.raises(MultiLabelError):
WeakLabels(rules=[lambda x: None], dataset="mock")
@@ -291,9 +277,7 @@ def mock_load(*args, **kwargs):
(
{None: -10, "negative": 50, "positive": 10},
{None: -10, "negative": 50, "positive": 10},
- np.array(
- [[50, -10, 10], [-10, 10, 10], [-10, 10, -10]], dtype=np.short
- ),
+ np.array([[50, -10, 10], [-10, 10, 10], [-10, 10, -10]], dtype=np.short),
np.array([50, 10, -10], dtype=np.short),
),
({}, None, None, None),
@@ -335,9 +319,7 @@ def test_apply(
assert (weak_labels._annotation == expected_annotation_array).all()
def test_apply_MultiLabelError(self, log_dataset):
- with pytest.raises(
- MultiLabelError, match="For rules that do not return exactly 1 label"
- ):
+ with pytest.raises(MultiLabelError, match="For rules that do not return exactly 1 label"):
WeakLabels(rules=[lambda x: ["a", "b"]], dataset=log_dataset)
def test_matrix_annotation_properties(self, monkeypatch):
@@ -349,9 +331,7 @@ def test_matrix_annotation_properties(self, monkeypatch):
def mock_load(*args, **kwargs):
return expected_records
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array([[0, 1], [-1, 0]], dtype=np.short)
@@ -363,36 +343,21 @@ def mock_apply(self, *args, **kwargs):
weak_labels = WeakLabels(rules=[lambda x: "mock"] * 2, dataset="mock")
- assert (
- weak_labels.matrix(has_annotation=False)
- == np.array([[0, 1]], dtype=np.short)
- ).all()
- assert (
- weak_labels.matrix(has_annotation=True)
- == np.array([[-1, 0]], dtype=np.short)
- ).all()
+ assert (weak_labels.matrix(has_annotation=False) == np.array([[0, 1]], dtype=np.short)).all()
+ assert (weak_labels.matrix(has_annotation=True) == np.array([[-1, 0]], dtype=np.short)).all()
assert (weak_labels.annotation() == np.array([[0]], dtype=np.short)).all()
- assert (
- weak_labels.annotation(include_missing=True)
- == np.array([[-1, 0]], dtype=np.short)
- ).all()
- with pytest.warns(
- FutureWarning, match="'exclude_missing_annotations' is deprecated"
- ):
+ assert (weak_labels.annotation(include_missing=True) == np.array([[-1, 0]], dtype=np.short)).all()
+ with pytest.warns(FutureWarning, match="'exclude_missing_annotations' is deprecated"):
weak_labels.annotation(exclude_missing_annotations=True)
def test_summary(self, monkeypatch, rules):
def mock_load(*args, **kwargs):
return [TextClassificationRecord(text="test")] * 4
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
- weak_label_matrix = np.array(
- [[0, 1, -1], [-1, 0, -1], [-1, -1, -1], [1, 1, -1]], dtype=np.short
- )
+ weak_label_matrix = np.array([[0, 1, -1], [-1, 0, -1], [-1, -1, -1], [1, 1, -1]], dtype=np.short)
# weak_label_matrix = np.random.randint(-1, 30, (int(1e5), 50), dtype=np.short)
annotation_array = np.array([-1, -1, -1, -1], dtype=np.short)
# annotation_array = np.random.randint(-1, 30, int(1e5), dtype=np.short)
@@ -465,9 +430,7 @@ def test_show_records(self, monkeypatch, rules):
def mock_load(*args, **kwargs):
return [TextClassificationRecord(text="test", id=i) for i in range(5)]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array(
@@ -495,17 +458,13 @@ def mock_apply(self, *args, **kwargs):
1,
4,
]
- assert weak_labels.show_records(
- labels=["positive"], rules=["argilla_rule"]
- ).empty
+ assert weak_labels.show_records(labels=["positive"], rules=["argilla_rule"]).empty
def test_change_mapping(self, monkeypatch, rules):
def mock_load(*args, **kwargs):
return [TextClassificationRecord(text="test", id=i) for i in range(5)]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array(
@@ -543,10 +502,7 @@ def mock_apply(self, *args, **kwargs):
dtype=np.short,
)
).all()
- assert (
- weak_labels.annotation(include_missing=True)
- == np.array([2, 10, -10, 1, 2], dtype=np.short)
- ).all()
+ assert (weak_labels.annotation(include_missing=True) == np.array([2, 10, -10, 1, 2], dtype=np.short)).all()
assert weak_labels.label2int == new_mapping
assert weak_labels.int2label == {val: key for key, val in new_mapping.items()}
@@ -559,9 +515,7 @@ def weak_labels(self, monkeypatch, rules):
def mock_load(*args, **kwargs):
return [TextClassificationRecord(text="test", id=i) for i in range(3)]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array(
@@ -583,19 +537,13 @@ def test_extend_matrix(self, weak_labels):
):
weak_labels.extend_matrix([1.0, 0.5, 0.5])
- weak_labels.extend_matrix(
- [1.0, 0.5, 0.5], np.array([[0.1, 0.1], [0.1, 0.11], [0.11, 0.1]])
- )
+ weak_labels.extend_matrix([1.0, 0.5, 0.5], np.array([[0.1, 0.1], [0.1, 0.11], [0.11, 0.1]]))
- np.testing.assert_equal(
- weak_labels.matrix(), np.array([[0, -1, -1], [-1, 1, -1], [-1, 1, -1]])
- )
+ np.testing.assert_equal(weak_labels.matrix(), np.array([[0, -1, -1], [-1, 1, -1], [-1, 1, -1]]))
weak_labels.extend_matrix([1.0, 1.0, 1.0])
- np.testing.assert_equal(
- weak_labels.matrix(), np.array([[0, -1, -1], [-1, 1, -1], [-1, -1, -1]])
- )
+ np.testing.assert_equal(weak_labels.matrix(), np.array([[0, -1, -1], [-1, 1, -1], [-1, -1, -1]]))
class TestWeakMultiLabels:
@@ -604,9 +552,7 @@ def test_apply(
log_multilabel_dataset,
multilabel_rules,
):
- weak_labels = WeakMultiLabels(
- rules=multilabel_rules, dataset=log_multilabel_dataset
- )
+ weak_labels = WeakMultiLabels(rules=multilabel_rules, dataset=log_multilabel_dataset)
assert weak_labels.labels == ["negative", "positive"]
@@ -625,30 +571,21 @@ def test_apply(
)
).all()
- assert (
- weak_labels._annotation
- == np.array([[1, 0], [1, 1], [-1, -1]], dtype=np.short)
- ).all()
+ assert (weak_labels._annotation == np.array([[1, 0], [1, 1], [-1, -1]], dtype=np.short)).all()
def test_matrix_annotation(self, monkeypatch):
expected_records = [
TextClassificationRecord(text="test without annot", multi_label=True),
- TextClassificationRecord(
- text="test with annot", annotation="positive", multi_label=True
- ),
+ TextClassificationRecord(text="test with annot", annotation="positive", multi_label=True),
]
def mock_load(*args, **kwargs):
return expected_records
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
- weak_label_matrix = np.array(
- [[[1, 0], [0, 1]], [[-1, -1], [1, 0]]], dtype=np.short
- )
+ weak_label_matrix = np.array([[[1, 0], [0, 1]], [[-1, -1], [1, 0]]], dtype=np.short)
annotation_array = np.array([[-1, -1], [1, 0]], dtype=np.short)
labels = ["negative", "positive"]
return weak_label_matrix, annotation_array, labels
@@ -657,27 +594,16 @@ def mock_apply(self, *args, **kwargs):
weak_labels = WeakMultiLabels(rules=[lambda x: "mock"] * 2, dataset="mock")
- assert (
- weak_labels.matrix(has_annotation=False)
- == np.array([[[1, 0], [0, 1]]], dtype=np.short)
- ).all()
- assert (
- weak_labels.matrix(has_annotation=True)
- == np.array([[[-1, -1], [1, 0]]], dtype=np.short)
- ).all()
+ assert (weak_labels.matrix(has_annotation=False) == np.array([[[1, 0], [0, 1]]], dtype=np.short)).all()
+ assert (weak_labels.matrix(has_annotation=True) == np.array([[[-1, -1], [1, 0]]], dtype=np.short)).all()
assert (weak_labels.annotation() == np.array([[1, 0]], dtype=np.short)).all()
- assert (
- weak_labels.annotation(include_missing=True)
- == np.array([[[-1, -1], [1, 0]]], dtype=np.short)
- ).all()
+ assert (weak_labels.annotation(include_missing=True) == np.array([[[-1, -1], [1, 0]]], dtype=np.short)).all()
def test_summary(self, monkeypatch, multilabel_rules):
def mock_load(*args, **kwargs):
return [TextClassificationRecord(text="test", multi_label=True)] * 4
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array(
@@ -690,9 +616,7 @@ def mock_apply(self, *args, **kwargs):
dtype=np.short,
)
# weak_label_matrix = np.random.randint(-1, 30, (int(1e5), 50), dtype=np.short)
- annotation_array = np.array(
- [[-1, -1], [-1, -1], [-1, -1], [-1, -1]], dtype=np.short
- )
+ annotation_array = np.array([[-1, -1], [-1, -1], [-1, -1], [-1, -1]], dtype=np.short)
# annotation_array = np.random.randint(-1, 30, int(1e5), dtype=np.short)
labels = ["negative", "positive"]
# label2int = {k: v for k, v in zip(["None"]+list(range(30)), list(range(-1, 30)))}
@@ -734,9 +658,7 @@ def mock_apply(self, *args, **kwargs):
)
assert_frame_equal(summary, expected)
- summary = weak_labels.summary(
- annotation=np.array([[0, 1], [-1, -1], [0, 1], [1, 1]])
- )
+ summary = weak_labels.summary(annotation=np.array([[0, 1], [-1, -1], [0, 1], [1, 1]]))
expected = pd.DataFrame(
{
"label": [
@@ -762,9 +684,7 @@ def test_compute_correct_incorrect(self, monkeypatch):
def mock_load(*args, **kwargs):
return [TextClassificationRecord(text="mock")]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array([[[1, 0, 1, 0], [0, 1, 0, 1]]], dtype=np.short)
@@ -773,23 +693,16 @@ def mock_apply(self, *args, **kwargs):
monkeypatch.setattr(WeakMultiLabels, "_apply_rules", mock_apply)
weak_labels = WeakMultiLabels(rules=[lambda x: "mock"] * 2, dataset="mock")
- correct, incorrect = weak_labels._compute_correct_incorrect(
- annotation=np.array([[1, 0, 1, 0]])
- )
+ correct, incorrect = weak_labels._compute_correct_incorrect(annotation=np.array([[1, 0, 1, 0]]))
assert np.allclose(correct, np.array([2, 0, 2]))
assert np.allclose(incorrect, np.array([0, 2, 2]))
def test_show_records(self, monkeypatch, multilabel_rules):
def mock_load(*args, **kwargs):
- return [
- TextClassificationRecord(text="test", id=i, multi_label=True)
- for i in range(5)
- ]
+ return [TextClassificationRecord(text="test", id=i, multi_label=True) for i in range(5)]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array(
@@ -812,9 +725,7 @@ def mock_apply(self, *args, **kwargs):
assert weak_labels.show_records().id.tolist() == [0, 1, 2, 3, 4]
assert weak_labels.show_records(labels=["positive"]).id.tolist() == [0, 1, 3]
assert weak_labels.show_records(labels=["negative"]).id.tolist() == [0, 1, 4]
- assert weak_labels.show_records(
- labels=["negative", "positive"]
- ).id.tolist() == [0, 1]
+ assert weak_labels.show_records(labels=["negative", "positive"]).id.tolist() == [0, 1]
assert weak_labels.show_records(rules=[0]).id.tolist() == [0, 1, 3]
assert weak_labels.show_records(rules=[0, "rule_1"]).id.tolist() == [0, 1, 3]
@@ -823,21 +734,14 @@ def mock_apply(self, *args, **kwargs):
1,
4,
]
- assert weak_labels.show_records(
- labels=["positive"], rules=["argilla_rule"]
- ).empty
+ assert weak_labels.show_records(labels=["positive"], rules=["argilla_rule"]).empty
@pytest.fixture
def weak_multi_labels(self, monkeypatch, rules):
def mock_load(*args, **kwargs):
- return [
- TextClassificationRecord(text="test", id=i, multi_label=True)
- for i in range(3)
- ]
+ return [TextClassificationRecord(text="test", id=i, multi_label=True) for i in range(3)]
- monkeypatch.setattr(
- "argilla.labeling.text_classification.weak_labels.load", mock_load
- )
+ monkeypatch.setattr("argilla.labeling.text_classification.weak_labels.load", mock_load)
def mock_apply(self, *args, **kwargs):
weak_label_matrix = np.array(
@@ -862,9 +766,7 @@ def test_extend_matrix(self, weak_multi_labels):
):
weak_multi_labels.extend_matrix([1.0, 0.5, 0.5])
- weak_multi_labels.extend_matrix(
- [1.0, 0.5, 0.5], np.array([[0.1, 0.1], [0.1, 0.11], [0.11, 0.1]])
- )
+ weak_multi_labels.extend_matrix([1.0, 0.5, 0.5], np.array([[0.1, 0.1], [0.1, 0.11], [0.11, 0.1]]))
np.testing.assert_equal(
weak_multi_labels.matrix(),
@@ -891,16 +793,10 @@ def test_extend_matrix(self, weak_multi_labels):
)
# The "correct" and "incorrect" columns from `expected_summary` may infer a different
# dtype than `weak_multi_labels.summary()` returns.
- assert_frame_equal(
- weak_multi_labels.summary(), expected_summary, check_dtype=False
- )
+ assert_frame_equal(weak_multi_labels.summary(), expected_summary, check_dtype=False)
- expected_show_records = pd.DataFrame(
- map(lambda x: x.dict(), weak_multi_labels.records())
- )
- assert_frame_equal(
- weak_multi_labels.show_records(rules=["rule_1"]), expected_show_records
- )
+ expected_show_records = pd.DataFrame(map(lambda x: x.dict(), weak_multi_labels.records()))
+ assert_frame_equal(weak_multi_labels.show_records(rules=["rule_1"]), expected_show_records)
weak_multi_labels.extend_matrix([1.0, 1.0, 1.0])
diff --git a/tests/listeners/test_listener.py b/tests/listeners/test_listener.py
index 8549036869..7651a7dbaf 100644
--- a/tests/listeners/test_listener.py
+++ b/tests/listeners/test_listener.py
@@ -39,9 +39,7 @@ def condition_check_params(search):
("dataset", "val + {param}", None, condition_check_params, {"param": 100}),
],
)
-def test_listener_with_parameters(
- mocked_client, dataset, query, metrics, condition, query_params
-):
+def test_listener_with_parameters(mocked_client, dataset, query, metrics, condition, query_params):
ar.delete(dataset)
class TestListener:
diff --git a/tests/metrics/test_common_metrics.py b/tests/metrics/test_common_metrics.py
index 4cbb6f80f6..1bf51590aa 100644
--- a/tests/metrics/test_common_metrics.py
+++ b/tests/metrics/test_common_metrics.py
@@ -137,9 +137,7 @@ def test_keywords_metrics(mocked_client, gutenberg_spacy_ner):
"two": 18,
}
- assert keywords(name=gutenberg_spacy_ner) == keywords(
- name=gutenberg_spacy_ner, query=""
- )
+ assert keywords(name=gutenberg_spacy_ner) == keywords(name=gutenberg_spacy_ner, query="")
with pytest.raises(AssertionError, match="size must be greater than 0"):
keywords(name=gutenberg_spacy_ner, size=0)
diff --git a/tests/monitoring/test_monitor.py b/tests/monitoring/test_monitor.py
index 18240551af..39b5db059a 100644
--- a/tests/monitoring/test_monitor.py
+++ b/tests/monitoring/test_monitor.py
@@ -30,8 +30,7 @@ def test_monitor_with_non_supported_model():
assert len(warning_list) == 1
warn_text = warning_list[0].message.args[0]
assert (
- warn_text
- == "The provided task model is not supported by monitoring module. "
+ warn_text == "The provided task model is not supported by monitoring module. "
"Predictions won't be logged into argilla"
)
@@ -53,7 +52,6 @@ def test_monitor_non_supported_huggingface_model():
assert len(warning_list) == 1
warn_text = warning_list[0].message.args[0]
assert (
- warn_text
- == "The provided task model is not supported by monitoring module. "
+ warn_text == "The provided task model is not supported by monitoring module. "
"Predictions won't be logged into argilla"
)
diff --git a/tests/monitoring/test_transformers_monitoring.py b/tests/monitoring/test_transformers_monitoring.py
index 225f606518..e478f1d572 100644
--- a/tests/monitoring/test_transformers_monitoring.py
+++ b/tests/monitoring/test_transformers_monitoring.py
@@ -198,10 +198,7 @@ def check_zero_shot_results(
assert record.inputs["text"] == text
assert record.metadata == {"labels": labels, "hypothesis_template": hypothesis}
assert record.prediction_agent == zero_shot_classifier.model.config.name_or_path
- assert record.prediction == [
- (label, score)
- for label, score in zip(predictions["labels"], predictions["scores"])
- ]
+ assert record.prediction == [(label, score) for label, score in zip(predictions["labels"], predictions["scores"])]
@pytest.mark.parametrize(
@@ -306,9 +303,7 @@ def test_monitor_zero_shot_with_text_array(
dataset,
):
argilla.delete(dataset)
- predictions = mocked_monitor(
- [text], candidate_labels=labels, hypothesis_template=hypothesis
- )
+ predictions = mocked_monitor([text], candidate_labels=labels, hypothesis_template=hypothesis)
check_zero_shot_results(
predictions,
diff --git a/tests/server/backend/test_query_builder.py b/tests/server/backend/test_query_builder.py
index 0276da15eb..4db8dff1c3 100644
--- a/tests/server/backend/test_query_builder.py
+++ b/tests/server/backend/test_query_builder.py
@@ -62,9 +62,7 @@
def test_build_sort_configuration(index_schema, sort_cfg, expected_sort):
builder = EsQueryBuilder()
- es_sort = builder.map_2_es_sort_configuration(
- sort=SortConfig(sort_by=sort_cfg), schema=index_schema
- )
+ es_sort = builder.map_2_es_sort_configuration(sort=SortConfig(sort_by=sort_cfg), schema=index_schema)
assert es_sort == expected_sort
@@ -72,9 +70,7 @@ def test_build_sort_with_wrong_field_name():
builder = EsQueryBuilder()
with pytest.raises(Exception):
- builder.map_2_es_sort_configuration(
- sort=SortConfig(sort_by=[SortableField(id="wat?!")])
- )
+ builder.map_2_es_sort_configuration(sort=SortConfig(sort_by=[SortableField(id="wat?!")]))
def test_build_sort_without_sort_config():
diff --git a/tests/server/commons/test_telemetry.py b/tests/server/commons/test_telemetry.py
index f3ce903918..0b2902ea03 100644
--- a/tests/server/commons/test_telemetry.py
+++ b/tests/server/commons/test_telemetry.py
@@ -45,9 +45,7 @@ async def test_track_bulk(telemetry_track_data):
task, records = TaskType.token_classification, 100
await telemetry.track_bulk(task=task, records=records)
- telemetry_track_data.assert_called_once_with(
- "LogRecordsRequested", {"task": task, "records": records}
- )
+ telemetry_track_data.assert_called_once_with("LogRecordsRequested", {"task": task, "records": records})
@pytest.mark.asyncio
diff --git a/tests/server/datasets/test_api.py b/tests/server/datasets/test_api.py
index 2d0ebebf0b..4f613f0123 100644
--- a/tests/server/datasets/test_api.py
+++ b/tests/server/datasets/test_api.py
@@ -178,9 +178,7 @@ def test_update_dataset(mocked_client):
delete_dataset(mocked_client, dataset)
create_mock_dataset(mocked_client, dataset)
- response = mocked_client.patch(
- f"/api/datasets/{dataset}", json={"metadata": {"new": "value"}}
- )
+ response = mocked_client.patch(f"/api/datasets/{dataset}", json={"metadata": {"new": "value"}})
assert response.status_code == 200
response = mocked_client.get(f"/api/datasets/{dataset}")
@@ -206,12 +204,7 @@ def test_open_and_close_dataset(mocked_client):
}
assert mocked_client.put(f"/api/datasets/{dataset}:open").status_code == 200
- assert (
- mocked_client.post(
- f"/api/datasets/{dataset}/TextClassification:search"
- ).status_code
- == 200
- )
+ assert mocked_client.post(f"/api/datasets/{dataset}/TextClassification:search").status_code == 200
def delete_dataset(client, dataset, workspace: Optional[str] = None):
@@ -268,9 +261,7 @@ def test_delete_records(mocked_client):
}
}
- response = mocked_client.delete(
- f"/api/datasets/{dataset_name}/data?mark_as_discarded=true"
- )
+ response = mocked_client.delete(f"/api/datasets/{dataset_name}/data?mark_as_discarded=true")
assert response.status_code == 200
assert response.json() == {
"matched": 99,
diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py
index 21032362b2..0dc30ba139 100644
--- a/tests/server/metrics/test_api.py
+++ b/tests/server/metrics/test_api.py
@@ -61,9 +61,7 @@ def test_wrong_dataset_metrics(mocked_client):
assert response.json() == {
"detail": {
"code": "argilla.api.errors::WrongTaskError",
- "params": {
- "message": "Provided task TokenClassification cannot be applied to dataset"
- },
+ "params": {"message": "Provided task TokenClassification cannot be applied to dataset"},
}
}
@@ -76,9 +74,7 @@ def test_wrong_dataset_metrics(mocked_client):
assert response.json() == {
"detail": {
"code": "argilla.api.errors::WrongTaskError",
- "params": {
- "message": "Provided task TokenClassification cannot be applied to dataset"
- },
+ "params": {"message": "Provided task TokenClassification cannot be applied to dataset"},
}
}
@@ -136,9 +132,7 @@ def test_dataset_for_token_classification(mocked_client):
).status_code
== 200
)
- metrics = mocked_client.get(
- f"/api/datasets/TokenClassification/{dataset}/metrics"
- ).json()
+ metrics = mocked_client.get(f"/api/datasets/TokenClassification/{dataset}/metrics").json()
assert len(metrics) == len(TokenClassificationMetrics.metrics)
for metric in metrics:
@@ -187,9 +181,7 @@ def test_dataset_metrics(mocked_client):
== 200
)
- metrics = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/metrics"
- ).json()
+ metrics = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/metrics").json()
assert len(metrics) == COMMON_METRICS_LENGTH + 5
diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py
index 757349d6a0..1511695f63 100644
--- a/tests/server/security/test_model.py
+++ b/tests/server/security/test_model.py
@@ -24,17 +24,13 @@ def test_valid_mail(email):
assert user.email == email
-@pytest.mark.parametrize(
- "wrong_email", ["non-valid-email", "wrong@mail", "@wrong" "wrong.mail"]
-)
+@pytest.mark.parametrize("wrong_email", ["non-valid-email", "wrong@mail", "@wrong" "wrong.mail"])
def test_email_validator(wrong_email):
with pytest.raises(ValidationError):
User(username="user", email=wrong_email)
-@pytest.mark.parametrize(
- "wrong_name", ["user name", "user/name", "user.name", "UserName", "userName"]
-)
+@pytest.mark.parametrize("wrong_name", ["user name", "user/name", "user.name", "UserName", "userName"])
def test_username_validator(wrong_name):
with pytest.raises(
ValidationError,
@@ -43,9 +39,7 @@ def test_username_validator(wrong_name):
User(username=wrong_name)
-@pytest.mark.parametrize(
- "wrong_workspace", ["work space", "work/space", "work.space", "_", "-"]
-)
+@pytest.mark.parametrize("wrong_workspace", ["work space", "work/space", "work.space", "_", "-"])
def test_workspace_validator(wrong_workspace):
with pytest.raises(ValidationError):
User(username="username", workspaces=[wrong_workspace])
diff --git a/tests/server/security/test_provider.py b/tests/server/security/test_provider.py
index 6d169e74dc..27e73cfc7e 100644
--- a/tests/server/security/test_provider.py
+++ b/tests/server/security/test_provider.py
@@ -33,9 +33,7 @@ async def test_get_user_via_token():
@pytest.mark.asyncio
async def test_get_user_via_api_key():
- user = await localAuth.get_user(
- security_scopes=security_Scopes, api_key=DEFAULT_API_KEY
- )
+ user = await localAuth.get_user(security_scopes=security_Scopes, api_key=DEFAULT_API_KEY)
assert user.username == "argilla"
diff --git a/tests/server/test_errors.py b/tests/server/test_errors.py
index 0090ab2e31..80da4da6f3 100644
--- a/tests/server/test_errors.py
+++ b/tests/server/test_errors.py
@@ -17,7 +17,4 @@
def test_generic_error():
err = GenericServerError(error=ValueError("this is an error"))
- assert (
- str(err)
- == "argilla.api.errors::GenericServerError(type=builtins.ValueError,message=this is an error)"
- )
+ assert str(err) == "argilla.api.errors::GenericServerError(type=builtins.ValueError,message=this is an error)"
diff --git a/tests/server/text2text/test_api.py b/tests/server/text2text/test_api.py
index c5db714f9d..aec7eec723 100644
--- a/tests/server/text2text/test_api.py
+++ b/tests/server/text2text/test_api.py
@@ -181,9 +181,7 @@ def test_api_with_new_predictions_data_model(mocked_client):
{
"text": "This is a text data",
"predictions": {
- "test": {
- "sentences": [{"text": "This is a test data", "score": 0.6}]
- },
+ "test": {"sentences": [{"text": "This is a test data", "score": 0.6}]},
},
}
),
diff --git a/tests/server/text_classification/test_api.py b/tests/server/text_classification/test_api.py
index 6791614471..79210fdf55 100644
--- a/tests/server/text_classification/test_api.py
+++ b/tests/server/text_classification/test_api.py
@@ -99,9 +99,7 @@ def test_create_records_for_text_classification_with_multi_label(mocked_client):
).dict(by_alias=True),
)
- get_dataset = Dataset.parse_obj(
- mocked_client.get(f"/api/datasets/{dataset}").json()
- )
+ get_dataset = Dataset.parse_obj(mocked_client.get(f"/api/datasets/{dataset}").json())
assert get_dataset.tags == {
"env": "test",
"class": "text classification",
@@ -202,9 +200,7 @@ def test_create_records_for_text_classification(mocked_client, telemetry_track_d
condition=not SUPPORTED_VECTOR_SEARCH,
reason="Vector search not supported",
)
-def test_create_records_for_text_classification_vector_search(
- mocked_client, telemetry_track_data
-):
+def test_create_records_for_text_classification_vector_search(mocked_client, telemetry_track_data):
dataset = "test_create_records_for_text_classification_vector_search"
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
tags = {"env": "test", "class": "text classification"}
@@ -271,9 +267,7 @@ def test_create_records_for_text_classification_vector_search(
assert created_dataset.tags == tags
assert created_dataset.metadata == metadata
- response = mocked_client.post(
- f"/api/datasets/{dataset}/TextClassification:search", json={}
- )
+ response = mocked_client.post(f"/api/datasets/{dataset}/TextClassification:search", json={})
assert response.status_code == 200
results = TextClassificationSearchResults.parse_obj(response.json())
@@ -362,9 +356,7 @@ def test_partial_record_update(mocked_client):
response = mocked_client.post(
f"/api/datasets/{name}/TextClassification:search",
json={
- "query": TextClassificationQuery(predicted=PredictionStatus.OK).dict(
- by_alias=True
- ),
+ "query": TextClassificationQuery(predicted=PredictionStatus.OK).dict(by_alias=True),
},
)
@@ -374,9 +366,7 @@ def test_partial_record_update(mocked_client):
first_record = results.records[0]
assert first_record.last_updated is not None
first_record.last_updated = None
- assert TextClassificationRecord(
- **first_record.dict(by_alias=True, exclude_none=True)
- ) == TextClassificationRecord(
+ assert TextClassificationRecord(**first_record.dict(by_alias=True, exclude_none=True)) == TextClassificationRecord(
**{
"id": 1,
"inputs": {"text": "This is a text, oh yeah!"},
@@ -508,10 +498,7 @@ def test_some_sort_by(mocked_client):
},
}
}
- assert (
- response.json()["detail"]["code"]
- == expected_response_property_name_2_value["detail"]["code"]
- )
+ assert response.json()["detail"]["code"] == expected_response_property_name_2_value["detail"]["code"]
assert (
response.json()["detail"]["params"]["message"]
== expected_response_property_name_2_value["detail"]["params"]["message"]
@@ -721,9 +708,7 @@ def test_wrong_text_query(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:search",
- json=TextClassificationSearchRequest(
- query=TextClassificationQuery(query_text="!")
- ).dict(),
+ json=TextClassificationSearchRequest(query=TextClassificationQuery(query_text="!")).dict(),
)
assert response.status_code == 400
assert response.json() == {
@@ -755,18 +740,14 @@ def test_search_using_text(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:search",
- json=TextClassificationSearchRequest(
- query=TextClassificationQuery(query_text="text: texto")
- ).dict(),
+ json=TextClassificationSearchRequest(query=TextClassificationQuery(query_text="text: texto")).dict(),
)
assert response.status_code == 200
assert response.json()["total"] == 1
response = mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:search",
- json=TextClassificationSearchRequest(
- query=TextClassificationQuery(query_text="text.exact: texto")
- ).dict(),
+ json=TextClassificationSearchRequest(query=TextClassificationQuery(query_text="text.exact: texto")).dict(),
)
assert response.status_code == 200
assert response.json()["total"] == 0
diff --git a/tests/server/text_classification/test_api_rules.py b/tests/server/text_classification/test_api_rules.py
index 464ab427af..16291950ff 100644
--- a/tests/server/text_classification/test_api_rules.py
+++ b/tests/server/text_classification/test_api_rules.py
@@ -57,9 +57,7 @@ def test_dataset_without_rules(mocked_client):
dataset = "test_dataset_without_rules"
log_some_records(mocked_client, dataset)
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules")
assert response.status_code == 200
assert len(response.json()) == 0
@@ -80,9 +78,7 @@ def test_dataset_update_rule(mocked_client):
json={"label": "NEW Label"},
)
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules")
rules = list(map(LabelingRule.parse_obj, response.json()))
assert len(rules) == 1
assert rules[0].label == "NEW Label"
@@ -94,9 +90,7 @@ def test_dataset_update_rule(mocked_client):
json={"labels": ["A", "B"], "description": "New description"},
)
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules")
rules = list(map(LabelingRule.parse_obj, response.json()))
assert len(rules) == 1
assert rules[0].description == "New description"
@@ -109,9 +103,7 @@ def test_dataset_update_rule(mocked_client):
[
CreateLabelingRule(query="a query", description="Description", label="LALA"),
CreateLabelingRule(query="/a qu?ry/", description="Description", label="LALA"),
- CreateLabelingRule(
- query="another query", description="Description", labels=["A", "B", "C"]
- ),
+ CreateLabelingRule(query="another query", description="Description", labels=["A", "B", "C"]),
],
)
def test_dataset_with_rules(mocked_client, rule):
@@ -130,9 +122,7 @@ def test_dataset_with_rules(mocked_client, rule):
assert created_rule.labels == rule.labels
assert created_rule.description == rule.description
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules")
assert response.status_code == 200
rules = list(map(LabelingRule.parse_obj, response.json()))
assert len(rules) == 1
@@ -143,12 +133,8 @@ def test_dataset_with_rules(mocked_client, rule):
"rule",
[
CreateLabelingRule(query="a query", description="Description", label="LALA"),
- CreateLabelingRule(
- query="/a qu(e|E)ry/", description="Description", label="LALA"
- ),
- CreateLabelingRule(
- query="another query", description="Description", labels=["A", "B", "C"]
- ),
+ CreateLabelingRule(query="/a qu(e|E)ry/", description="Description", label="LALA"),
+ CreateLabelingRule(query="another query", description="Description", labels=["A", "B", "C"]),
],
)
def test_get_dataset_rule(mocked_client, rule):
@@ -161,9 +147,7 @@ def test_get_dataset_rule(mocked_client, rule):
)
assert response.status_code == 200
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}")
assert response.status_code == 200
found_rule = LabelingRule.parse_obj(response.json())
assert found_rule.query == rule.query
@@ -178,20 +162,14 @@ def test_delete_dataset_rules(mocked_client):
response = mocked_client.post(
f"/api/datasets/TextClassification/{dataset}/labeling/rules",
- json=CreateLabelingRule(
- query="/a query/", label="TEST", description="Description"
- ).dict(),
+ json=CreateLabelingRule(query="/a query/", label="TEST", description="Description").dict(),
)
assert response.status_code == 200
- response = mocked_client.delete(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules//a query/"
- )
+ response = mocked_client.delete(f"/api/datasets/TextClassification/{dataset}/labeling/rules//a query/")
assert response.status_code == 200
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules")
assert response.status_code == 200
assert len(response.json()) == 0
@@ -242,9 +220,7 @@ def test_rule_metrics_with_missing_label(mocked_client):
dataset = "test_rule_metrics_with_missing_label"
log_some_records(mocked_client, dataset, annotation="OK")
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules/a query/metrics"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/a query/metrics")
assert response.status_code == 200, response.json()
assert response.json() == {
"coverage": 0.0,
@@ -344,18 +320,12 @@ def test_rule_metrics_with_missing_label(mocked_client):
),
],
)
-def test_rule_metrics_with_missing_label_for_stored_rule(
- mocked_client, rule, expected_metrics
-):
+def test_rule_metrics_with_missing_label_for_stored_rule(mocked_client, rule, expected_metrics):
dataset = "test_rule_metrics_with_missing_label_for_stored_rule"
log_some_records(mocked_client, dataset, annotation="o.k.")
- mocked_client.post(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules", json=rule.dict()
- )
+ mocked_client.post(f"/api/datasets/TextClassification/{dataset}/labeling/rules", json=rule.dict())
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}/metrics"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/{rule.query}/metrics")
assert response.status_code == 200
assert response.json() == expected_metrics
@@ -366,21 +336,15 @@ def test_create_rules_and_then_log(mocked_client):
for query in ["ejemplo", "bad query"]:
mocked_client.post(
f"/api/datasets/TextClassification/{dataset}/labeling/rules",
- json=CreateLabelingRule(
- query=query, label="TEST", description="Description"
- ).dict(),
+ json=CreateLabelingRule(query=query, label="TEST", description="Description").dict(),
)
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules")
rules = list(map(LabelingRule.parse_obj, response.json()))
assert len(rules) == 2
log_some_records(mocked_client, dataset, annotation="OK", delete=False)
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules")
rules = list(map(LabelingRule.parse_obj, response.json()))
assert len(rules) == 2
@@ -441,9 +405,7 @@ def test_dataset_rules_metrics(mocked_client, rules, expected_metrics, annotatio
json=rule.dict(),
)
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules/metrics"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/metrics")
assert response.status_code == 200, response.json()
assert response.json() == expected_metrics
@@ -465,9 +427,7 @@ def test_rule_metric(mocked_client):
assert metrics.incorrect == 1
assert metrics.precision == 0
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules/ejemplo/metrics?label=OK"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/ejemplo/metrics?label=OK")
assert response.status_code == 200
metrics = LabelingRuleMetricsSummary.parse_obj(response.json())
@@ -475,9 +435,7 @@ def test_rule_metric(mocked_client):
assert metrics.incorrect == 0
assert metrics.precision == 1
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules/ejemplo/metrics"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/ejemplo/metrics")
assert response.status_code == 200
metrics = LabelingRuleMetricsSummary.parse_obj(response.json())
@@ -486,9 +444,7 @@ def test_rule_metric(mocked_client):
assert metrics.precision is None
assert metrics.coverage_annotated == 1
- response = mocked_client.get(
- f"/api/datasets/TextClassification/{dataset}/labeling/rules/badd/metrics?label=OK"
- )
+ response = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/labeling/rules/badd/metrics?label=OK")
assert response.status_code == 200
metrics = LabelingRuleMetricsSummary.parse_obj(response.json())
diff --git a/tests/server/text_classification/test_api_settings.py b/tests/server/text_classification/test_api_settings.py
index 03e50a3c8c..e091ce7952 100644
--- a/tests/server/text_classification/test_api_settings.py
+++ b/tests/server/text_classification/test_api_settings.py
@@ -17,9 +17,7 @@
def create_dataset(client, name: str):
- response = client.post(
- "/api/datasets", json={"name": name, "task": TaskType.text_classification}
- )
+ response = client.post("/api/datasets", json={"name": name, "task": TaskType.text_classification})
assert response.status_code == 200
@@ -60,9 +58,7 @@ def test_delete_settings(mocked_client):
create_dataset(mocked_client, name)
assert create_settings(mocked_client, name).status_code == 200
- response = mocked_client.delete(
- f"/api/datasets/{TaskType.text_classification}/{name}/settings"
- )
+ response = mocked_client.delete(f"/api/datasets/{TaskType.text_classification}/{name}/settings")
assert response.status_code == 200
assert fetch_settings(mocked_client, name).status_code == 404
@@ -131,6 +127,4 @@ def log_some_data(mocked_client, name):
def fetch_settings(mocked_client, name):
- return mocked_client.get(
- f"/api/datasets/{TaskType.text_classification}/{name}/settings"
- )
+ return mocked_client.get(f"/api/datasets/{TaskType.text_classification}/{name}/settings")
diff --git a/tests/server/text_classification/test_model.py b/tests/server/text_classification/test_model.py
index 87e344cb02..c20d984804 100644
--- a/tests/server/text_classification/test_model.py
+++ b/tests/server/text_classification/test_model.py
@@ -31,9 +31,7 @@
def test_flatten_metadata():
data = {
"inputs": {"text": "bogh"},
- "metadata": {
- "mail": {"subject": "The mail subject", "body": "This is a large text body"}
- },
+ "metadata": {"mail": {"subject": "The mail subject", "body": "This is a large text body"}},
}
record = ServiceTextClassificationRecord.parse_obj(data)
assert list(record.metadata.keys()) == ["mail.subject", "mail.body"]
@@ -378,9 +376,7 @@ def test_empty_labels_for_no_multilabel():
record = ServiceTextClassificationRecord(
inputs={"text": "The input text"},
prediction=TextClassificationAnnotation(agent="ann.", labels=[]),
- annotation=TextClassificationAnnotation(
- agent="ann.", labels=[ClassPrediction(class_label="B")]
- ),
+ annotation=TextClassificationAnnotation(agent="ann.", labels=[ClassPrediction(class_label="B")]),
)
assert record.predicted == PredictionStatus.KO
@@ -400,12 +396,8 @@ def test_using_predictions_dict():
record = ServiceTextClassificationRecord(
inputs={"text": "this is a text"},
predictions={
- "carl": TextClassificationAnnotation(
- agent="wat at", labels=[ClassPrediction(class_label="YES")]
- ),
- "BOB": TextClassificationAnnotation(
- agent="wot wot", labels=[ClassPrediction(class_label="NO")]
- ),
+ "carl": TextClassificationAnnotation(agent="wat at", labels=[ClassPrediction(class_label="YES")]),
+ "BOB": TextClassificationAnnotation(agent="wot wot", labels=[ClassPrediction(class_label="NO")]),
},
)
@@ -415,9 +407,7 @@ def test_using_predictions_dict():
}
assert record.predictions == {
"BOB": TextClassificationAnnotation(labels=[ClassPrediction(class_label="NO")]),
- "carl": TextClassificationAnnotation(
- labels=[ClassPrediction(class_label="YES")]
- ),
+ "carl": TextClassificationAnnotation(labels=[ClassPrediction(class_label="YES")]),
}
@@ -425,7 +415,5 @@ def test_with_no_agent_at_all():
with pytest.raises(ValidationError):
ServiceTextClassificationRecord(
inputs={"text": "this is a text"},
- prediction=TextClassificationAnnotation(
- labels=[ClassPrediction(class_label="YES")]
- ),
+ prediction=TextClassificationAnnotation(labels=[ClassPrediction(class_label="YES")]),
)
diff --git a/tests/server/token_classification/test_api.py b/tests/server/token_classification/test_api.py
index 8260300012..577b84f7cd 100644
--- a/tests/server/token_classification/test_api.py
+++ b/tests/server/token_classification/test_api.py
@@ -58,10 +58,7 @@ def test_load_as_different_task(mocked_client):
assert response.json() == {
"detail": {
"code": "argilla.api.errors::WrongTaskError",
- "params": {
- "message": "Provided task TextClassification cannot be "
- "applied to dataset"
- },
+ "params": {"message": "Provided task TextClassification cannot be " "applied to dataset"},
}
}
@@ -90,9 +87,7 @@ def test_search_special_characters(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TokenClassification:search",
- json=TokenClassificationSearchRequest(
- query=TokenClassificationQuery(query_text="\!")
- ).dict(),
+ json=TokenClassificationSearchRequest(query=TokenClassificationQuery(query_text="\!")).dict(),
)
assert response.status_code == 200, response.json()
results = TokenClassificationSearchResults.parse_obj(response.json())
@@ -100,9 +95,7 @@ def test_search_special_characters(mocked_client):
response = mocked_client.post(
f"/api/datasets/{dataset}/TokenClassification:search",
- json=TokenClassificationSearchRequest(
- query=TokenClassificationQuery(query_text="text.exact:\!")
- ).dict(),
+ json=TokenClassificationSearchRequest(query=TokenClassificationQuery(query_text="text.exact:\!")).dict(),
)
assert response.status_code == 200, response.json()
results = TokenClassificationSearchResults.parse_obj(response.json())
@@ -161,9 +154,7 @@ def test_some_sort(mocked_client):
(None, lambda r: len(r.metrics) == 0),
],
)
-def test_create_records_for_token_classification(
- mocked_client, include_metrics: bool, metrics_validator: Callable
-):
+def test_create_records_for_token_classification(mocked_client, include_metrics: bool, metrics_validator: Callable):
dataset = "test_create_records_for_token_classification"
assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
entity_label = "TEST"
diff --git a/tests/server/token_classification/test_api_settings.py b/tests/server/token_classification/test_api_settings.py
index 3701e4de53..587ea89177 100644
--- a/tests/server/token_classification/test_api_settings.py
+++ b/tests/server/token_classification/test_api_settings.py
@@ -17,9 +17,7 @@
def create_dataset(client, name: str):
- response = client.post(
- "/api/datasets", json={"name": name, "task": TaskType.token_classification}
- )
+ response = client.post("/api/datasets", json={"name": name, "task": TaskType.token_classification})
assert response.status_code == 200
@@ -60,9 +58,7 @@ def test_delete_settings(mocked_client):
create_dataset(mocked_client, name)
assert create_settings(mocked_client, name).status_code == 200
- response = mocked_client.delete(
- f"/api/datasets/{TaskType.token_classification}/{name}/settings"
- )
+ response = mocked_client.delete(f"/api/datasets/{TaskType.token_classification}/{name}/settings")
assert response.status_code == 200
assert fetch_settings(mocked_client, name).status_code == 404
@@ -135,6 +131,4 @@ def test_validate_settings_after_logging(mocked_client):
def fetch_settings(mocked_client, name):
- return mocked_client.get(
- f"/api/datasets/{TaskType.token_classification}/{name}/settings"
- )
+ return mocked_client.get(f"/api/datasets/{TaskType.token_classification}/{name}/settings")
diff --git a/tests/server/token_classification/test_model.py b/tests/server/token_classification/test_model.py
index 8d0c12bbff..878db93577 100644
--- a/tests/server/token_classification/test_model.py
+++ b/tests/server/token_classification/test_model.py
@@ -119,9 +119,7 @@ def test_model_with_predictions():
"metrics": {},
"predictions": {
"test": {
- "entities": [
- {"end": 24, "label": "test", "score": 1.0, "start": 9}
- ],
+ "entities": [{"end": 24, "label": "test", "score": 1.0, "start": 9}],
}
},
"status": "Default",
@@ -162,9 +160,7 @@ def test_too_long_metadata():
def test_entity_label_too_long():
text = "On one ones o no"
- with pytest.raises(
- ValidationError, match="ensure this value has at most 128 character"
- ):
+ with pytest.raises(ValidationError, match="ensure this value has at most 128 character"):
ServiceTokenClassificationRecord(
text=text,
tokens=text.split(),
@@ -269,9 +265,7 @@ def test_annotated_without_entities():
record = ServiceTokenClassificationRecord(
text=text,
tokens=text.split(),
- prediction=TokenClassificationAnnotation(
- agent="pred.test", entities=[EntitySpan(start=0, end=3, label="DET")]
- ),
+ prediction=TokenClassificationAnnotation(agent="pred.test", entities=[EntitySpan(start=0, end=3, label="DET")]),
annotation=TokenClassificationAnnotation(agent="test", entities=[]),
)
diff --git a/tests/utils/test_span_utils.py b/tests/utils/test_span_utils.py
index 3cde38131c..16b71d10b1 100644
--- a/tests/utils/test_span_utils.py
+++ b/tests/utils/test_span_utils.py
@@ -43,9 +43,7 @@ def test_init():
def test_init_value_error():
- with pytest.raises(
- ValueError, match="Token 'ValueError' not found in text: test error"
- ):
+ with pytest.raises(ValueError, match="Token 'ValueError' not found in text: test error"):
SpanUtils(text="test error", tokens=["test", "ValueError"])
@@ -56,9 +54,7 @@ def test_validate():
def test_validate_not_valid_spans():
span_utils = SpanUtils("test this.", ["test", "this", "."])
- with pytest.raises(
- ValueError, match="Following entity spans are not valid: \[\('mock', 2, 1\)\]\n"
- ):
+ with pytest.raises(ValueError, match="Following entity spans are not valid: \[\('mock', 2, 1\)\]\n"):
span_utils.validate([("mock", 2, 1)])
diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py
index 85fe24e9f5..6a08af0c66 100644
--- a/tests/utils/test_utils.py
+++ b/tests/utils/test_utils.py
@@ -34,14 +34,10 @@ def mock_import_module(name, package):
assert lazy_module.title() == ".mock_module".title()
assert lazy_module.string == str
- with pytest.warns(
- FutureWarning, match="Importing 'dep_mock_module' from the argilla namespace"
- ):
+ with pytest.warns(FutureWarning, match="Importing 'dep_mock_module' from the argilla namespace"):
assert lazy_module.dep_mock_module == ".dep_mock_module"
- with pytest.warns(
- FutureWarning, match="Importing 'upper' from the argilla namespace"
- ):
+ with pytest.warns(FutureWarning, match="Importing 'upper' from the argilla namespace"):
assert lazy_module.upper() == ".dep_mock_module".upper()
with pytest.raises(AttributeError):
From 4c5f51377e374fb30649bdc7b9a3291db21c5bb8 Mon Sep 17 00:00:00 2001
From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Date: Thu, 16 Feb 2023 14:40:26 +0100
Subject: [PATCH 10/45] Use `rich` for logging, tracebacks, printing,
progressbars (#2350)
Closes #1843
Hello!
## Pull Request overview
* Use [`rich`](https://github.com/Textualize/rich) for logging,
tracebacks, printing and progressbars.
* Add `rich` as a dependency.
* Remove `loguru` as a dependency and remove all mentions of it in the
codebase.
* Simplify logging configuration according to the logging documentation.
* Update logging tests.
## Before & After
[`rich`](https://github.com/Textualize/rich) is a large Python library
for very colorful formatting in the terminal. Most importantly (in my
opinion), it improves the readability of logs and tracebacks. Let's go
over some before-and-afters:
Printing, Logging & Progressbars
### Before:
![image](https://user-images.githubusercontent.com/37621491/219089678-e57906d3-568d-480e-88a4-9240397f5229.png)
### After:
![image](https://user-images.githubusercontent.com/37621491/219089826-646d57a6-7e5b-426f-9ab1-d6d6317ec885.png)
Note that for the logs, the repeated information on the left side is
removed. Beyond that, the file location from which the log originates is
moved to the right side. Beyond that, the progressbar has been updated,
ahd the URL in the printed output has been highlighted automatically.
Tracebacks
### Before:
![image](https://user-images.githubusercontent.com/37621491/219090868-42cfe128-fd98-47ec-9d38-6f6f52a21373.png)
### After:
![image](https://user-images.githubusercontent.com/37621491/219090903-86f1fe11-d509-440d-8a6a-2833c344707b.png)
---
### Before:
![image](https://user-images.githubusercontent.com/37621491/219091343-96bae874-a673-4281-80c5-caebb67e348e.png)
### After:
![image](https://user-images.githubusercontent.com/37621491/219091193-d4cb1f64-11a7-4783-a9b2-0aec1abb8eb7.png)
---
### Before
![image](https://user-images.githubusercontent.com/37621491/219091791-aa8969a1-e0c1-4708-a23d-38d22c2406f2.png)
### After
![image](https://user-images.githubusercontent.com/37621491/219091878-e24c1f6b-83fa-4fed-9705-ede522faee82.png)
## Notes
Note that there are some changes in the logging configuration. Most of
all, it has been simplified according to the note from
[here](https://docs.python.org/3/library/logging.html#logging.Logger.propagate).
In my changes, I only attach our handler to the root logger and let
propagation take care of the rest.
Beyond that, I've set `rich` to 13.0.1 as newer versions experience a
StopIteration error like discussed
[here](https://github.com/Textualize/rich/issues/2800#issuecomment-1428764064).
I've replaced `tqdm` with `rich` Progressbar when logging. However, I've
kept the `tqdm` progressbar for the [Weak
Labeling](https://github.com/argilla-io/argilla/blob/develop/src/argilla/labeling/text_classification/weak_labels.py)
for now.
One difference between the old situation and now is that all of the logs
are displayed during `pytest` under "live log call" (so, including
expected errors), while earlier only warnings were shown.
## What to review?
Please do the following when reviewing:
1. Ensuring that `rich` is correctly set to be installed whenever
someone installs `argilla`. I always set my dependencies explicitly in
setup.py like
[here](https://github.com/nltk/nltk/blob/develop/setup.py#L115) or
[here](https://github.com/huggingface/setfit/blob/78851287535305ef32f789c7a87004628172b5b6/setup.py#L47-L48),
but the one for `argilla` is
[empty](https://github.com/argilla-io/argilla/blob/develop/setup.py),
and `pyproject.toml` is used instead. I'd like for someone to look this
over.
2. Fetch this branch and run some arbitrary code. Load some data, log
some data, crash some programs, and get an idea of the changes.
Especially changes to loggers and tracebacks can be a bit personal, so
I'd like to get people on board with this. Otherwise we can scrap it or
find a compromise. After all, this is also a design PR.
3. Please have a look at my discussion points below.
## Discussion
`rich` is quite configurable, so there's some changes that we can make
still.
1. The `RichHandler` logging handler can be modified to e.g. include
rich tracebacks in their logs as discussed
[here](https://rich.readthedocs.io/en/latest/logging.html#handle-exceptions).
Are we interested in this?
2. The `rich` traceback handler can be set up to include local variables
in its traceback:
Click to see a rich traceback with local
variables
![image](https://user-images.githubusercontent.com/37621491/219096029-796b57ee-2f1b-485f-af35-c3effd44200b.png)
Are we interested in this? I think this is a bit overkill in my opinion.
3. We can suppress frames from certain Python modules to exclude them
from the rich tracebacks. Are we interested in this?
4. The default rich traceback shows a maximum of 100 frames, which is a
*lot*. Are we interested in reducing this to only show the first and
last X?
5. The progress bar doesn't automatically stretch to fill the full
available width, while `tqdm` does. If we want, we can set `expand=True`
and it'll also expand to the entire width. Are we interested in this?
6. The progress "bar" does not need to be a bar, we can also use e.g. a
spinner animation. See some more info
[here](https://rich.readthedocs.io/en/latest/progress.html#columns). Are
we interested in this?
---
**Type of change**
- [x] Refactor (change restructuring the codebase without changing
functionality)
**How Has This Been Tested**
I've updated the tests according to my changes.
**Checklist**
- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [x] I added comments to my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- Tom Aarsen
---
environment_dev.yml | 6 +-
pyproject.toml | 6 +-
src/argilla/__init__.py | 8 +++
src/argilla/client/client.py | 40 ++++++-----
src/argilla/logging.py | 70 +++----------------
tests/conftest.py | 19 +----
.../text_classification/test_label_errors.py | 3 +-
tests/test_init.py | 7 +-
tests/test_logging.py | 6 +-
9 files changed, 56 insertions(+), 109 deletions(-)
diff --git a/environment_dev.yml b/environment_dev.yml
index dcb2c0b7fd..3b73689e5c 100644
--- a/environment_dev.yml
+++ b/environment_dev.yml
@@ -43,9 +43,9 @@ dependencies:
- pgmpy
- plotly>=4.1.0
- snorkel>=0.9.7
- - spacy==3.1.0
- - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.1.0/en_core_web_sm-3.1.0.tar.gz
+ - spacy==3.5.0
+ - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0.tar.gz
- transformers[torch]~=4.18.0
- - loguru
+ - rich==13.0.1
# install Argilla in editable mode
- -e .[server,listeners]
diff --git a/pyproject.toml b/pyproject.toml
index 42133cb2c7..bf7123e24f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,11 +38,13 @@ dependencies = [
"wrapt >= 1.13,< 1.15",
# weaksupervision
"numpy < 1.24.0",
+ # for progressbars
"tqdm >= 4.27.0",
# monitor background consumers
"backoff",
- "monotonic"
-
+ "monotonic",
+ # for logging, tracebacks, printing, progressbars
+ "rich <= 13.0.1"
]
dynamic = ["version"]
diff --git a/src/argilla/__init__.py b/src/argilla/__init__.py
index 1dbd654c9f..81ed00085a 100644
--- a/src/argilla/__init__.py
+++ b/src/argilla/__init__.py
@@ -26,6 +26,14 @@
from . import _version
from .utils import LazyargillaModule as _LazyargillaModule
+try:
+ from rich.traceback import install as _install_rich
+
+ # Rely on `rich` for tracebacks
+ _install_rich()
+except ModuleNotFoundError:
+ pass
+
__version__ = _version.version
if _TYPE_CHECKING:
diff --git a/src/argilla/client/client.py b/src/argilla/client/client.py
index ef36f35800..70f9938509 100644
--- a/src/argilla/client/client.py
+++ b/src/argilla/client/client.py
@@ -20,7 +20,8 @@
from asyncio import Future
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
-from tqdm.auto import tqdm
+from rich import print as rprint
+from rich.progress import Progress
from argilla._constants import (
_OLD_WORKSPACE_HEADER_NAME,
@@ -363,25 +364,26 @@ async def log_async(
raise InputValueError(f"Unknown record type {record_type}. Available values are" f" {Record.__args__}")
processed, failed = 0, 0
- progress_bar = tqdm(total=len(records), disable=not verbose)
- for i in range(0, len(records), chunk_size):
- chunk = records[i : i + chunk_size]
-
- response = await async_bulk(
- client=self._client,
- name=name,
- json_body=bulk_class(
- tags=tags,
- metadata=metadata,
- records=[creation_class.from_client(r) for r in chunk],
- ),
- )
+ with Progress() as progress_bar:
+ task = progress_bar.add_task("Logging...", total=len(records), visible=verbose)
+
+ for i in range(0, len(records), chunk_size):
+ chunk = records[i : i + chunk_size]
+
+ response = await async_bulk(
+ client=self._client,
+ name=name,
+ json_body=bulk_class(
+ tags=tags,
+ metadata=metadata,
+ records=[creation_class.from_client(r) for r in chunk],
+ ),
+ )
- processed += response.parsed.processed
- failed += response.parsed.failed
+ processed += response.parsed.processed
+ failed += response.parsed.failed
- progress_bar.update(len(chunk))
- progress_bar.close()
+ progress_bar.update(task, advance=len(chunk))
# TODO: improve logging policy in library
if verbose:
@@ -389,7 +391,7 @@ async def log_async(
workspace = self.get_workspace()
if not workspace: # Just for backward comp. with datasets with no workspaces
workspace = "-"
- print(f"{processed} records logged to" f" {self._client.base_url}/datasets/{workspace}/{name}")
+ rprint(f"{processed} records logged to {self._client.base_url}/datasets/{workspace}/{name}")
# Creating a composite BulkResponse with the total processed and failed
return BulkResponse(dataset=name, processed=processed, failed=failed)
diff --git a/src/argilla/logging.py b/src/argilla/logging.py
index 94b1cc249e..51e3c56b49 100644
--- a/src/argilla/logging.py
+++ b/src/argilla/logging.py
@@ -18,13 +18,13 @@
"""
import logging
-from logging import Logger
+from logging import Logger, StreamHandler
from typing import Type
try:
- from loguru import logger
+ from rich.logging import RichHandler as ArgillaHandler
except ModuleNotFoundError:
- logger = None
+ ArgillaHandler = StreamHandler
def full_qualified_class_name(_class: Type) -> str:
@@ -60,64 +60,10 @@ def logger(self) -> logging.Logger:
return self.__logger__
-class LoguruLoggerHandler(logging.Handler):
- """This logging handler enables an easy way to use loguru fo all built-in logger traces"""
-
- __LOGLEVEL_MAPPING__ = {
- 50: "CRITICAL",
- 40: "ERROR",
- 30: "WARNING",
- 20: "INFO",
- 10: "DEBUG",
- 0: "NOTSET",
- }
-
- @property
- def is_available(self) -> bool:
- """Return True if handler can tackle log records. False otherwise"""
- return logger is not None
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- if not self.is_available:
- self.emit = lambda record: None
-
- def emit(self, record: logging.LogRecord):
- try:
- level = logger.level(record.levelname).name
- except AttributeError:
- level = self.__LOGLEVEL_MAPPING__[record.levelno]
-
- frame, depth = logging.currentframe(), 2
- while frame.f_code.co_filename == logging.__file__:
- frame = frame.f_back
- depth += 1
-
- log = logger.bind(request_id="argilla")
- log.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
-
-
def configure_logging():
"""Normalizes logging configuration for argilla and its dependencies"""
- intercept_handler = LoguruLoggerHandler()
- if not intercept_handler.is_available:
- return
-
- logging.basicConfig(handlers=[intercept_handler], level=logging.WARNING)
- for name in logging.root.manager.loggerDict:
- logger_ = logging.getLogger(name)
- logger_.handlers = []
-
- for name in [
- "uvicorn",
- "uvicorn.lifespan",
- "uvicorn.error",
- "uvicorn.access",
- "fastapi",
- "argilla",
- "argilla.server",
- ]:
- logger_ = logging.getLogger(name)
- logger_.propagate = False
- logger_.handlers = [intercept_handler]
+ handler = ArgillaHandler()
+
+ # See the note here: https://docs.python.org/3/library/logging.html#logging.Logger.propagate
+ # We only attach our handler to the root logger and let propagation take care of the rest
+ logging.basicConfig(handlers=[handler], level=logging.WARNING)
diff --git a/tests/conftest.py b/tests/conftest.py
index 33738a6db6..f8495bff2b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -15,15 +15,10 @@
import httpx
import pytest
from _pytest.logging import LogCaptureFixture
-from argilla.client.sdk.users import api as users_api
-from argilla.server.commons import telemetry
-
-try:
- from loguru import logger
-except ModuleNotFoundError:
- logger = None
from argilla import app
from argilla.client.api import active_api
+from argilla.client.sdk.users import api as users_api
+from argilla.server.commons import telemetry
from starlette.testclient import TestClient
from .helpers import SecuredClient
@@ -68,13 +63,3 @@ def whoami_mocked(client):
monkeypatch.setattr(rb_api._client, "__httpx__", client_)
yield client_
-
-
-@pytest.fixture
-def caplog(caplog: LogCaptureFixture):
- if not logger:
- yield caplog
- else:
- handler_id = logger.add(caplog.handler, format="{message}")
- yield caplog
- logger.remove(handler_id)
diff --git a/tests/labeling/text_classification/test_label_errors.py b/tests/labeling/text_classification/test_label_errors.py
index b07780ec71..0e46394bf0 100644
--- a/tests/labeling/text_classification/test_label_errors.py
+++ b/tests/labeling/text_classification/test_label_errors.py
@@ -17,6 +17,7 @@
import argilla as ar
import cleanlab
import pytest
+from _pytest.logging import LogCaptureFixture
from argilla.labeling.text_classification import find_label_errors
from argilla.labeling.text_classification.label_errors import (
MissingPredictionError,
@@ -70,7 +71,7 @@ def test_no_records():
find_label_errors(records)
-def test_multi_label_warning(caplog):
+def test_multi_label_warning(caplog: LogCaptureFixture):
record = ar.TextClassificationRecord(
text="test",
prediction=[("mock", 0.0), ("mock2", 0.0)],
diff --git a/tests/test_init.py b/tests/test_init.py
index 9a3a37b103..753e230216 100644
--- a/tests/test_init.py
+++ b/tests/test_init.py
@@ -16,7 +16,7 @@
import logging
import sys
-from argilla.logging import LoguruLoggerHandler
+from argilla.logging import ArgillaHandler
from argilla.utils import LazyargillaModule
@@ -25,4 +25,7 @@ def test_lazy_module():
def test_configure_logging_call():
- assert isinstance(logging.getLogger("argilla").handlers[0], LoguruLoggerHandler)
+ # Ensure that the root logger uses the ArgillaHandler (RichHandler if rich is installed),
+ # whereas the other loggers do not have handlers
+ assert isinstance(logging.getLogger().handlers[0], ArgillaHandler)
+ assert len(logging.getLogger("argilla").handlers) == 0
diff --git a/tests/test_logging.py b/tests/test_logging.py
index 054deb7740..5f7f9f5be4 100644
--- a/tests/test_logging.py
+++ b/tests/test_logging.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from argilla.logging import LoggingMixin, LoguruLoggerHandler
+from argilla.logging import ArgillaHandler, LoggingMixin
class LoggingForTest(LoggingMixin):
@@ -50,8 +50,8 @@ def test_logging_mixin_without_breaking_constructors():
def test_logging_handler(mocker):
- mocker.patch.object(LoguruLoggerHandler, "emit", autospec=True)
- handler = LoguruLoggerHandler()
+ mocker.patch.object(ArgillaHandler, "emit", autospec=True)
+ handler = ArgillaHandler()
logger = logging.getLogger(__name__)
logger.handlers = [handler]
From 5a8bb28210ec1255d3805baf3bc1735e106fe4c7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Francisco=20Calvo?=
Date: Mon, 20 Feb 2023 07:57:00 +0100
Subject: [PATCH 11/45] chore: Replace old recognai emails with argilla ones
(#2365)
# Description
Replace old `recogn.ai` emails with `argilla.io` ones.
This will improve [argilla Pypi package
page](https://pypi.org/project/argilla/) showing the correct contact
emails.
We still have a reference to `recogn.ai` on this JS model file:
https://github.com/argilla-io/argilla/blob/develop/frontend/models/Dataset.js#L22
Please @frascuchon and @keithCuniah can you confirm if that JS code can
be safely changed and how?
---
CODE_OF_CONDUCT.md | 2 +-
pyproject.toml | 4 ++--
tests/server/security/test_model.py | 2 +-
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md
index f18617bd23..d86e29843e 100644
--- a/CODE_OF_CONDUCT.md
+++ b/CODE_OF_CONDUCT.md
@@ -60,7 +60,7 @@ representative at an online or offline event.
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
-contact@recogn.ai.
+contact@argilla.io.
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
diff --git a/pyproject.toml b/pyproject.toml
index bf7123e24f..34f978ea40 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,10 +20,10 @@ keywords = [
"mlops"
]
authors = [
- {name = "recognai", email = "contact@recogn.ai"}
+ {name = "argilla", email = "contact@argilla.io"}
]
maintainers = [
- {name = "recognai", email = "contact@recogn.ai"}
+ {name = "argilla", email = "contact@argilla.io"}
]
dependencies = [
# Client
diff --git a/tests/server/security/test_model.py b/tests/server/security/test_model.py
index 1511695f63..e89ac76842 100644
--- a/tests/server/security/test_model.py
+++ b/tests/server/security/test_model.py
@@ -18,7 +18,7 @@
from pydantic import ValidationError
-@pytest.mark.parametrize("email", ["my@email.com", "infra@recogn.ai"])
+@pytest.mark.parametrize("email", ["my@email.com", "infra@argilla.io"])
def test_valid_mail(email):
user = User(username="user", email=email)
assert user.email == email
From 175a52a3046eba79c2d56772150a53f4ed469c9e Mon Sep 17 00:00:00 2001
From: Francisco Aranda
Date: Mon, 20 Feb 2023 16:39:56 +0100
Subject: [PATCH 12/45] refactor: remove the classification labeling rules
service (#2361)
# Description
This PR move the labeling rule operations from the labeling rule service
to the text-classfication service.
**Type of change**
(Please delete options that are not relevant. Remember to title the PR
according to the type of change)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [x] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update
**Checklist**
- [x] I have merged the original branch into my forked branch
- [x] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [x] I added comments to my code
- [x] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
---
.../server/apis/v0/handlers/metrics.py | 17 ++-
.../apis/v0/handlers/text_classification.py | 11 +-
.../server/daos/backend/metrics/base.py | 11 +-
.../backend/metrics/text_classification.py | 2 +-
src/argilla/server/security/settings.py | 14 --
src/argilla/server/services/datasets.py | 3 +
.../server/services/metrics/service.py | 13 ++
.../tasks/text_classification/__init__.py | 1 -
.../labeling_rules_service.py | 141 ------------------
.../tasks/text_classification/metrics.py | 8 +
.../tasks/text_classification/model.py | 13 ++
.../tasks/text_classification/service.py | 119 +++++++++++----
tests/server/metrics/test_api.py | 38 ++++-
13 files changed, 189 insertions(+), 202 deletions(-)
delete mode 100644 src/argilla/server/security/settings.py
delete mode 100644 src/argilla/server/services/tasks/text_classification/labeling_rules_service.py
diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py
index 3f4e0b6743..9408f1e3f0 100644
--- a/src/argilla/server/apis/v0/handlers/metrics.py
+++ b/src/argilla/server/apis/v0/handlers/metrics.py
@@ -14,9 +14,9 @@
# limitations under the License.
from dataclasses import dataclass
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
-from fastapi import APIRouter, Depends, Query, Security
+from fastapi import APIRouter, Depends, Query, Request, Security
from pydantic import BaseModel, Field
from argilla.server.apis.v0.helpers import deprecate_endpoint
@@ -36,6 +36,8 @@ class MetricInfo(BaseModel):
@dataclass
class MetricSummaryParams:
+ request: Request
+
interval: Optional[float] = Query(
default=None,
gt=0.0,
@@ -47,6 +49,15 @@ class MetricSummaryParams:
description="The number of terms for terminological summaries",
)
+ @property
+ def parameters(self) -> Dict[str, Any]:
+ """Returns dynamic metric args found in the request query params"""
+ return {
+ "interval": self.interval,
+ "size": self.size,
+ **{k: v for k, v in self.request.query_params.items() if k not in ["interval", "size"]},
+ }
+
def configure_router(router: APIRouter, cfg: TaskConfig):
base_metrics_endpoint = f"/{cfg.task}/{{name}}/metrics"
@@ -112,7 +123,7 @@ def metric_summary(
metric=metric_,
record_class=record_class,
query=query,
- **vars(metric_params),
+ **metric_params.parameters,
)
diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py
index e5d4ca840b..59f9436425 100644
--- a/src/argilla/server/apis/v0/handlers/text_classification.py
+++ b/src/argilla/server/apis/v0/handlers/text_classification.py
@@ -316,7 +316,7 @@ async def list_labeling_rules(
as_dataset_class=TasksFactory.get_task_dataset(task_type),
)
- return [LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset)]
+ return [LabelingRule.parse_obj(rule) for rule in service.list_labeling_rules(dataset)]
@deprecate_endpoint(
path=f"{new_base_endpoint}/labeling/rules",
@@ -379,7 +379,7 @@ async def compute_rule_metrics(
as_dataset_class=TasksFactory.get_task_dataset(task_type),
)
- return service.compute_rule_metrics(dataset, rule_query=query, labels=labels)
+ return service.compute_labeling_rule(dataset, rule_query=query, labels=labels)
@deprecate_endpoint(
path=f"{new_base_endpoint}/labeling/rules/metrics",
@@ -404,7 +404,7 @@ async def compute_dataset_rules_metrics(
workspace=common_params.workspace,
as_dataset_class=TasksFactory.get_task_dataset(task_type),
)
- metrics = service.compute_overall_rules_metrics(dataset)
+ metrics = service.compute_all_labeling_rules(dataset)
return DatasetLabelingRulesMetricsSummary.parse_obj(metrics)
@deprecate_endpoint(
@@ -456,10 +456,7 @@ async def get_rule(
workspace=common_params.workspace,
as_dataset_class=TasksFactory.get_task_dataset(task_type),
)
- rule = service.find_labeling_rule(
- dataset,
- rule_query=query,
- )
+ rule = service.find_labeling_rule(dataset, rule_query=query)
return LabelingRule.parse_obj(rule)
@deprecate_endpoint(
diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py
index 82bd60a1e7..344dfeba18 100644
--- a/src/argilla/server/daos/backend/metrics/base.py
+++ b/src/argilla/server/daos/backend/metrics/base.py
@@ -13,6 +13,7 @@
# limitations under the License.
import dataclasses
+import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from argilla.server.daos.backend.query_helpers import aggregations
@@ -22,6 +23,9 @@
from argilla.server.daos.backend.client_adapters.base import IClientAdapter
+_LOGGER = logging.getLogger(__file__)
+
+
@dataclasses.dataclass
class ElasticsearchMetric:
id: str
@@ -37,11 +41,14 @@ def __post_init__(self):
def get_function_arg_names(func):
return func.__code__.co_varnames
- def aggregation_request(self, *args, **kwargs) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
+ def aggregation_request(self, *args, **kwargs) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]:
"""
Configures the summary es aggregation definition
"""
- return {self.id: self._build_aggregation(*args, **kwargs)}
+ try:
+ return {self.id: self._build_aggregation(*args, **kwargs)}
+ except TypeError as ex:
+ _LOGGER.warning(f"Cannot build metric for metric {self.id}. Error: {ex}. Skipping...")
def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
"""
diff --git a/src/argilla/server/daos/backend/metrics/text_classification.py b/src/argilla/server/daos/backend/metrics/text_classification.py
index 8d3a83228c..cf54f5e12d 100644
--- a/src/argilla/server/daos/backend/metrics/text_classification.py
+++ b/src/argilla/server/daos/backend/metrics/text_classification.py
@@ -43,7 +43,7 @@ def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]:
class LabelingRulesMetric(ElasticsearchMetric):
id: str
- def _build_aggregation(self, rule_query: str, labels: Optional[List[str]]) -> Dict[str, Any]:
+ def _build_aggregation(self, rule_query: str, labels: Optional[List[str]] = None) -> Dict[str, Any]:
annotated_records_filter = filters.exists_field("annotated_as")
rule_query_filter = filters.text_query(rule_query)
aggr_filters = {
diff --git a/src/argilla/server/security/settings.py b/src/argilla/server/security/settings.py
deleted file mode 100644
index 168dce9eb1..0000000000
--- a/src/argilla/server/security/settings.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# coding=utf-8
-# Copyright 2021-present, the Recognai S.L. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py
index b6af813d4b..1c4e7f6843 100644
--- a/src/argilla/server/services/datasets.py
+++ b/src/argilla/server/services/datasets.py
@@ -230,6 +230,9 @@ async def get_settings(
raise EntityNotFoundError(name=dataset.name, type=class_type)
return class_type.parse_obj(settings.dict())
+ def raw_dataset_update(self, dataset):
+ self.__dao__.update_dataset(dataset)
+
async def save_settings(
self, user: User, dataset: ServiceDataset, settings: ServiceDatasetSettings
) -> ServiceDatasetSettings:
diff --git a/src/argilla/server/services/metrics/service.py b/src/argilla/server/services/metrics/service.py
index 9989690387..90b86cf294 100644
--- a/src/argilla/server/services/metrics/service.py
+++ b/src/argilla/server/services/metrics/service.py
@@ -116,3 +116,16 @@ def summarize_metric(
dataset=dataset,
query=query,
)
+
+ def annotated_records(self, dataset: ServiceDataset) -> int:
+ """Return the number of annotated records for a dataset"""
+ results = self.__dao__.search_records(
+ dataset,
+ size=0,
+ search=DaoRecordsSearch(query=ServiceBaseRecordsQuery(has_annotation=True)),
+ )
+ return results.total
+
+ def total_records(self, dataset: ServiceDataset) -> int:
+ """Return the total number of records for a given dataset"""
+ return self.__dao__.search_records(dataset, size=0).total
diff --git a/src/argilla/server/services/tasks/text_classification/__init__.py b/src/argilla/server/services/tasks/text_classification/__init__.py
index 1a0c078cea..b8298320b7 100644
--- a/src/argilla/server/services/tasks/text_classification/__init__.py
+++ b/src/argilla/server/services/tasks/text_classification/__init__.py
@@ -12,5 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .labeling_rules_service import LabelingService
from .service import TextClassificationService
diff --git a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py
deleted file mode 100644
index 38471bfccc..0000000000
--- a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Copyright 2021-present, the Recognai S.L. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import List, Optional, Tuple
-
-from fastapi import Depends
-from pydantic import BaseModel, Field
-
-from argilla.server.daos.datasets import DatasetsDAO
-from argilla.server.daos.models.records import DaoRecordsSearch
-from argilla.server.daos.records import DatasetRecordsDAO
-from argilla.server.errors import EntityAlreadyExistsError, EntityNotFoundError
-from argilla.server.services.search.model import ServiceBaseRecordsQuery
-from argilla.server.services.tasks.text_classification.model import (
- ServiceLabelingRule,
- ServiceTextClassificationDataset,
-)
-
-
-class DatasetLabelingRulesSummary(BaseModel):
- covered_records: int
- annotated_covered_records: int
-
-
-class LabelingRuleSummary(BaseModel):
- covered_records: int
- annotated_covered_records: int
- correct_records: int = Field(default=0)
- incorrect_records: int = Field(default=0)
- precision: Optional[float] = None
-
-
-class LabelingService:
- _INSTANCE = None
-
- @classmethod
- def get_instance(
- cls,
- datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance),
- records: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance),
- ):
- if cls._INSTANCE is None:
- cls._INSTANCE = cls(datasets, records)
- return cls._INSTANCE
-
- def __init__(self, datasets: DatasetsDAO, records: DatasetRecordsDAO):
- self.__datasets__ = datasets
- self.__records__ = records
-
- # TODO(@frascuchon): Move all rules management methods to the common datasets service like settings
- def list_rules(self, dataset: ServiceTextClassificationDataset) -> List[ServiceLabelingRule]:
- """List a set of rules for a given dataset"""
- return dataset.rules
-
- def delete_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str):
- """Delete a rule from a dataset by its defined query string"""
- new_rules_set = [r for r in dataset.rules if r.query != rule_query]
- if len(dataset.rules) != new_rules_set:
- dataset.rules = new_rules_set
- self.__datasets__.update_dataset(dataset)
-
- def add_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> ServiceLabelingRule:
- """Adds a rule to a dataset"""
- for r in dataset.rules:
- if r.query == rule.query:
- raise EntityAlreadyExistsError(rule.query, type=ServiceLabelingRule)
- dataset.rules.append(rule)
- self.__datasets__.update_dataset(dataset)
- return rule
-
- def compute_rule_metrics(
- self,
- dataset: ServiceTextClassificationDataset,
- rule_query: str,
- labels: Optional[List[str]] = None,
- ) -> Tuple[int, int, LabelingRuleSummary]:
- """Computes metrics for given rule query and optional label against a set of rules"""
-
- annotated_records = self._count_annotated_records(dataset)
- dataset_records = self.__records__.search_records(dataset, size=0).total
- metric_data = self.__records__.compute_metric(
- dataset=dataset,
- metric_id="labeling_rule",
- metric_params=dict(rule_query=rule_query, labels=labels),
- )
-
- return (
- dataset_records,
- annotated_records,
- LabelingRuleSummary.parse_obj(metric_data),
- )
-
- def _count_annotated_records(self, dataset: ServiceTextClassificationDataset) -> int:
- results = self.__records__.search_records(
- dataset,
- size=0,
- search=DaoRecordsSearch(query=ServiceBaseRecordsQuery(has_annotation=True)),
- )
- return results.total
-
- def all_rules_metrics(
- self, dataset: ServiceTextClassificationDataset
- ) -> Tuple[int, int, DatasetLabelingRulesSummary]:
- annotated_records = self._count_annotated_records(dataset)
- dataset_records = self.__records__.search_records(dataset, size=0).total
- metric_data = self.__records__.compute_metric(
- dataset=dataset,
- metric_id="dataset_labeling_rules",
- metric_params=dict(queries=[r.query for r in dataset.rules]),
- )
-
- return (
- dataset_records,
- annotated_records,
- DatasetLabelingRulesSummary.parse_obj(metric_data),
- )
-
- def find_rule_by_query(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule:
- rule_query = rule_query.strip()
- for rule in dataset.rules:
- if rule.query == rule_query:
- return rule
- raise EntityNotFoundError(rule_query, type=ServiceLabelingRule)
-
- def replace_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule):
- for idx, r in enumerate(dataset.rules):
- if r.query == rule.query:
- dataset.rules[idx] = rule
- break
- self.__datasets__.update_dataset(dataset)
diff --git a/src/argilla/server/services/tasks/text_classification/metrics.py b/src/argilla/server/services/tasks/text_classification/metrics.py
index d09891620b..2b9bd006a9 100644
--- a/src/argilla/server/services/tasks/text_classification/metrics.py
+++ b/src/argilla/server/services/tasks/text_classification/metrics.py
@@ -159,5 +159,13 @@ class TextClassificationMetrics(CommonTasksMetrics[ServiceTextClassificationReco
id="annotated_as",
name="Annotated labels distribution",
),
+ ServiceBaseMetric(
+ id="labeling_rule",
+ name="Labeling rule metric based on a query rule and a set of labels",
+ ),
+ ServiceBaseMetric(
+ id="dataset_labeling_rules",
+ name="Computes the overall labeling rules stats",
+ ),
]
)
diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py
index 73c5cd7d05..f8a82db328 100644
--- a/src/argilla/server/services/tasks/text_classification/model.py
+++ b/src/argilla/server/services/tasks/text_classification/model.py
@@ -307,3 +307,16 @@ class ServiceTextClassificationQuery(ServiceBaseRecordsQuery):
predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)
uncovered_by_rules: List[str] = Field(default_factory=list)
+
+
+class DatasetLabelingRulesSummary(BaseModel):
+ covered_records: int
+ annotated_covered_records: int
+
+
+class LabelingRuleSummary(BaseModel):
+ covered_records: int
+ annotated_covered_records: int
+ correct_records: int = Field(default=0)
+ incorrect_records: int = Field(default=0)
+ precision: Optional[float] = None
diff --git a/src/argilla/server/services/tasks/text_classification/service.py b/src/argilla/server/services/tasks/text_classification/service.py
index 8a5407258c..79d6ddb473 100644
--- a/src/argilla/server/services/tasks/text_classification/service.py
+++ b/src/argilla/server/services/tasks/text_classification/service.py
@@ -13,12 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, List, Optional
+from typing import Iterable, List, Optional, Tuple
from fastapi import Depends
-from argilla.server.commons.config import TasksFactory
-from argilla.server.errors.base_errors import MissingDatasetRecordsError
+from argilla.server.errors.base_errors import (
+ EntityAlreadyExistsError,
+ EntityNotFoundError,
+ MissingDatasetRecordsError,
+)
+from argilla.server.services.datasets import DatasetsService
+from argilla.server.services.metrics import MetricsService
from argilla.server.services.search.model import (
ServiceSearchResults,
ServiceSortableField,
@@ -27,10 +32,14 @@
from argilla.server.services.search.service import SearchRecordsService
from argilla.server.services.storage.service import RecordsStorageService
from argilla.server.services.tasks.commons import BulkResponse
-from argilla.server.services.tasks.text_classification import LabelingService
+from argilla.server.services.tasks.text_classification.metrics import (
+ TextClassificationMetrics,
+)
from argilla.server.services.tasks.text_classification.model import (
DatasetLabelingRulesMetricsSummary,
+ DatasetLabelingRulesSummary,
LabelingRuleMetricsSummary,
+ LabelingRuleSummary,
ServiceLabelingRule,
ServiceTextClassificationDataset,
ServiceTextClassificationQuery,
@@ -49,23 +58,26 @@ class TextClassificationService:
@classmethod
def get_instance(
cls,
+ datasets: DatasetsService = Depends(DatasetsService.get_instance),
+ metrics: MetricsService = Depends(MetricsService.get_instance),
storage: RecordsStorageService = Depends(RecordsStorageService.get_instance),
- labeling: LabelingService = Depends(LabelingService.get_instance),
search: SearchRecordsService = Depends(SearchRecordsService.get_instance),
) -> "TextClassificationService":
if not cls._INSTANCE:
- cls._INSTANCE = cls(storage, labeling=labeling, search=search)
+ cls._INSTANCE = cls(datasets=datasets, metrics=metrics, storage=storage, search=search)
return cls._INSTANCE
def __init__(
self,
+ datasets: DatasetsService,
+ metrics: MetricsService,
storage: RecordsStorageService,
search: SearchRecordsService,
- labeling: LabelingService,
):
self.__storage__ = storage
self.__search__ = search
- self.__labeling__ = labeling
+ self.__metrics__ = metrics
+ self.__datasets__ = datasets
async def add_records(
self,
@@ -116,9 +128,9 @@ def search(
"""
- metrics = TasksFactory.find_task_metrics(
- dataset.task,
- metric_ids={
+ metrics = [
+ TextClassificationMetrics.find_metric(id)
+ for id in {
"words_cloud",
"predicted_by",
"predicted_as",
@@ -128,8 +140,8 @@ def search(
"status_distribution",
"metadata",
"score",
- },
- )
+ }
+ ]
results = self.__search__.search(
dataset,
@@ -209,8 +221,18 @@ def _is_dataset_multi_label(self, dataset: ServiceTextClassificationDataset) ->
if results.records:
return results.records[0].multi_label
- def get_labeling_rules(self, dataset: ServiceTextClassificationDataset) -> Iterable[ServiceLabelingRule]:
- return self.__labeling__.list_rules(dataset)
+ def find_labeling_rule(
+ self, dataset: ServiceTextClassificationDataset, rule_query: str, error_on_missing: bool = True
+ ) -> Optional[ServiceLabelingRule]:
+ rule_query = rule_query.strip()
+ for rule in dataset.rules:
+ if rule.query == rule_query:
+ return rule
+ if error_on_missing:
+ raise EntityNotFoundError(rule_query, type=ServiceLabelingRule)
+
+ def list_labeling_rules(self, dataset: ServiceTextClassificationDataset) -> Iterable[ServiceLabelingRule]:
+ return dataset.rules
def add_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> None:
"""
@@ -224,8 +246,13 @@ def add_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule: Ser
rule:
The rule
"""
+
self.__normalized_rule__(rule)
- self.__labeling__.add_rule(dataset, rule)
+ if self.find_labeling_rule(dataset, rule_query=rule.query, error_on_missing=False):
+ raise EntityAlreadyExistsError(rule.query, type=ServiceLabelingRule)
+
+ dataset.rules.append(rule)
+ self.__datasets__.raw_dataset_update(dataset)
def update_labeling_rule(
self,
@@ -234,7 +261,7 @@ def update_labeling_rule(
labels: List[str],
description: Optional[str] = None,
) -> ServiceLabelingRule:
- found_rule = self.__labeling__.find_rule_by_query(dataset, rule_query)
+ found_rule = self.find_labeling_rule(dataset, rule_query)
found_rule.labels = labels
found_rule.label = labels[0] if len(labels) == 1 else None
@@ -242,17 +269,22 @@ def update_labeling_rule(
found_rule.description = description
self.__normalized_rule__(found_rule)
- self.__labeling__.replace_rule(dataset, found_rule)
- return found_rule
+ for idx, r in enumerate(dataset.rules):
+ if r.query == found_rule.query:
+ dataset.rules[idx] = found_rule
+ break
+ self.__datasets__.raw_dataset_update(dataset)
- def find_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule:
- return self.__labeling__.find_rule_by_query(dataset, rule_query=rule_query)
+ return found_rule
def delete_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str):
- if rule_query.strip():
- return self.__labeling__.delete_rule(dataset, rule_query)
+ """Delete a rule from a dataset by its defined query string"""
+ new_rules_set = [r for r in dataset.rules if r.query != rule_query]
+ if len(dataset.rules) != new_rules_set:
+ dataset.rules = new_rules_set
+ self.__datasets__.raw_dataset_update(dataset)
- def compute_rule_metrics(
+ def compute_labeling_rule(
self,
dataset: ServiceTextClassificationDataset,
rule_query: str,
@@ -288,14 +320,20 @@ def compute_rule_metrics(
rule_query = rule_query.strip()
if labels is None:
- for rule in self.get_labeling_rules(dataset):
- if rule.query == rule_query:
- labels = rule.labels
- break
+ rule = self.find_labeling_rule(dataset, rule_query=rule_query, error_on_missing=False)
+ if rule:
+ labels = rule.labels
- total, annotated, metrics = self.__labeling__.compute_rule_metrics(
- dataset, rule_query=rule_query, labels=labels
+ metric_data = self.__metrics__.summarize_metric(
+ dataset=dataset,
+ metric=TextClassificationMetrics.find_metric("labeling_rule"),
+ rule_query=rule_query,
+ labels=labels,
)
+ annotated = self.__metrics__.annotated_records(dataset)
+ total = self.__metrics__.total_records(dataset)
+
+ metrics = LabelingRuleSummary.parse_obj(metric_data)
coverage = metrics.covered_records / total if total > 0 else None
coverage_annotated = metrics.annotated_covered_records / annotated if annotated > 0 else None
@@ -310,8 +348,8 @@ def compute_rule_metrics(
precision=metrics.precision if annotated > 0 else None,
)
- def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDataset):
- total, annotated, metrics = self.__labeling__.all_rules_metrics(dataset)
+ def compute_all_labeling_rules(self, dataset: ServiceTextClassificationDataset):
+ total, annotated, metrics = self._compute_all_lb_rules_metrics(dataset)
coverage = metrics.covered_records / total if total else None
coverage_annotated = metrics.annotated_covered_records / annotated if annotated else None
return DatasetLabelingRulesMetricsSummary(
@@ -321,6 +359,23 @@ def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDatase
annotated_records=annotated,
)
+ def _compute_all_lb_rules_metrics(
+ self, dataset: ServiceTextClassificationDataset
+ ) -> Tuple[int, int, DatasetLabelingRulesSummary]:
+ annotated_records = self.__metrics__.annotated_records(dataset)
+ dataset_records = self.__metrics__.total_records(dataset)
+ metric_data = self.__metrics__.summarize_metric(
+ dataset=dataset,
+ metric=TextClassificationMetrics.find_metric(id="dataset_labeling_rules"),
+ queries=[r.query for r in dataset.rules],
+ )
+
+ return (
+ dataset_records,
+ annotated_records,
+ DatasetLabelingRulesSummary.parse_obj(metric_data),
+ )
+
@staticmethod
def __normalized_rule__(rule: ServiceLabelingRule) -> ServiceLabelingRule:
if rule.labels and len(rule.labels) == 1:
diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py
index 0dc30ba139..dec0169c54 100644
--- a/tests/server/metrics/test_api.py
+++ b/tests/server/metrics/test_api.py
@@ -26,11 +26,15 @@
TokenClassificationRecord,
)
from argilla.server.services.metrics.models import CommonTasksMetrics
+from argilla.server.services.tasks.text_classification.metrics import (
+ TextClassificationMetrics,
+)
from argilla.server.services.tasks.token_classification.metrics import (
TokenClassificationMetrics,
)
COMMON_METRICS_LENGTH = len(CommonTasksMetrics.metrics)
+CLASSIFICATION_METRICS_LENGTH = len(TextClassificationMetrics.metrics)
def test_wrong_dataset_metrics(mocked_client):
@@ -183,7 +187,7 @@ def test_dataset_metrics(mocked_client):
metrics = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/metrics").json()
- assert len(metrics) == COMMON_METRICS_LENGTH + 5
+ assert len(metrics) == CLASSIFICATION_METRICS_LENGTH
response = mocked_client.post(
f"/api/datasets/TextClassification/{dataset}/metrics/missing_metric:summary",
@@ -206,6 +210,38 @@ def test_dataset_metrics(mocked_client):
assert response.status_code == 200, f"{metric}: {response.json()}"
+def create_some_classification_data(mocked_client, dataset: str, records: list):
+ request = TextClassificationBulkRequest(records=[TextClassificationRecord.parse_obj(r) for r in records])
+
+ assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200
+ assert (
+ mocked_client.post(
+ f"/api/datasets/{dataset}/TextClassification:bulk",
+ json=request.dict(by_alias=True),
+ ).status_code
+ == 200
+ )
+
+
+def test_labeling_rule_metric(mocked_client):
+ dataset = "test_labeling_rule_metric"
+ create_some_classification_data(
+ mocked_client, dataset, records=[{"inputs": {"text": "This is classification record"}}] * 10
+ )
+
+ rule_query = "t*"
+ response = mocked_client.post(
+ f"/api/datasets/TextClassification/{dataset}/metrics/labeling_rule:summary?rule_query={rule_query}",
+ json={},
+ )
+ assert response.json() == {
+ "annotated_covered_records": 0,
+ "correct_records": 0,
+ "covered_records": 10,
+ "incorrect_records": 0,
+ }
+
+
def test_dataset_labels_for_text_classification(mocked_client):
records = [
TextClassificationRecord.parse_obj(data)
From 3c27fd576719965e09400b01efdada542ec7d993 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Tue, 21 Feb 2023 04:40:00 +0000
Subject: [PATCH 13/45] [pre-commit.ci] pre-commit autoupdate
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
updates:
- [github.com/charliermarsh/ruff-pre-commit: v0.0.244 → v0.0.249](https://github.com/charliermarsh/ruff-pre-commit/compare/v0.0.244...v0.0.249)
---
.pre-commit-config.yaml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5ba8658505..0f8bffe411 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -26,7 +26,7 @@ repos:
additional_dependencies: ["click==8.0.4"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
- rev: v0.0.244
+ rev: v0.0.249
hooks:
# Simulate isort via (the much faster) ruff
- id: ruff
From 8f0d10de2c409d562e3bb758998625a1a430836f Mon Sep 17 00:00:00 2001
From: Gnonpi
Date: Tue, 21 Feb 2023 12:56:53 +0100
Subject: [PATCH 14/45] Documentation update: adding missing n (#2362)
# Description
There's an "n" character missing at the end of the sentence, not a super
critical change
**Type of change**
- [x] Documentation update
**How Has This Been Tested**
By eye
**Checklist**
- [x] I have merged the original branch into my forked branch
- [x] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [x] I added comments to my code
- [x] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
---
docs/_source/guides/query_datasets.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/_source/guides/query_datasets.md b/docs/_source/guides/query_datasets.md
index 0b80e19384..7ae8d1861c 100644
--- a/docs/_source/guides/query_datasets.md
+++ b/docs/_source/guides/query_datasets.md
@@ -35,7 +35,7 @@ If you do not retrieve any results after a version update, you should use the `w
The (arguably) most important fields are the `text` and `text.exact` fields.
They both contain the text of the records, however in two different forms:
-- the `text` field uses Elasticsearch's [standard analyzer](https://www.elastic.co/guide/en/elasticsearch/reference/7.10/analysis-standard-analyzer.html) that ignores capitalization and removes most of the punctuatio;
+- the `text` field uses Elasticsearch's [standard analyzer](https://www.elastic.co/guide/en/elasticsearch/reference/7.10/analysis-standard-analyzer.html) that ignores capitalization and removes most of the punctuation;
- the `text.exact` field uses the [whitespace analyzer](https://www.elastic.co/guide/en/elasticsearch/reference/7.10/analysis-whitespace-analyzer.html) that differentiates between lower and upper case, and does take into account punctuation;
Let's have a look at a few examples.
From 5ab38be2a188eb3f0d7a5159a4b75e358e208fd8 Mon Sep 17 00:00:00 2001
From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Date: Wed, 22 Feb 2023 11:41:47 +0100
Subject: [PATCH 15/45] ci: Remove Pyre from CI (#2358)
Hello!
## Pull Request overview
* Remove `pyre` from the CI.
* Remove the `pyre` configuration file.
## Details
Pyre has not been shown useful. It only causes CI failures, and I don't
think anyone actually looks at the logs. Let's save ourselves to red
crosses and remove it. After all, the CI doesn't even run on the
`develop` branch, only on `main` (and `master` and `releases`).
See
[here](https://github.com/argilla-io/argilla/actions/runs/4184361432/jobs/7249906904)
for an example CI output of the `pyre` workflow.
**Type of change**
- [x] Removal of unused code
**Checklist**
- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I added comments to my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
---
I intend to create a handful of other PRs to improve the CI situation.
In particular, I want green checkmarks to once again mean that a PR is
likely ready, and for red crosses to mean that it most certainly is not.
- Tom Aarsen
---
.github/workflows/pyre-check.yml | 78 --------------------------------
.pyre_configuration | 33 --------------
2 files changed, 111 deletions(-)
delete mode 100644 .github/workflows/pyre-check.yml
delete mode 100644 .pyre_configuration
diff --git a/.github/workflows/pyre-check.yml b/.github/workflows/pyre-check.yml
deleted file mode 100644
index 3941c480b5..0000000000
--- a/.github/workflows/pyre-check.yml
+++ /dev/null
@@ -1,78 +0,0 @@
-name: pyre
-
-on:
- push:
- branches: [main, master, releases/*]
- pull_request:
- # The branches below must be a subset of the branches above
- branches: [main, master, releases/*]
-
-jobs:
- pyre:
- runs-on: ubuntu-latest
- defaults:
- run:
- shell: bash -l {0}
- steps:
- - uses: actions/checkout@v2
-
- - name: Setup Conda Env 🐍
- uses: conda-incubator/setup-miniconda@v2
- with:
- miniforge-variant: Mambaforge
- miniforge-version: latest
- activate-environment: argilla
- use-mamba: true
-
- - name: Get date for conda cache
- id: get-date
- run: echo "::set-output name=today::$(/bin/date -u '+%Y%m%d')"
- shell: bash
-
- - name: Cache Conda env
- uses: actions/cache@v2
- id: cache
- with:
- path: ${{ env.CONDA }}/envs
- key: conda-${{ runner.os }}-${{ runner.arch }}-${{ steps.get-date.outputs.today }}-${{ hashFiles('environment_dev.yml') }}-${{ env.CACHE_NUMBER }}
- env:
- # Increase this value to reset cache if etc/example-environment.yml has not changed
- CACHE_NUMBER: 0
-
- - name: Update environment
- if: steps.cache.outputs.cache-hit != 'true'
- run: mamba env update -n argilla -f environment_dev.yml
-
- - name: Cache pip 👜
- uses: actions/cache@v2
- if: steps.filter.outputs.python_code == 'true'
- env:
- # Increase this value to reset cache if pyproject.toml has not changed
- CACHE_NUMBER: 0
- with:
- path: ~/.cache/pip
- key: ${{ runner.os }}-pip-${{ env.CACHE_NUMBER }}-${{ hashFiles('pyproject.toml') }}
-
- - name: Install dependencies
- # Force install this version since older versions fail with click < 8.0 (included by spacy 3.x)
- run: pip install pyre-check==0.9.15
-
- - name: Run Pyre
- continue-on-error: false
- run: |
- pyre --output=text check
-
-# - name: Expose SARIF Resultss
-# uses: actions/upload-artifact@v2
-# with:
-# name: SARIF Results
-# path: sarif.json
-#
-# - name: Upload SARIF Results
-# uses: github/codeql-action/upload-sarif@v1
-# with:
-# sarif_file: sarif.json
-#
-# - name: Fail Command On Errors
-# run: |
-# if [ "$(cat sarif.json | grep 'PYRE-ERROR')" != "" ]; then cat sarif.json && exit 1; fi
diff --git a/.pyre_configuration b/.pyre_configuration
deleted file mode 100644
index 6fef55869e..0000000000
--- a/.pyre_configuration
+++ /dev/null
@@ -1,33 +0,0 @@
-{
- "site_package_search_strategy": "pep561",
-
- "source_directories": [
- "src"
- ],
- "search_path": [
- {"site-package": "pytest"},
- {"site-package": "httpx"},
- {"site-package": "datasets"},
- {"site-package": "pandas"},
- {"site-package": "pydantic"},
- {"site-package": "fastapi"},
- {"site-package": "snorkel"},
- {"site-package": "spacy"},
- {"site-package": "cleanlab"},
- {"site-package": "flair"},
- {"site-package": "uvicorn"},
- {"site-package": "tqdm"},
- {"site-package": "sklearn"},
- {"site-package": "flyingsquid"},
- {"site-package": "pgmpy"},
- {"site-package": "faiss"},
- {"site-package": "schedule"},
- {"site-package": "prodict"},
- {"site-package": "plotly"},
- {"site-package": "wrapt"},
- {"site-package": "stopwordsiso"},
- {"site-package": "luqum"},
- {"site-package": "jose"},
- {"site-package": "brotli_asgi"}
- ]
-}
From e999fe211eafbccfb5698ff40e6c06dddf6f5334 Mon Sep 17 00:00:00 2001
From: Francisco Aranda
Date: Wed, 22 Feb 2023 12:07:21 +0100
Subject: [PATCH 16/45] Refactor/deprecate dataset owner (#2386)
# Description
Adding dataset `workspace` attribute and deprecating `owner` attribute
for current Dataset model. The `owner` will be removed in the next
release.
**Type of change**
(Please delete options that are not relevant. Remember to title the PR
according to the type of change)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [x] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update
**Checklist**
- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [x] I added comments to my code
- [x] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
---------
Co-authored-by: keithCuniah
---
.../_source/guides/log_load_and_prepare_data.ipynb | 3 ++-
.../commons/header/help-info/HelpInfo.spec.js | 2 +-
.../components/commons/results/results.spec.js | 2 +-
.../components/commons/results/resultsList.spec.js | 2 +-
.../SimilarityRecordReference.component.spec.js | 2 +-
frontend/components/commons/taskSearch.spec.js | 2 +-
.../results/recordTextClassification.spec.js | 2 +-
.../text2text/results/recordText2Text.spec.js | 2 +-
.../text2text/results/text2TextResultsList.spec.js | 2 +-
.../token-classifier/results/textSpan.spec.js | 2 +-
.../results/tokenClassificationResultsList.spec.js | 2 +-
frontend/database/modules/datasets.js | 4 ++--
frontend/models/Dataset.js | 12 ++++++------
frontend/models/TextClassification.js | 14 +++++++-------
frontend/models/TokenClassification.js | 2 +-
frontend/pages/datasets/index.vue | 10 +++++-----
.../components/core/table/BaseTableInfo.spec.js | 8 ++++----
.../core/table/TableFiltrableColumn.spec.js | 6 +++---
.../table/__snapshots__/BaseTableInfo.spec.js.snap | 2 +-
.../results/EntitiesSelector.spec.js | 2 +-
src/argilla/client/apis/datasets.py | 1 +
src/argilla/client/sdk/datasets/models.py | 1 +
src/argilla/server/daos/models/datasets.py | 13 +++++++++++--
tests/datasets/test_datasets.py | 10 ++++++++++
tests/server/datasets/test_api.py | 6 +++---
25 files changed, 68 insertions(+), 46 deletions(-)
diff --git a/docs/_source/guides/log_load_and_prepare_data.ipynb b/docs/_source/guides/log_load_and_prepare_data.ipynb
index fe0861f2f9..ce21c887ee 100644
--- a/docs/_source/guides/log_load_and_prepare_data.ipynb
+++ b/docs/_source/guides/log_load_and_prepare_data.ipynb
@@ -452,7 +452,8 @@
"empty_workspace_datasets = [\n",
" ds[\"name\"]\n",
" for ds in rg_client.http_client.get(\"/api/datasets\")\n",
- " if not ds.get(\"owner\", None) # filtering dataset with no workspace (owner field)\n",
+ " # filtering dataset with no workspace (use `\"owner\"` if you're running this code with server versions <=1.3.0)\n",
+ " if not ds.get(\"workspace\", None)\n",
"]\n",
"\n",
"rg.set_workspace(\"\") # working from the \"empty\" workspace\n",
diff --git a/frontend/components/commons/header/help-info/HelpInfo.spec.js b/frontend/components/commons/header/help-info/HelpInfo.spec.js
index 4a39e0c005..ec4f9736fd 100644
--- a/frontend/components/commons/header/help-info/HelpInfo.spec.js
+++ b/frontend/components/commons/header/help-info/HelpInfo.spec.js
@@ -20,7 +20,7 @@ const options = {
component: "helpInfoExplain",
},
propsData: {
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetTask: "TextClassification",
datasetName: "name",
},
diff --git a/frontend/components/commons/results/results.spec.js b/frontend/components/commons/results/results.spec.js
index d03da9e8c5..04ef8f98bb 100644
--- a/frontend/components/commons/results/results.spec.js
+++ b/frontend/components/commons/results/results.spec.js
@@ -9,7 +9,7 @@ const options = {
"Text2TextResultsList",
],
propsData: {
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetTask: "TextClassification",
datasetName: "name",
},
diff --git a/frontend/components/commons/results/resultsList.spec.js b/frontend/components/commons/results/resultsList.spec.js
index 586d5e4062..1b093feda6 100644
--- a/frontend/components/commons/results/resultsList.spec.js
+++ b/frontend/components/commons/results/resultsList.spec.js
@@ -21,7 +21,7 @@ const options = {
localVue,
store,
propsData: {
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetTask: "TextClassification",
},
};
diff --git a/frontend/components/commons/results/similarity/SimilarityRecordReference.component.spec.js b/frontend/components/commons/results/similarity/SimilarityRecordReference.component.spec.js
index 1224292455..4aac5320f2 100644
--- a/frontend/components/commons/results/similarity/SimilarityRecordReference.component.spec.js
+++ b/frontend/components/commons/results/similarity/SimilarityRecordReference.component.spec.js
@@ -6,7 +6,7 @@ let wrapper = null;
const options = {
stubs: ["nuxt", "results-record"],
propsData: {
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetTask: "TextClassification",
dataset: {
type: Object,
diff --git a/frontend/components/commons/taskSearch.spec.js b/frontend/components/commons/taskSearch.spec.js
index b3f12b9bda..929f27ea64 100644
--- a/frontend/components/commons/taskSearch.spec.js
+++ b/frontend/components/commons/taskSearch.spec.js
@@ -5,7 +5,7 @@ let wrapper = null;
const options = {
stubs: ["results"],
propsData: {
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetTask: "TextClassification",
datasetName: "name",
},
diff --git a/frontend/components/text-classifier/results/recordTextClassification.spec.js b/frontend/components/text-classifier/results/recordTextClassification.spec.js
index 94a3456d58..64090bae0b 100644
--- a/frontend/components/text-classifier/results/recordTextClassification.spec.js
+++ b/frontend/components/text-classifier/results/recordTextClassification.spec.js
@@ -6,7 +6,7 @@ const options = {
stubs: ["record-inputs", "classifier-exploration-area", "base-tag"],
propsData: {
viewSettings: {},
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetName: "name",
datasetLabels: ["label 1", "label 2"],
record: {
diff --git a/frontend/components/text2text/results/recordText2Text.spec.js b/frontend/components/text2text/results/recordText2Text.spec.js
index 20857b1f97..26c2df6b04 100644
--- a/frontend/components/text2text/results/recordText2Text.spec.js
+++ b/frontend/components/text2text/results/recordText2Text.spec.js
@@ -6,7 +6,7 @@ let wrapper = null;
const options = {
stubs: ["text-2-text-list", "record-string-text-2-text"],
propsData: {
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetName: "name",
viewSettings: {},
record: {
diff --git a/frontend/components/text2text/results/text2TextResultsList.spec.js b/frontend/components/text2text/results/text2TextResultsList.spec.js
index fc9b1a6f1f..bda7173df8 100644
--- a/frontend/components/text2text/results/text2TextResultsList.spec.js
+++ b/frontend/components/text2text/results/text2TextResultsList.spec.js
@@ -5,7 +5,7 @@ let wrapper = null;
const options = {
stubs: ["results-list"],
propsData: {
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetTask: "TextClassification",
},
};
diff --git a/frontend/components/token-classifier/results/textSpan.spec.js b/frontend/components/token-classifier/results/textSpan.spec.js
index f327deca1d..d2879eb1ad 100644
--- a/frontend/components/token-classifier/results/textSpan.spec.js
+++ b/frontend/components/token-classifier/results/textSpan.spec.js
@@ -9,7 +9,7 @@ const options = {
},
},
propsData: {
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetName: "name",
datasetLastSelectedEntity: {},
datasetEntities: [
diff --git a/frontend/components/token-classifier/results/tokenClassificationResultsList.spec.js b/frontend/components/token-classifier/results/tokenClassificationResultsList.spec.js
index 8a9c2f13e4..8951a7744c 100644
--- a/frontend/components/token-classifier/results/tokenClassificationResultsList.spec.js
+++ b/frontend/components/token-classifier/results/tokenClassificationResultsList.spec.js
@@ -5,7 +5,7 @@ let wrapper = null;
const options = {
stubs: ["results-list"],
propsData: {
- datasetId: ["owner", "name"],
+ datasetId: ["workspace", "name"],
datasetTask: "TextClassification",
},
};
diff --git a/frontend/database/modules/datasets.js b/frontend/database/modules/datasets.js
index 221699b533..d58006c1ec 100644
--- a/frontend/database/modules/datasets.js
+++ b/frontend/database/modules/datasets.js
@@ -69,7 +69,7 @@ async function _getOrFetchDataset({ workspace, name }) {
}
await ObservationDataset.api().get(`/datasets/${name}`, {
dataTransformer: ({ data }) => {
- data.owner = data.owner || workspace;
+ data.workspace = data.workspace || workspace;
return data;
},
});
@@ -600,7 +600,7 @@ const actions = {
persistBy: "create",
dataTransformer: ({ data }) => {
return data.map((datasource) => {
- datasource.owner = datasource.owner || NO_WORKSPACE;
+ datasource.workspace = datasource.workspace || NO_WORKSPACE;
return datasource;
});
},
diff --git a/frontend/models/Dataset.js b/frontend/models/Dataset.js
index 2d97418a3a..efc79c8eab 100644
--- a/frontend/models/Dataset.js
+++ b/frontend/models/Dataset.js
@@ -24,9 +24,9 @@ const USER_DATA_METADATA_KEY = "rubrix.recogn.ai/ui/custom/userData.v1";
class ObservationDataset extends Model {
static entity = "datasets";
- // TODO: Combine name + owner for primary key.
+ // TODO: Combine name + workspace for primary key.
// This should fix https://github.com/recognai/rubrix/issues/736
- static primaryKey = ["owner", "name"];
+ static primaryKey = ["workspace", "name"];
static #registeredDatasetClasses = {};
@@ -75,7 +75,7 @@ class ObservationDataset extends Model {
}
get id() {
- return [this.owner, this.name];
+ return [this.workspace, this.name];
}
get visibleRecords() {
@@ -85,19 +85,19 @@ class ObservationDataset extends Model {
static fields() {
return {
name: this.string(null),
- owner: this.string(null),
+ workspace: this.string(null),
metadata: this.attr(null),
tags: this.attr(null),
task: this.string(null),
created_at: this.string(null),
last_updated: this.string(null),
- // This will be normalized in a future PR using also owner for relational ids
+ // This will be normalized in a future PR using also workspace for relational ids
viewSettings: this.hasOne(DatasetViewSettings, "id", "name"),
};
}
}
-const getDatasetModelPrimaryKey = ({ owner, name }) => [owner, name];
+const getDatasetModelPrimaryKey = ({ workspace, name }) => [workspace, name];
export {
ObservationDataset,
diff --git a/frontend/models/TextClassification.js b/frontend/models/TextClassification.js
index ccc9144272..0d26a8e02a 100644
--- a/frontend/models/TextClassification.js
+++ b/frontend/models/TextClassification.js
@@ -152,7 +152,7 @@ class TextClassificationDataset extends ObservationDataset {
where: this.id,
data: [
{
- owner: this.owner,
+ workspace: this.workspace,
name: this.name,
_labels: labels,
settings,
@@ -303,7 +303,7 @@ class TextClassificationDataset extends ObservationDataset {
return await TextClassificationDataset.insertOrUpdate({
where: this.id,
data: {
- owner: this.owner,
+ workspace: this.workspace,
name: this.name,
perRuleQueryMetrics,
rulesOveralMetrics: overalMetrics,
@@ -315,7 +315,7 @@ class TextClassificationDataset extends ObservationDataset {
const rules = await this._fetchAllRules();
await TextClassificationDataset.insertOrUpdate({
data: {
- owner: this.owner,
+ workspace: this.workspace,
name: this.name,
rules,
},
@@ -394,7 +394,7 @@ class TextClassificationDataset extends ObservationDataset {
await TextClassificationDataset.insertOrUpdate({
data: {
- owner: this.owner,
+ workspace: this.workspace,
name: this.name,
activeRule: rule || {
query,
@@ -427,7 +427,7 @@ class TextClassificationDataset extends ObservationDataset {
await TextClassificationDataset.insertOrUpdate({
data: {
- owner: this.owner,
+ workspace: this.workspace,
name: this.name,
rules,
activeRule,
@@ -454,7 +454,7 @@ class TextClassificationDataset extends ObservationDataset {
await TextClassificationDataset.insertOrUpdate({
data: {
- owner: this.owner,
+ workspace: this.workspace,
name: this.name,
rules,
perRuleQueryMetrics,
@@ -466,7 +466,7 @@ class TextClassificationDataset extends ObservationDataset {
async clearCurrentLabelingRule() {
await TextClassificationDataset.insertOrUpdate({
data: {
- owner: this.owner,
+ workspace: this.workspace,
name: this.name,
activeRule: null,
activeRuleMetrics: null,
diff --git a/frontend/models/TokenClassification.js b/frontend/models/TokenClassification.js
index c937d5d154..5dd03c4382 100644
--- a/frontend/models/TokenClassification.js
+++ b/frontend/models/TokenClassification.js
@@ -104,7 +104,7 @@ class TokenClassificationDataset extends ObservationDataset {
where: this.id,
data: [
{
- owner: this.owner,
+ workspace: this.workspace,
name: this.name,
settings,
},
diff --git a/frontend/pages/datasets/index.vue b/frontend/pages/datasets/index.vue
index 3521c9c3b6..6267d64acb 100644
--- a/frontend/pages/datasets/index.vue
+++ b/frontend/pages/datasets/index.vue
@@ -82,7 +82,7 @@ export default {
{ name: "Name", field: "name", class: "table-info__title", type: "link" },
{
name: "Workspace",
- field: "owner",
+ field: "workspace",
class: "text",
type: "text",
filtrable: "true",
@@ -142,7 +142,7 @@ export default {
const tasks = this.tasks;
const tags = this.tags;
return [
- { column: "owner", values: workspaces || [] },
+ { column: "workspace", values: workspaces || [] },
{ column: "task", values: tasks || [] },
{ column: "tags", values: tags || [] },
];
@@ -217,7 +217,7 @@ export default {
_deleteDataset: "entities/datasets/deleteDataset",
}),
onColumnFilterApplied({ column, values }) {
- if (column === "owner") {
+ if (column === "workspace") {
if (values !== this.workspaces) {
this.$router.push({
query: { ...this.$route.query, workspace: values },
@@ -243,7 +243,7 @@ export default {
}
},
datasetWorkspace(dataset) {
- var workspace = dataset.owner;
+ var workspace = dataset.workspace;
if (workspace === null || workspace === "null") {
workspace = this.workspace;
}
@@ -285,7 +285,7 @@ export default {
},
deleteDataset(dataset) {
this._deleteDataset({
- workspace: dataset.owner,
+ workspace: dataset.workspace,
name: dataset.name,
});
this.closeModal();
diff --git a/frontend/specs/components/core/table/BaseTableInfo.spec.js b/frontend/specs/components/core/table/BaseTableInfo.spec.js
index 13b6b885a7..f6e02b996c 100644
--- a/frontend/specs/components/core/table/BaseTableInfo.spec.js
+++ b/frontend/specs/components/core/table/BaseTableInfo.spec.js
@@ -20,7 +20,7 @@ function mountBaseTableInfo() {
idx: 1,
key: "column1",
class: "text",
- field: "owner",
+ field: "workspace",
filtrable: "true",
name: "Workspace",
type: "text",
@@ -30,13 +30,13 @@ function mountBaseTableInfo() {
{
key: "data1",
name: "dataset_1",
- owner: "recognai",
+ workspace: "recognai",
task: "TokenClassification",
},
{
key: "data2",
name: "dataset_2",
- owner: "recognai",
+ workspace: "recognai",
task: "TokenClassification",
},
],
@@ -52,7 +52,7 @@ function mountBaseTableInfo() {
hideButton: false,
noDataInfo: undefined,
querySearch: undefined,
- filterFromRoute: "owner",
+ filterFromRoute: "workspace",
searchOn: "name",
sortedByField: "last_updated",
sortedOrder: "desc",
diff --git a/frontend/specs/components/core/table/TableFiltrableColumn.spec.js b/frontend/specs/components/core/table/TableFiltrableColumn.spec.js
index e41446f0d1..e6cc404f40 100644
--- a/frontend/specs/components/core/table/TableFiltrableColumn.spec.js
+++ b/frontend/specs/components/core/table/TableFiltrableColumn.spec.js
@@ -5,11 +5,11 @@ function mountTableFiltrableColumn() {
return mount(TableFiltrableColumn, {
propsData: {
filters: {
- owner: ["recognai"],
+ workspace: ["recognai"],
},
column: {
class: "text",
- field: "owner",
+ field: "workspace",
filtrable: "true",
name: "Workspace",
type: "text",
@@ -17,7 +17,7 @@ function mountTableFiltrableColumn() {
data: [
{
name: "dataset_1",
- owner: "recognai",
+ workspace: "recognai",
task: "TokenClassification",
},
],
diff --git a/frontend/specs/components/core/table/__snapshots__/BaseTableInfo.spec.js.snap b/frontend/specs/components/core/table/__snapshots__/BaseTableInfo.spec.js.snap
index bdd477500b..672efd231d 100644
--- a/frontend/specs/components/core/table/__snapshots__/BaseTableInfo.spec.js.snap
+++ b/frontend/specs/components/core/table/__snapshots__/BaseTableInfo.spec.js.snap
@@ -1,7 +1,7 @@
// Jest Snapshot v1, https://goo.gl/fbAQLP
exports[`BaseTableInfo renders properly 1`] = `
-
+
@@ -74,6 +76,8 @@ export default {
...mapActions({
discard: "entities/datasets/discardAnnotations",
validate: "entities/datasets/validateAnnotations",
+ updateRecords: "entities/datasets/updateDatasetRecords",
+ resetRecords: "entities/datasets/resetRecords",
}),
async onDiscard(records) {
@@ -83,18 +87,49 @@ export default {
});
},
async onValidate(records) {
- await this.validate({
+ try {
+ await this.validate({
+ dataset: this.dataset,
+ agent: this.$auth.user.username,
+ records: records.map((record) => {
+ return {
+ ...record,
+ annotatedEntities: undefined,
+ annotation: {
+ entities: record.annotatedEntities,
+ },
+ };
+ }),
+ });
+ } catch (err) {
+ console.log(err);
+ }
+ },
+ async onClear(records) {
+ const clearedRecords = records.map((record) => {
+ return {
+ ...record,
+ annotatedEntities: [],
+ annotation: null,
+ selected: true,
+ status: "Edited",
+ };
+ });
+ await this.updateRecords({
+ dataset: this.dataset,
+ records: clearedRecords,
+ });
+ },
+ async onReset(records) {
+ const restartedRecords = records.map((record) => {
+ return {
+ ...record,
+ annotatedEntities: record.annotation?.entities,
+ };
+ });
+ await this.resetRecords({
dataset: this.dataset,
- agent: this.$auth.user.username,
- records: records.map((record) => {
- return {
- ...record,
- annotatedEntities: undefined,
- annotation: {
- entities: record.annotatedEntities,
- },
- };
- }),
+ records: restartedRecords,
});
},
async onNewLabel(label) {
diff --git a/frontend/components/token-classifier/results/RecordTokenClassification.vue b/frontend/components/token-classifier/results/RecordTokenClassification.vue
index 07fa754b8c..7d7276a5bb 100755
--- a/frontend/components/token-classifier/results/RecordTokenClassification.vue
+++ b/frontend/components/token-classifier/results/RecordTokenClassification.vue
@@ -46,21 +46,14 @@
:entities="getEntitiesByOrigin('annotation')"
/>
-
-
- {{ record.status === "Edited" ? "Save" : "Validate" }}
-
-
- Clear annotations
-
-
+
@@ -152,11 +145,42 @@ export default {
);
return visualTokens;
},
+ tokenClassifierActionButtons() {
+ return [
+ {
+ id: "validate",
+ name: "Validate",
+ allow: true,
+ active: this.record.status === "Validated",
+ },
+ {
+ id: "discard",
+ name: "Discard",
+ allow: true,
+ active: this.record.status === "Discarded",
+ },
+ {
+ id: "clear",
+ name: "Clear",
+ allow: true,
+ disable: !this.record.annotatedEntities?.length || false,
+ },
+ {
+ id: "reset",
+ name: "Reset",
+ allow: true,
+ disable: this.record.status !== "Edited",
+ },
+ ];
+ },
},
methods: {
...mapActions({
validate: "entities/datasets/validateAnnotations",
+ discard: "entities/datasets/discardAnnotations",
updateRecords: "entities/datasets/updateDatasetRecords",
+ changeStatusToDefault: "entities/datasets/changeStatusToDefault",
+ resetRecords: "entities/datasets/resetRecords",
}),
getEntitiesByOrigin(origin) {
if (this.interactionsEnabled) {
@@ -172,23 +196,44 @@ export default {
: [];
}
},
- async onValidate(record) {
+ async toggleValidateRecord() {
+ if (this.record.status === "Validated") {
+ await this.onChangeStatusToDefault();
+ } else {
+ await this.onValidate();
+ }
+ },
+ async toggleDiscardRecord() {
+ if (this.record.status === "Discarded") {
+ await this.onChangeStatusToDefault();
+ } else {
+ this.onDiscard();
+ }
+ },
+ async onValidate() {
await this.validate({
// TODO: Move this as part of token classification dataset logic
dataset: this.getTokenClassificationDataset(),
agent: this.$auth.user.username,
records: [
{
- ...record,
+ ...this.record,
annotatedEntities: undefined,
annotation: {
- entities: record.annotatedEntities,
+ entities: this.record.annotatedEntities,
origin: "annotation",
},
},
],
});
},
+ async onChangeStatusToDefault() {
+ const currentRecordAndDataset = {
+ dataset: this.getTokenClassificationDataset(),
+ records: [this.record],
+ };
+ await this.changeStatusToDefault(currentRecordAndDataset);
+ },
onClearAnnotations() {
this.updateRecords({
dataset: this.getTokenClassificationDataset(),
@@ -202,6 +247,20 @@ export default {
],
});
},
+ async onReset() {
+ await this.resetRecords({
+ dataset: this.getTokenClassificationDataset(),
+ records: [
+ {
+ ...this.record,
+ annotatedEntities: this.record.annotation?.entities,
+ },
+ ],
+ });
+ },
+ onDiscard() {
+ this.$emit("discard");
+ },
getTokenClassificationDataset() {
return getTokenClassificationDatasetById(this.datasetId);
},
@@ -211,11 +270,14 @@ export default {