diff --git a/src/argilla/server/apis/v0/handlers/metrics.py b/src/argilla/server/apis/v0/handlers/metrics.py index 3f4e0b6743..9408f1e3f0 100644 --- a/src/argilla/server/apis/v0/handlers/metrics.py +++ b/src/argilla/server/apis/v0/handlers/metrics.py @@ -14,9 +14,9 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Optional +from typing import Any, Dict, List, Optional -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Query, Request, Security from pydantic import BaseModel, Field from argilla.server.apis.v0.helpers import deprecate_endpoint @@ -36,6 +36,8 @@ class MetricInfo(BaseModel): @dataclass class MetricSummaryParams: + request: Request + interval: Optional[float] = Query( default=None, gt=0.0, @@ -47,6 +49,15 @@ class MetricSummaryParams: description="The number of terms for terminological summaries", ) + @property + def parameters(self) -> Dict[str, Any]: + """Returns dynamic metric args found in the request query params""" + return { + "interval": self.interval, + "size": self.size, + **{k: v for k, v in self.request.query_params.items() if k not in ["interval", "size"]}, + } + def configure_router(router: APIRouter, cfg: TaskConfig): base_metrics_endpoint = f"/{cfg.task}/{{name}}/metrics" @@ -112,7 +123,7 @@ def metric_summary( metric=metric_, record_class=record_class, query=query, - **vars(metric_params), + **metric_params.parameters, ) diff --git a/src/argilla/server/apis/v0/handlers/text_classification.py b/src/argilla/server/apis/v0/handlers/text_classification.py index e5d4ca840b..59f9436425 100644 --- a/src/argilla/server/apis/v0/handlers/text_classification.py +++ b/src/argilla/server/apis/v0/handlers/text_classification.py @@ -316,7 +316,7 @@ async def list_labeling_rules( as_dataset_class=TasksFactory.get_task_dataset(task_type), ) - return [LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset)] + return [LabelingRule.parse_obj(rule) for rule in service.list_labeling_rules(dataset)] @deprecate_endpoint( path=f"{new_base_endpoint}/labeling/rules", @@ -379,7 +379,7 @@ async def compute_rule_metrics( as_dataset_class=TasksFactory.get_task_dataset(task_type), ) - return service.compute_rule_metrics(dataset, rule_query=query, labels=labels) + return service.compute_labeling_rule(dataset, rule_query=query, labels=labels) @deprecate_endpoint( path=f"{new_base_endpoint}/labeling/rules/metrics", @@ -404,7 +404,7 @@ async def compute_dataset_rules_metrics( workspace=common_params.workspace, as_dataset_class=TasksFactory.get_task_dataset(task_type), ) - metrics = service.compute_overall_rules_metrics(dataset) + metrics = service.compute_all_labeling_rules(dataset) return DatasetLabelingRulesMetricsSummary.parse_obj(metrics) @deprecate_endpoint( @@ -456,10 +456,7 @@ async def get_rule( workspace=common_params.workspace, as_dataset_class=TasksFactory.get_task_dataset(task_type), ) - rule = service.find_labeling_rule( - dataset, - rule_query=query, - ) + rule = service.find_labeling_rule(dataset, rule_query=query) return LabelingRule.parse_obj(rule) @deprecate_endpoint( diff --git a/src/argilla/server/daos/backend/metrics/base.py b/src/argilla/server/daos/backend/metrics/base.py index 82bd60a1e7..344dfeba18 100644 --- a/src/argilla/server/daos/backend/metrics/base.py +++ b/src/argilla/server/daos/backend/metrics/base.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses +import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from argilla.server.daos.backend.query_helpers import aggregations @@ -22,6 +23,9 @@ from argilla.server.daos.backend.client_adapters.base import IClientAdapter +_LOGGER = logging.getLogger(__file__) + + @dataclasses.dataclass class ElasticsearchMetric: id: str @@ -37,11 +41,14 @@ def __post_init__(self): def get_function_arg_names(func): return func.__code__.co_varnames - def aggregation_request(self, *args, **kwargs) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + def aggregation_request(self, *args, **kwargs) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]: """ Configures the summary es aggregation definition """ - return {self.id: self._build_aggregation(*args, **kwargs)} + try: + return {self.id: self._build_aggregation(*args, **kwargs)} + except TypeError as ex: + _LOGGER.warning(f"Cannot build metric for metric {self.id}. Error: {ex}. Skipping...") def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/src/argilla/server/daos/backend/metrics/text_classification.py b/src/argilla/server/daos/backend/metrics/text_classification.py index 8d3a83228c..cf54f5e12d 100644 --- a/src/argilla/server/daos/backend/metrics/text_classification.py +++ b/src/argilla/server/daos/backend/metrics/text_classification.py @@ -43,7 +43,7 @@ def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]: class LabelingRulesMetric(ElasticsearchMetric): id: str - def _build_aggregation(self, rule_query: str, labels: Optional[List[str]]) -> Dict[str, Any]: + def _build_aggregation(self, rule_query: str, labels: Optional[List[str]] = None) -> Dict[str, Any]: annotated_records_filter = filters.exists_field("annotated_as") rule_query_filter = filters.text_query(rule_query) aggr_filters = { diff --git a/src/argilla/server/security/settings.py b/src/argilla/server/security/settings.py deleted file mode 100644 index 168dce9eb1..0000000000 --- a/src/argilla/server/security/settings.py +++ /dev/null @@ -1,14 +0,0 @@ -# coding=utf-8 -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/argilla/server/services/datasets.py b/src/argilla/server/services/datasets.py index b6af813d4b..1c4e7f6843 100644 --- a/src/argilla/server/services/datasets.py +++ b/src/argilla/server/services/datasets.py @@ -230,6 +230,9 @@ async def get_settings( raise EntityNotFoundError(name=dataset.name, type=class_type) return class_type.parse_obj(settings.dict()) + def raw_dataset_update(self, dataset): + self.__dao__.update_dataset(dataset) + async def save_settings( self, user: User, dataset: ServiceDataset, settings: ServiceDatasetSettings ) -> ServiceDatasetSettings: diff --git a/src/argilla/server/services/metrics/service.py b/src/argilla/server/services/metrics/service.py index 9989690387..90b86cf294 100644 --- a/src/argilla/server/services/metrics/service.py +++ b/src/argilla/server/services/metrics/service.py @@ -116,3 +116,16 @@ def summarize_metric( dataset=dataset, query=query, ) + + def annotated_records(self, dataset: ServiceDataset) -> int: + """Return the number of annotated records for a dataset""" + results = self.__dao__.search_records( + dataset, + size=0, + search=DaoRecordsSearch(query=ServiceBaseRecordsQuery(has_annotation=True)), + ) + return results.total + + def total_records(self, dataset: ServiceDataset) -> int: + """Return the total number of records for a given dataset""" + return self.__dao__.search_records(dataset, size=0).total diff --git a/src/argilla/server/services/tasks/text_classification/__init__.py b/src/argilla/server/services/tasks/text_classification/__init__.py index 1a0c078cea..b8298320b7 100644 --- a/src/argilla/server/services/tasks/text_classification/__init__.py +++ b/src/argilla/server/services/tasks/text_classification/__init__.py @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .labeling_rules_service import LabelingService from .service import TextClassificationService diff --git a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py b/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py deleted file mode 100644 index 38471bfccc..0000000000 --- a/src/argilla/server/services/tasks/text_classification/labeling_rules_service.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Tuple - -from fastapi import Depends -from pydantic import BaseModel, Field - -from argilla.server.daos.datasets import DatasetsDAO -from argilla.server.daos.models.records import DaoRecordsSearch -from argilla.server.daos.records import DatasetRecordsDAO -from argilla.server.errors import EntityAlreadyExistsError, EntityNotFoundError -from argilla.server.services.search.model import ServiceBaseRecordsQuery -from argilla.server.services.tasks.text_classification.model import ( - ServiceLabelingRule, - ServiceTextClassificationDataset, -) - - -class DatasetLabelingRulesSummary(BaseModel): - covered_records: int - annotated_covered_records: int - - -class LabelingRuleSummary(BaseModel): - covered_records: int - annotated_covered_records: int - correct_records: int = Field(default=0) - incorrect_records: int = Field(default=0) - precision: Optional[float] = None - - -class LabelingService: - _INSTANCE = None - - @classmethod - def get_instance( - cls, - datasets: DatasetsDAO = Depends(DatasetsDAO.get_instance), - records: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance), - ): - if cls._INSTANCE is None: - cls._INSTANCE = cls(datasets, records) - return cls._INSTANCE - - def __init__(self, datasets: DatasetsDAO, records: DatasetRecordsDAO): - self.__datasets__ = datasets - self.__records__ = records - - # TODO(@frascuchon): Move all rules management methods to the common datasets service like settings - def list_rules(self, dataset: ServiceTextClassificationDataset) -> List[ServiceLabelingRule]: - """List a set of rules for a given dataset""" - return dataset.rules - - def delete_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str): - """Delete a rule from a dataset by its defined query string""" - new_rules_set = [r for r in dataset.rules if r.query != rule_query] - if len(dataset.rules) != new_rules_set: - dataset.rules = new_rules_set - self.__datasets__.update_dataset(dataset) - - def add_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> ServiceLabelingRule: - """Adds a rule to a dataset""" - for r in dataset.rules: - if r.query == rule.query: - raise EntityAlreadyExistsError(rule.query, type=ServiceLabelingRule) - dataset.rules.append(rule) - self.__datasets__.update_dataset(dataset) - return rule - - def compute_rule_metrics( - self, - dataset: ServiceTextClassificationDataset, - rule_query: str, - labels: Optional[List[str]] = None, - ) -> Tuple[int, int, LabelingRuleSummary]: - """Computes metrics for given rule query and optional label against a set of rules""" - - annotated_records = self._count_annotated_records(dataset) - dataset_records = self.__records__.search_records(dataset, size=0).total - metric_data = self.__records__.compute_metric( - dataset=dataset, - metric_id="labeling_rule", - metric_params=dict(rule_query=rule_query, labels=labels), - ) - - return ( - dataset_records, - annotated_records, - LabelingRuleSummary.parse_obj(metric_data), - ) - - def _count_annotated_records(self, dataset: ServiceTextClassificationDataset) -> int: - results = self.__records__.search_records( - dataset, - size=0, - search=DaoRecordsSearch(query=ServiceBaseRecordsQuery(has_annotation=True)), - ) - return results.total - - def all_rules_metrics( - self, dataset: ServiceTextClassificationDataset - ) -> Tuple[int, int, DatasetLabelingRulesSummary]: - annotated_records = self._count_annotated_records(dataset) - dataset_records = self.__records__.search_records(dataset, size=0).total - metric_data = self.__records__.compute_metric( - dataset=dataset, - metric_id="dataset_labeling_rules", - metric_params=dict(queries=[r.query for r in dataset.rules]), - ) - - return ( - dataset_records, - annotated_records, - DatasetLabelingRulesSummary.parse_obj(metric_data), - ) - - def find_rule_by_query(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule: - rule_query = rule_query.strip() - for rule in dataset.rules: - if rule.query == rule_query: - return rule - raise EntityNotFoundError(rule_query, type=ServiceLabelingRule) - - def replace_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule): - for idx, r in enumerate(dataset.rules): - if r.query == rule.query: - dataset.rules[idx] = rule - break - self.__datasets__.update_dataset(dataset) diff --git a/src/argilla/server/services/tasks/text_classification/metrics.py b/src/argilla/server/services/tasks/text_classification/metrics.py index d09891620b..2b9bd006a9 100644 --- a/src/argilla/server/services/tasks/text_classification/metrics.py +++ b/src/argilla/server/services/tasks/text_classification/metrics.py @@ -159,5 +159,13 @@ class TextClassificationMetrics(CommonTasksMetrics[ServiceTextClassificationReco id="annotated_as", name="Annotated labels distribution", ), + ServiceBaseMetric( + id="labeling_rule", + name="Labeling rule metric based on a query rule and a set of labels", + ), + ServiceBaseMetric( + id="dataset_labeling_rules", + name="Computes the overall labeling rules stats", + ), ] ) diff --git a/src/argilla/server/services/tasks/text_classification/model.py b/src/argilla/server/services/tasks/text_classification/model.py index 73c5cd7d05..f8a82db328 100644 --- a/src/argilla/server/services/tasks/text_classification/model.py +++ b/src/argilla/server/services/tasks/text_classification/model.py @@ -307,3 +307,16 @@ class ServiceTextClassificationQuery(ServiceBaseRecordsQuery): predicted: Optional[PredictionStatus] = Field(default=None, nullable=True) uncovered_by_rules: List[str] = Field(default_factory=list) + + +class DatasetLabelingRulesSummary(BaseModel): + covered_records: int + annotated_covered_records: int + + +class LabelingRuleSummary(BaseModel): + covered_records: int + annotated_covered_records: int + correct_records: int = Field(default=0) + incorrect_records: int = Field(default=0) + precision: Optional[float] = None diff --git a/src/argilla/server/services/tasks/text_classification/service.py b/src/argilla/server/services/tasks/text_classification/service.py index 8a5407258c..79d6ddb473 100644 --- a/src/argilla/server/services/tasks/text_classification/service.py +++ b/src/argilla/server/services/tasks/text_classification/service.py @@ -13,12 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Tuple from fastapi import Depends -from argilla.server.commons.config import TasksFactory -from argilla.server.errors.base_errors import MissingDatasetRecordsError +from argilla.server.errors.base_errors import ( + EntityAlreadyExistsError, + EntityNotFoundError, + MissingDatasetRecordsError, +) +from argilla.server.services.datasets import DatasetsService +from argilla.server.services.metrics import MetricsService from argilla.server.services.search.model import ( ServiceSearchResults, ServiceSortableField, @@ -27,10 +32,14 @@ from argilla.server.services.search.service import SearchRecordsService from argilla.server.services.storage.service import RecordsStorageService from argilla.server.services.tasks.commons import BulkResponse -from argilla.server.services.tasks.text_classification import LabelingService +from argilla.server.services.tasks.text_classification.metrics import ( + TextClassificationMetrics, +) from argilla.server.services.tasks.text_classification.model import ( DatasetLabelingRulesMetricsSummary, + DatasetLabelingRulesSummary, LabelingRuleMetricsSummary, + LabelingRuleSummary, ServiceLabelingRule, ServiceTextClassificationDataset, ServiceTextClassificationQuery, @@ -49,23 +58,26 @@ class TextClassificationService: @classmethod def get_instance( cls, + datasets: DatasetsService = Depends(DatasetsService.get_instance), + metrics: MetricsService = Depends(MetricsService.get_instance), storage: RecordsStorageService = Depends(RecordsStorageService.get_instance), - labeling: LabelingService = Depends(LabelingService.get_instance), search: SearchRecordsService = Depends(SearchRecordsService.get_instance), ) -> "TextClassificationService": if not cls._INSTANCE: - cls._INSTANCE = cls(storage, labeling=labeling, search=search) + cls._INSTANCE = cls(datasets=datasets, metrics=metrics, storage=storage, search=search) return cls._INSTANCE def __init__( self, + datasets: DatasetsService, + metrics: MetricsService, storage: RecordsStorageService, search: SearchRecordsService, - labeling: LabelingService, ): self.__storage__ = storage self.__search__ = search - self.__labeling__ = labeling + self.__metrics__ = metrics + self.__datasets__ = datasets async def add_records( self, @@ -116,9 +128,9 @@ def search( """ - metrics = TasksFactory.find_task_metrics( - dataset.task, - metric_ids={ + metrics = [ + TextClassificationMetrics.find_metric(id) + for id in { "words_cloud", "predicted_by", "predicted_as", @@ -128,8 +140,8 @@ def search( "status_distribution", "metadata", "score", - }, - ) + } + ] results = self.__search__.search( dataset, @@ -209,8 +221,18 @@ def _is_dataset_multi_label(self, dataset: ServiceTextClassificationDataset) -> if results.records: return results.records[0].multi_label - def get_labeling_rules(self, dataset: ServiceTextClassificationDataset) -> Iterable[ServiceLabelingRule]: - return self.__labeling__.list_rules(dataset) + def find_labeling_rule( + self, dataset: ServiceTextClassificationDataset, rule_query: str, error_on_missing: bool = True + ) -> Optional[ServiceLabelingRule]: + rule_query = rule_query.strip() + for rule in dataset.rules: + if rule.query == rule_query: + return rule + if error_on_missing: + raise EntityNotFoundError(rule_query, type=ServiceLabelingRule) + + def list_labeling_rules(self, dataset: ServiceTextClassificationDataset) -> Iterable[ServiceLabelingRule]: + return dataset.rules def add_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule: ServiceLabelingRule) -> None: """ @@ -224,8 +246,13 @@ def add_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule: Ser rule: The rule """ + self.__normalized_rule__(rule) - self.__labeling__.add_rule(dataset, rule) + if self.find_labeling_rule(dataset, rule_query=rule.query, error_on_missing=False): + raise EntityAlreadyExistsError(rule.query, type=ServiceLabelingRule) + + dataset.rules.append(rule) + self.__datasets__.raw_dataset_update(dataset) def update_labeling_rule( self, @@ -234,7 +261,7 @@ def update_labeling_rule( labels: List[str], description: Optional[str] = None, ) -> ServiceLabelingRule: - found_rule = self.__labeling__.find_rule_by_query(dataset, rule_query) + found_rule = self.find_labeling_rule(dataset, rule_query) found_rule.labels = labels found_rule.label = labels[0] if len(labels) == 1 else None @@ -242,17 +269,22 @@ def update_labeling_rule( found_rule.description = description self.__normalized_rule__(found_rule) - self.__labeling__.replace_rule(dataset, found_rule) - return found_rule + for idx, r in enumerate(dataset.rules): + if r.query == found_rule.query: + dataset.rules[idx] = found_rule + break + self.__datasets__.raw_dataset_update(dataset) - def find_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str) -> ServiceLabelingRule: - return self.__labeling__.find_rule_by_query(dataset, rule_query=rule_query) + return found_rule def delete_labeling_rule(self, dataset: ServiceTextClassificationDataset, rule_query: str): - if rule_query.strip(): - return self.__labeling__.delete_rule(dataset, rule_query) + """Delete a rule from a dataset by its defined query string""" + new_rules_set = [r for r in dataset.rules if r.query != rule_query] + if len(dataset.rules) != new_rules_set: + dataset.rules = new_rules_set + self.__datasets__.raw_dataset_update(dataset) - def compute_rule_metrics( + def compute_labeling_rule( self, dataset: ServiceTextClassificationDataset, rule_query: str, @@ -288,14 +320,20 @@ def compute_rule_metrics( rule_query = rule_query.strip() if labels is None: - for rule in self.get_labeling_rules(dataset): - if rule.query == rule_query: - labels = rule.labels - break + rule = self.find_labeling_rule(dataset, rule_query=rule_query, error_on_missing=False) + if rule: + labels = rule.labels - total, annotated, metrics = self.__labeling__.compute_rule_metrics( - dataset, rule_query=rule_query, labels=labels + metric_data = self.__metrics__.summarize_metric( + dataset=dataset, + metric=TextClassificationMetrics.find_metric("labeling_rule"), + rule_query=rule_query, + labels=labels, ) + annotated = self.__metrics__.annotated_records(dataset) + total = self.__metrics__.total_records(dataset) + + metrics = LabelingRuleSummary.parse_obj(metric_data) coverage = metrics.covered_records / total if total > 0 else None coverage_annotated = metrics.annotated_covered_records / annotated if annotated > 0 else None @@ -310,8 +348,8 @@ def compute_rule_metrics( precision=metrics.precision if annotated > 0 else None, ) - def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDataset): - total, annotated, metrics = self.__labeling__.all_rules_metrics(dataset) + def compute_all_labeling_rules(self, dataset: ServiceTextClassificationDataset): + total, annotated, metrics = self._compute_all_lb_rules_metrics(dataset) coverage = metrics.covered_records / total if total else None coverage_annotated = metrics.annotated_covered_records / annotated if annotated else None return DatasetLabelingRulesMetricsSummary( @@ -321,6 +359,23 @@ def compute_overall_rules_metrics(self, dataset: ServiceTextClassificationDatase annotated_records=annotated, ) + def _compute_all_lb_rules_metrics( + self, dataset: ServiceTextClassificationDataset + ) -> Tuple[int, int, DatasetLabelingRulesSummary]: + annotated_records = self.__metrics__.annotated_records(dataset) + dataset_records = self.__metrics__.total_records(dataset) + metric_data = self.__metrics__.summarize_metric( + dataset=dataset, + metric=TextClassificationMetrics.find_metric(id="dataset_labeling_rules"), + queries=[r.query for r in dataset.rules], + ) + + return ( + dataset_records, + annotated_records, + DatasetLabelingRulesSummary.parse_obj(metric_data), + ) + @staticmethod def __normalized_rule__(rule: ServiceLabelingRule) -> ServiceLabelingRule: if rule.labels and len(rule.labels) == 1: diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py index 0dc30ba139..dec0169c54 100644 --- a/tests/server/metrics/test_api.py +++ b/tests/server/metrics/test_api.py @@ -26,11 +26,15 @@ TokenClassificationRecord, ) from argilla.server.services.metrics.models import CommonTasksMetrics +from argilla.server.services.tasks.text_classification.metrics import ( + TextClassificationMetrics, +) from argilla.server.services.tasks.token_classification.metrics import ( TokenClassificationMetrics, ) COMMON_METRICS_LENGTH = len(CommonTasksMetrics.metrics) +CLASSIFICATION_METRICS_LENGTH = len(TextClassificationMetrics.metrics) def test_wrong_dataset_metrics(mocked_client): @@ -183,7 +187,7 @@ def test_dataset_metrics(mocked_client): metrics = mocked_client.get(f"/api/datasets/TextClassification/{dataset}/metrics").json() - assert len(metrics) == COMMON_METRICS_LENGTH + 5 + assert len(metrics) == CLASSIFICATION_METRICS_LENGTH response = mocked_client.post( f"/api/datasets/TextClassification/{dataset}/metrics/missing_metric:summary", @@ -206,6 +210,38 @@ def test_dataset_metrics(mocked_client): assert response.status_code == 200, f"{metric}: {response.json()}" +def create_some_classification_data(mocked_client, dataset: str, records: list): + request = TextClassificationBulkRequest(records=[TextClassificationRecord.parse_obj(r) for r in records]) + + assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 + assert ( + mocked_client.post( + f"/api/datasets/{dataset}/TextClassification:bulk", + json=request.dict(by_alias=True), + ).status_code + == 200 + ) + + +def test_labeling_rule_metric(mocked_client): + dataset = "test_labeling_rule_metric" + create_some_classification_data( + mocked_client, dataset, records=[{"inputs": {"text": "This is classification record"}}] * 10 + ) + + rule_query = "t*" + response = mocked_client.post( + f"/api/datasets/TextClassification/{dataset}/metrics/labeling_rule:summary?rule_query={rule_query}", + json={}, + ) + assert response.json() == { + "annotated_covered_records": 0, + "correct_records": 0, + "covered_records": 10, + "incorrect_records": 0, + } + + def test_dataset_labels_for_text_classification(mocked_client): records = [ TextClassificationRecord.parse_obj(data)