Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: import errors when importing from argilla.feedback #3471

Merged
merged 8 commits into from
Jul 27, 2023
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ These are the section headers that we use:

## [Unreleased]

## [1.13.3](https://github.com/argilla-io/argilla/compare/v1.13.2...v1.13.3)

### Fixed

- Fixed `ModuleNotFoundError` caused because the `argilla.utils.telemetry` module used in the `ArgillaTrainer` was importing an optional dependency not installed by default ([#3471](https://github.com/argilla-io/argilla/pull/3471)).
- Fixed `ImportError` caused because the `argilla.client.feedback.config` module was importing `pyyaml` optional dependency not installed by default ([#3471](https://github.com/argilla-io/argilla/pull/3471)).

## [1.13.2](https://github.com/argilla-io/argilla/compare/v1.13.1...v1.13.2)

### Fixed
Expand Down
10 changes: 9 additions & 1 deletion src/argilla/client/feedback/integrations/huggingface/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from packaging.version import parse as parse_version

from argilla.client.feedback.config import DatasetConfig, DeprecatedDatasetConfig
from argilla.client.feedback.constants import FIELD_TYPE_TO_PYTHON_TYPE
from argilla.client.feedback.schemas import FeedbackRecord
from argilla.client.feedback.types import AllowedQuestionTypes
Expand Down Expand Up @@ -188,6 +187,9 @@ def push_to_huggingface(
import huggingface_hub
from huggingface_hub import DatasetCardData, HfApi

# https://github.com/argilla-io/argilla/issues/3468
from argilla.client.feedback.config import DatasetConfig
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved

if parse_version(huggingface_hub.__version__) < parse_version("0.14.0"):
_LOGGER.warning(
"Recommended `huggingface_hub` version is 0.14.0 or higher, and you have"
Expand Down Expand Up @@ -261,6 +263,12 @@ def from_huggingface(cls: Type["FeedbackDataset"], repo_id: str, *args: Any, **k
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError

# https://github.com/argilla-io/argilla/issues/3468
from argilla.client.feedback.config import (
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
DatasetConfig,
DeprecatedDatasetConfig,
)

if parse_version(huggingface_hub.__version__) < parse_version("0.14.0"):
_LOGGER.warning(
"Recommended `huggingface_hub` version is 0.14.0 or higher, and you have"
Expand Down
24 changes: 19 additions & 5 deletions src/argilla/server/errors/api_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Dict

from fastapi import HTTPException, Request
from fastapi.exception_handlers import http_exception_handler
from pydantic import BaseModel

from argilla.server.errors.adapter import exception_to_argilla_error
from argilla.server.errors.base_errors import ServerError
from argilla.server.errors.base_errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
GenericServerError,
ServerError,
)
from argilla.utils import telemetry

_LOGGER = logging.getLogger("argilla")


class ErrorDetail(BaseModel):
code: str
Expand All @@ -41,10 +43,22 @@ def __init__(self, error: ServerError):


class APIErrorHandler:
@staticmethod
async def track_error(error: ServerError, request: Request):
data = {
"code": error.code,
"user-agent": request.headers.get("user-agent"),
"accept-language": request.headers.get("accept-language"),
}
if isinstance(error, (GenericServerError, EntityNotFoundError, EntityAlreadyExistsError)):
data["type"] = error.type

telemetry.get_telemetry_client().track_data(action="ServerErrorFound", data=data)

@staticmethod
async def common_exception_handler(request: Request, error: Exception):
"""Wraps errors as custom generic error"""
argilla_error = exception_to_argilla_error(error)
await telemetry.track_error(argilla_error, request=request)
await APIErrorHandler.track_error(argilla_error, request=request)

return await http_exception_handler(request, ServerHTTPException(argilla_error))
27 changes: 6 additions & 21 deletions src/argilla/utils/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,14 @@
import logging
import platform
import uuid
from typing import Any, Dict, Optional

from fastapi import Request
from typing import TYPE_CHECKING, Any, Dict, Optional

from argilla.server.commons.models import TaskType
from argilla.server.errors.base_errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
GenericServerError,
ServerError,
)
from argilla.server.settings import settings

if TYPE_CHECKING:
from fastapi import Request

try:
from analytics import Client # This import works only for version 2.2.0
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -89,25 +84,15 @@ def track_data(self, action: str, data: Dict[str, Any], include_system_info: boo
_CLIENT = TelemetryClient()


def _process_request_info(request: Request):
def _process_request_info(request: "Request"):
return {header: request.headers.get(header) for header in ["user-agent", "accept-language"]}


async def track_error(error: ServerError, request: Request):
data = {"code": error.code}
if isinstance(error, (GenericServerError, EntityNotFoundError, EntityAlreadyExistsError)):
data["type"] = error.type

data.update(_process_request_info(request))

_CLIENT.track_data(action="ServerErrorFound", data=data)


async def track_bulk(task: TaskType, records: int):
_CLIENT.track_data(action="LogRecordsRequested", data={"task": task, "records": records})


async def track_login(request: Request, username: str):
async def track_login(request: "Request", username: str):
_CLIENT.track_data(
action="UserInfoRequested",
data={
Expand Down
54 changes: 0 additions & 54 deletions tests/server/commons/test_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@

import pytest
from argilla.server.commons.models import TaskType
from argilla.server.errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
GenericServerError,
ServerError,
)
from argilla.server.schemas.datasets import Dataset
from argilla.utils import telemetry
from argilla.utils.telemetry import TelemetryClient, get_telemetry_client
from fastapi import Request
Expand Down Expand Up @@ -57,50 +50,3 @@ async def test_track_bulk(test_telemetry):

await telemetry.track_bulk(task=task, records=records)
test_telemetry.assert_called_once_with("LogRecordsRequested", {"task": task, "records": records})


@pytest.mark.asyncio
@pytest.mark.parametrize(
["error", "expected_event"],
[
(
EntityNotFoundError(name="mock-name", type="MockType"),
{
"accept-language": None,
"code": "argilla.api.errors::EntityNotFoundError",
"type": "MockType",
"user-agent": None,
},
),
(
EntityAlreadyExistsError(name="mock-name", type=Dataset, workspace="mock-workspace"),
{
"accept-language": None,
"code": "argilla.api.errors::EntityAlreadyExistsError",
"type": "Dataset",
"user-agent": None,
},
),
(
GenericServerError(RuntimeError("This is a mock error")),
{
"accept-language": None,
"code": "argilla.api.errors::GenericServerError",
"type": "builtins.RuntimeError",
"user-agent": None,
},
),
(
ServerError(),
{
"accept-language": None,
"code": "argilla.api.errors::ServerError",
"user-agent": None,
},
),
],
)
async def test_track_error(test_telemetry, error, expected_event):
await telemetry.track_error(error, request=mock_request)

test_telemetry.assert_called_once_with("ServerErrorFound", expected_event)
13 changes: 13 additions & 0 deletions tests/server/errors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
75 changes: 75 additions & 0 deletions tests/server/errors/test_api_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from argilla.server.errors.api_errors import APIErrorHandler
from argilla.server.errors.base_errors import (
EntityAlreadyExistsError,
EntityNotFoundError,
GenericServerError,
ServerError,
)
from argilla.server.schemas.datasets import Dataset
from fastapi import Request

mock_request = Request(scope={"type": "http", "headers": {}})


@pytest.mark.asyncio
class TestAPIErrorHandler:
@pytest.mark.asyncio
@pytest.mark.parametrize(
["error", "expected_event"],
[
(
EntityNotFoundError(name="mock-name", type="MockType"),
{
"accept-language": None,
"code": "argilla.api.errors::EntityNotFoundError",
"type": "MockType",
"user-agent": None,
},
),
(
EntityAlreadyExistsError(name="mock-name", type=Dataset, workspace="mock-workspace"),
{
"accept-language": None,
"code": "argilla.api.errors::EntityAlreadyExistsError",
"type": "Dataset",
"user-agent": None,
},
),
(
GenericServerError(RuntimeError("This is a mock error")),
{
"accept-language": None,
"code": "argilla.api.errors::GenericServerError",
"type": "builtins.RuntimeError",
"user-agent": None,
},
),
(
ServerError(),
{
"accept-language": None,
"code": "argilla.api.errors::ServerError",
"user-agent": None,
},
),
],
)
async def test_track_error(self, test_telemetry, error, expected_event):
await APIErrorHandler.track_error(error, request=mock_request)

test_telemetry.assert_called_once_with("ServerErrorFound", expected_event)
Loading