Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove the classification labeling rules service #2361

Merged
merged 8 commits into from
Feb 20, 2023
17 changes: 14 additions & 3 deletions src/argilla/server/apis/v0/handlers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# limitations under the License.

from dataclasses import dataclass
from typing import List, Optional
from typing import Any, Dict, List, Optional

from fastapi import APIRouter, Depends, Query, Security
from fastapi import APIRouter, Depends, Query, Request, Security
from pydantic import BaseModel, Field

from argilla.server.apis.v0.helpers import deprecate_endpoint
Expand All @@ -36,6 +36,8 @@ class MetricInfo(BaseModel):

@dataclass
class MetricSummaryParams:
request: Request

interval: Optional[float] = Query(
default=None,
gt=0.0,
Expand All @@ -47,6 +49,15 @@ class MetricSummaryParams:
description="The number of terms for terminological summaries",
)

@property
def parameters(self) -> Dict[str, Any]:
"""Returns dynamic metric args found in the request query params"""
return {
"interval": self.interval,
"size": self.size,
**{k: v for k, v in self.request.query_params.items() if k not in ["interval", "size"]},
}


def configure_router(router: APIRouter, cfg: TaskConfig):
base_metrics_endpoint = f"/{cfg.task}/{{name}}/metrics"
Expand Down Expand Up @@ -112,7 +123,7 @@ def metric_summary(
metric=metric_,
record_class=record_class,
query=query,
**vars(metric_params),
**metric_params.parameters,
)


Expand Down
11 changes: 4 additions & 7 deletions src/argilla/server/apis/v0/handlers/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ async def list_labeling_rules(
as_dataset_class=TasksFactory.get_task_dataset(task_type),
)

return [LabelingRule.parse_obj(rule) for rule in service.get_labeling_rules(dataset)]
return [LabelingRule.parse_obj(rule) for rule in service.list_labeling_rules(dataset)]

@deprecate_endpoint(
path=f"{new_base_endpoint}/labeling/rules",
Expand Down Expand Up @@ -379,7 +379,7 @@ async def compute_rule_metrics(
as_dataset_class=TasksFactory.get_task_dataset(task_type),
)

return service.compute_rule_metrics(dataset, rule_query=query, labels=labels)
return service.compute_labeling_rule(dataset, rule_query=query, labels=labels)

@deprecate_endpoint(
path=f"{new_base_endpoint}/labeling/rules/metrics",
Expand All @@ -404,7 +404,7 @@ async def compute_dataset_rules_metrics(
workspace=common_params.workspace,
as_dataset_class=TasksFactory.get_task_dataset(task_type),
)
metrics = service.compute_overall_rules_metrics(dataset)
metrics = service.compute_all_labeling_rules(dataset)
return DatasetLabelingRulesMetricsSummary.parse_obj(metrics)

@deprecate_endpoint(
Expand Down Expand Up @@ -456,10 +456,7 @@ async def get_rule(
workspace=common_params.workspace,
as_dataset_class=TasksFactory.get_task_dataset(task_type),
)
rule = service.find_labeling_rule(
dataset,
rule_query=query,
)
rule = service.find_labeling_rule(dataset, rule_query=query)
return LabelingRule.parse_obj(rule)

@deprecate_endpoint(
Expand Down
11 changes: 9 additions & 2 deletions src/argilla/server/daos/backend/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import dataclasses
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from argilla.server.daos.backend.query_helpers import aggregations
Expand All @@ -22,6 +23,9 @@
from argilla.server.daos.backend.client_adapters.base import IClientAdapter


_LOGGER = logging.getLogger(__file__)


@dataclasses.dataclass
class ElasticsearchMetric:
id: str
Expand All @@ -37,11 +41,14 @@ def __post_init__(self):
def get_function_arg_names(func):
return func.__code__.co_varnames

def aggregation_request(self, *args, **kwargs) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
def aggregation_request(self, *args, **kwargs) -> Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]:
"""
Configures the summary es aggregation definition
"""
return {self.id: self._build_aggregation(*args, **kwargs)}
try:
return {self.id: self._build_aggregation(*args, **kwargs)}
except TypeError as ex:
_LOGGER.warning(f"Cannot build metric for metric {self.id}. Error: {ex}. Skipping...")

def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _build_aggregation(self, queries: List[str]) -> Dict[str, Any]:
class LabelingRulesMetric(ElasticsearchMetric):
id: str

def _build_aggregation(self, rule_query: str, labels: Optional[List[str]]) -> Dict[str, Any]:
def _build_aggregation(self, rule_query: str, labels: Optional[List[str]] = None) -> Dict[str, Any]:
annotated_records_filter = filters.exists_field("annotated_as")
rule_query_filter = filters.text_query(rule_query)
aggr_filters = {
Expand Down
14 changes: 0 additions & 14 deletions src/argilla/server/security/settings.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/argilla/server/services/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ async def get_settings(
raise EntityNotFoundError(name=dataset.name, type=class_type)
return class_type.parse_obj(settings.dict())

def raw_dataset_update(self, dataset):
self.__dao__.update_dataset(dataset)

async def save_settings(
self, user: User, dataset: ServiceDataset, settings: ServiceDatasetSettings
) -> ServiceDatasetSettings:
Expand Down
13 changes: 13 additions & 0 deletions src/argilla/server/services/metrics/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,16 @@ def summarize_metric(
dataset=dataset,
query=query,
)

def annotated_records(self, dataset: ServiceDataset) -> int:
"""Return the number of annotated records for a dataset"""
results = self.__dao__.search_records(
dataset,
size=0,
search=DaoRecordsSearch(query=ServiceBaseRecordsQuery(has_annotation=True)),
)
return results.total

def total_records(self, dataset: ServiceDataset) -> int:
"""Return the total number of records for a given dataset"""
return self.__dao__.search_records(dataset, size=0).total
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .labeling_rules_service import LabelingService
from .service import TextClassificationService

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,13 @@ class TextClassificationMetrics(CommonTasksMetrics[ServiceTextClassificationReco
id="annotated_as",
name="Annotated labels distribution",
),
ServiceBaseMetric(
id="labeling_rule",
name="Labeling rule metric based on a query rule and a set of labels",
),
ServiceBaseMetric(
id="dataset_labeling_rules",
name="Computes the overall labeling rules stats",
),
]
)
13 changes: 13 additions & 0 deletions src/argilla/server/services/tasks/text_classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,16 @@ class ServiceTextClassificationQuery(ServiceBaseRecordsQuery):
predicted: Optional[PredictionStatus] = Field(default=None, nullable=True)

uncovered_by_rules: List[str] = Field(default_factory=list)


class DatasetLabelingRulesSummary(BaseModel):
covered_records: int
annotated_covered_records: int


class LabelingRuleSummary(BaseModel):
covered_records: int
annotated_covered_records: int
correct_records: int = Field(default=0)
incorrect_records: int = Field(default=0)
precision: Optional[float] = None
Loading