Skip to content

Commit

Permalink
Add cluster info to CogniteAPIError (#1942)
Browse files Browse the repository at this point in the history
  • Loading branch information
haakonvt authored Sep 25, 2024
1 parent d3cd8c9 commit 3abca3a
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 23 deletions.
1 change: 1 addition & 0 deletions cognite/client/_api/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,7 @@ def _raise_latest_exception(self, exceptions: list[Exception], successful: list[
unknown=AssetList(self.unknown),
failed=AssetList(self.failed),
unwrap_fn=op.attrgetter("external_id"),
cluster=self.assets_api._config.cdf_cluster,
)
err_message = "One or more errors happened during asset creation. Latest error:"
if isinstance(latest_exception, CogniteAPIError):
Expand Down
24 changes: 18 additions & 6 deletions cognite/client/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,12 @@ def _upsert_multiple(
# The created call failed
failed.extend(item.external_id if item.external_id is not None else item.id for item in to_update) # type: ignore [attr-defined]
raise CogniteAPIError(
api_error.message, code=api_error.code, successful=successful, failed=failed, unknown=unknown
api_error.message,
code=api_error.code,
successful=successful,
failed=failed,
unknown=unknown,
cluster=self._config.cdf_cluster,
)
# Need to retrieve the successful updated items from the first call.
successful_resources: T_CogniteResourceList | None = None
Expand Down Expand Up @@ -1272,8 +1277,7 @@ def _clear_all_attributes(update_attributes: list[PropertySpec]) -> dict[str, di
def _status_ok(status_code: int) -> bool:
return status_code in {200, 201, 202, 204}

@classmethod
def _raise_api_error(cls, res: Response, payload: dict) -> NoReturn:
def _raise_api_error(self, res: Response, payload: dict) -> NoReturn:
x_request_id = res.headers.get("X-Request-Id")
code = res.status_code
missing = None
Expand Down Expand Up @@ -1303,8 +1307,8 @@ def _raise_api_error(cls, res: Response, payload: dict) -> NoReturn:
if duplicated:
error_details["duplicated"] = duplicated
error_details["headers"] = res.request.headers.copy()
cls._sanitize_headers(error_details["headers"])
error_details["response_payload"] = shorten(cls._get_response_content_safe(res), 500)
self._sanitize_headers(error_details["headers"])
error_details["response_payload"] = shorten(self._get_response_content_safe(res), 500)
error_details["response_headers"] = res.headers

if res.history:
Expand All @@ -1313,7 +1317,15 @@ def _raise_api_error(cls, res: Response, payload: dict) -> NoReturn:
f"REDIRECT AFTER HTTP Error {res_hist.status_code} {res_hist.request.method} {res_hist.request.url}: {res_hist.content.decode()}"
)
logger.debug(f"HTTP Error {code} {res.request.method} {res.request.url}: {msg}", extra=error_details)
raise CogniteAPIError(msg, code, x_request_id, missing=missing, duplicated=duplicated, extra=extra)
raise CogniteAPIError(
msg,
code,
x_request_id,
missing=missing,
duplicated=duplicated,
extra=extra,
cluster=self._config.cdf_cluster,
)

def _log_request(self, res: Response, **kwargs: Any) -> None:
method = res.request.method
Expand Down
14 changes: 12 additions & 2 deletions cognite/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import getpass
import pprint
import re
import warnings
from contextlib import suppress
from typing import Any
Expand Down Expand Up @@ -184,6 +185,8 @@ def _validate_config(self) -> None:
raise ValueError(f"Invalid value for ClientConfig.project: <{self.project}>")
if not self.base_url:
raise ValueError(f"Invalid value for ClientConfig.base_url: <{self.base_url}>")
elif self.cdf_cluster is None:
warnings.warn(f"Given base URL may be invalid, please double-check: {self.base_url!r}", UserWarning)

def __str__(self) -> str:
return pprint.pformat({"max_workers": self.max_workers, **self.__dict__}, indent=4)
Expand Down Expand Up @@ -211,7 +214,7 @@ def default(
client_name=client_name or getpass.getuser(),
project=project,
credentials=credentials,
base_url=f"https://{cdf_cluster}.cognitedata.com/",
base_url=f"https://{cdf_cluster}.cognitedata.com",
)

@classmethod
Expand All @@ -233,7 +236,7 @@ def load(cls, config: dict[str, Any] | str) -> ClientConfig:
>>> config = {
... "client_name": "abcd",
... "project": "cdf-project",
... "base_url": "https://api.cognitedata.com/",
... "base_url": "https://api.cognitedata.com",
... "credentials": {
... "client_credentials": {
... "client_id": "abcd",
Expand Down Expand Up @@ -264,3 +267,10 @@ def load(cls, config: dict[str, Any] | str) -> ClientConfig:
file_transfer_timeout=loaded.get("file_transfer_timeout"),
debug=loaded.get("debug", False),
)

@property
def cdf_cluster(self) -> str | None:
# A best effort attempt to extract the cluster from the base url
if match := re.match(r"https?://([^/\.\s]+)\.cognitedata\.com(?::\d+)?(?:/|$)", self.base_url):
return match.group(1)
return None
5 changes: 5 additions & 0 deletions cognite/client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class CogniteAPIError(CogniteMultiException):
unknown (list | None): List of items which may or may not have been successfully processed.
skipped (list | None): List of items that were skipped due to "fail fast" mode.
unwrap_fn (Callable): Function to extract identifier from the Cognite resource.
cluster (str | None): Which Cognite cluster the user's project is on.
extra (dict | None): A dict of any additional information.
Examples:
Expand Down Expand Up @@ -170,18 +171,22 @@ def __init__(
unknown: list | None = None,
skipped: list | None = None,
unwrap_fn: Callable = no_op,
cluster: str | None = None,
extra: dict | None = None,
) -> None:
self.message = message
self.code = code
self.x_request_id = x_request_id
self.missing = missing
self.duplicated = duplicated
self.cluster = cluster
self.extra = extra
super().__init__(successful, failed, unknown, skipped, unwrap_fn)

def __str__(self) -> str:
msg = f"{self.message} | code: {self.code} | X-Request-ID: {self.x_request_id}"
if self.cluster:
msg += f" | cluster: {self.cluster}"
if self.missing:
msg += f"\nMissing: {self._truncate_elements(self.missing)}"
if self.duplicated:
Expand Down
10 changes: 6 additions & 4 deletions cognite/client/utils/_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self.not_found_error: Exception | None = None
self.duplicated_error: Exception | None = None
self.unknown_error: Exception | None = None
self.missing, self.duplicated = self._inspect_exceptions(exceptions)
self.missing, self.duplicated, self.cluster = self._inspect_exceptions(exceptions)

def joined_results(self, unwrap_fn: Callable = no_op) -> list:
joined_results: list = []
Expand Down Expand Up @@ -81,13 +81,14 @@ def raise_compound_exception_if_failed_tasks(
if self.duplicated_error:
self._raise_duplicated_error(str_format_element_fn, **task_lists)

def _inspect_exceptions(self, exceptions: list[Exception]) -> tuple[list, list]:
missing, duplicated = [], []
def _inspect_exceptions(self, exceptions: list[Exception]) -> tuple[list, list, str | None]:
cluster, missing, duplicated = None, [], []
for exc in exceptions:
if not isinstance(exc, CogniteAPIError):
self.unknown_error = exc
continue

cluster = cluster or exc.cluster
if exc.code in (400, 422) and exc.missing is not None:
missing.extend(exc.missing)
self.not_found_error = exc
Expand All @@ -97,7 +98,7 @@ def _inspect_exceptions(self, exceptions: list[Exception]) -> tuple[list, list]:
self.duplicated_error = exc
else:
self.unknown_error = exc
return missing, duplicated
return missing, duplicated, cluster

def _raise_basic_api_error(self, unwrap_fn: Callable, **task_lists: list) -> NoReturn:
if isinstance(self.unknown_error, CogniteAPIError) and (task_lists["failed"] or task_lists["unknown"]):
Expand All @@ -109,6 +110,7 @@ def _raise_basic_api_error(self, unwrap_fn: Callable, **task_lists: list) -> NoR
duplicated=self.duplicated,
extra=self.unknown_error.extra,
unwrap_fn=unwrap_fn,
cluster=self.cluster,
**task_lists,
)
raise self.unknown_error # type: ignore [misc]
Expand Down
27 changes: 23 additions & 4 deletions tests/tests_unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,20 @@ def test_load_non_existent_attr(self):
assert global_config.max_workers != 0


class TestClientConfig:
def test_default(self):
config = {
@pytest.fixture
def client_config():
return ClientConfig.default(
**{
"project": "test-project",
"cdf_cluster": "test-cluster",
"credentials": Token("abc"),
"client_name": "test-client",
}
client_config = ClientConfig.default(**config)
)


class TestClientConfig:
def test_default(self, client_config):
assert client_config.project == "test-project"
assert client_config.base_url == "https://test-cluster.cognitedata.com"
assert isinstance(client_config.credentials, Token)
Expand All @@ -100,3 +105,17 @@ def test_load(self, credentials):
assert isinstance(client_config.credentials, Token)
assert "Authorization", "Bearer abc" == client_config.credentials.authorization_header()
assert client_config.client_name == "test-client"

@pytest.mark.parametrize("protocol", ("http", "https"))
@pytest.mark.parametrize("end", ("", "/", ":8080", "/api/v1/", ":8080/api/v1/"))
def test_extract_cdf_cluster(self, client_config, protocol, end):
for valid in ("3D", "my_clus-ter", "jazz-testing-asia-northeast1-1", "trial-00ed82e12d9cbadfe28e4"):
client_config.base_url = f"{protocol}://{valid}.cognitedata.com{end}"
assert client_config.cdf_cluster == valid

for invalid in ("", ".", "..", "huh.my_cluster."):
client_config.base_url = f"{protocol}://{valid}cognitedata.com{end}"
assert client_config.cdf_cluster is None

client_config.base_url = "invalid"
assert client_config.cdf_cluster is None
27 changes: 20 additions & 7 deletions tests/tests_unit/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import pytest

from cognite.client.exceptions import CogniteAPIError


@pytest.fixture
def mock_get_400_error(rsps, cognite_client):
rsps.add(
rsps.GET,
cognite_client.assets._get_base_url_with_base_path() + "/any",
status=400,
json={"error": {"message": "bla", "extra": {"haha": "blabla"}, "other": "yup"}},
)


class TestAPIError:
def test_api_error(self):
e = CogniteAPIError(
Expand All @@ -19,15 +31,16 @@ def test_api_error(self):

assert "bla" in str(e)

def test_unknown_fields_in_api_error(self, rsps, cognite_client):
rsps.add(
rsps.GET,
cognite_client.assets._get_base_url_with_base_path() + "/any",
status=400,
json={"error": {"message": "bla", "extra": {"haha": "blabla"}, "other": "yup"}},
)
def test_unknown_fields_in_api_error(self, mock_get_400_error, cognite_client):
try:
cognite_client.get(url=f"/api/v1/projects/{cognite_client.config.project}/any")
assert False, "Call did not raise exception"
except CogniteAPIError as e:
assert {"extra": {"haha": "blabla"}, "other": "yup"} == e.extra

def test_api_errors_contain_cluster_info(self, mock_get_400_error, cognite_client):
try:
cognite_client.get(url=f"/api/v1/projects/{cognite_client.config.project}/any")
assert False, "Call did not raise exception"
except CogniteAPIError as e:
assert e.cluster == "api"

0 comments on commit 3abca3a

Please sign in to comment.