Skip to content

Commit

Permalink
make logger required
Browse files Browse the repository at this point in the history
  • Loading branch information
slicklash committed Jul 4, 2023
1 parent 3d03340 commit 882d591
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 62 deletions.
5 changes: 2 additions & 3 deletions granulate_utils/config_feeder/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from requests import Session

from granulate_utils.config_feeder.client.exceptions import MaximumRetriesExceeded
from granulate_utils.config_feeder.client.logging import get_logger

T = TypeVar("T")


class ConfigCollectorBase(ABC):
def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter] = None) -> None:
self.logger = logger or get_logger()
def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter]) -> None:
self.logger = logger
self._max_retries = max_retries
self._failed_requests = 0
self._init_session()
Expand Down
2 changes: 1 addition & 1 deletion granulate_utils/config_feeder/client/bigdata/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_databricks_node_info() -> Optional[NodeInfo]:
return NodeInfo(
external_id=instance_id,
external_cluster_id=properties["spark.databricks.clusterUsageTags.clusterId"],
is_master=instance_id == driver_instance_id,
is_master=(instance_id == driver_instance_id),
provider=provider,
bigdata_platform=BigDataPlatform.DATABRICKS,
properties=properties,
Expand Down
7 changes: 3 additions & 4 deletions granulate_utils/config_feeder/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from granulate_utils.config_feeder.client.bigdata import get_node_info
from granulate_utils.config_feeder.client.exceptions import APIError, ClientError
from granulate_utils.config_feeder.client.logging import get_logger
from granulate_utils.config_feeder.client.models import CollectionResult, ConfigType
from granulate_utils.config_feeder.client.yarn.collector import YarnConfigCollector
from granulate_utils.config_feeder.client.yarn.models import YarnConfig
Expand All @@ -36,21 +35,21 @@ def __init__(
token: str,
service: str,
*,
logger: Union[logging.Logger, logging.LoggerAdapter] = None,
logger: Union[logging.Logger, logging.LoggerAdapter],
server_address: Optional[str] = None,
yarn: bool = True,
collector=CollectorType.SAGENT,
) -> None:
if not token or not service:
raise ClientError("Token and service must be provided")
self.logger = logger or get_logger()
self.logger = logger
self._token = token
self._service = service
self._cluster_id: Optional[str] = None
self._collector = collector
self._server_address: str = server_address.rstrip("/") if server_address else DEFAULT_API_SERVER_ADDRESS
self._is_yarn_enabled = yarn
self._yarn_collector = YarnConfigCollector()
self._yarn_collector = YarnConfigCollector(logger=logger)
self._last_hash: DefaultDict[ConfigType, Dict[str, str]] = defaultdict(dict)
self._init_api_session()

Expand Down
20 changes: 0 additions & 20 deletions granulate_utils/config_feeder/client/logging.py

This file was deleted.

2 changes: 1 addition & 1 deletion granulate_utils/config_feeder/client/spark/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class SparkConfigCollector(ConfigCollectorBase):
def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter] = None) -> None:
def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter]) -> None:
super().__init__(max_retries=max_retries, logger=logger)
self._history_address = SPARK_HISTORY_DEFAULT_ADDRESS

Expand Down
2 changes: 1 addition & 1 deletion granulate_utils/config_feeder/client/yarn/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class YarnConfigCollector(ConfigCollectorBase):
def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter] = None) -> None:
def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter]) -> None:
super().__init__(max_retries=max_retries, logger=logger)
self._resource_manager_address = RM_DEFAULT_ADDRESS
self._is_address_detected = False
Expand Down
8 changes: 8 additions & 0 deletions tests/granulate_utils/config_feeder/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import pytest
import logging

from granulate_utils.config_feeder.client.bigdata import get_node_info


@pytest.fixture(autouse=True)
def clear_cache() -> None:
get_node_info.cache_clear()


@pytest.fixture(scope="session")
def logger() -> logging.Logger:
logger = logging.getLogger("test-logger")
logger.addHandler(logging.NullHandler())
return logger
32 changes: 12 additions & 20 deletions tests/granulate_utils/config_feeder/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from tests.granulate_utils.config_feeder.fixtures.api import ApiMock


def test_should_send_config_only_once_when_not_changed() -> None:
def test_should_send_config_only_once_when_not_changed(logger: logging.Logger) -> None:
with ApiMock(
collect_yarn_config=mock_yarn_config,
) as mock:
client = ConfigFeederClient("token1", "service1")
client = ConfigFeederClient("token1", "service1", logger=logger)

client.collect()
client.collect()
Expand Down Expand Up @@ -52,15 +52,15 @@ def test_should_send_config_only_once_when_not_changed() -> None:
}


def test_should_send_config_only_when_changed() -> None:
def test_should_send_config_only_when_changed(logger: logging.Logger) -> None:
yarn_configs = [
mock_yarn_config(thread_count=128),
mock_yarn_config(thread_count=128),
mock_yarn_config(thread_count=64),
]

with ApiMock(collect_yarn_config=lambda _: yarn_configs.pop()) as mock:
client = ConfigFeederClient("token1", "service1")
client = ConfigFeederClient("token1", "service1", logger=logger)

client.collect()
client.collect()
Expand All @@ -71,9 +71,9 @@ def test_should_send_config_only_when_changed() -> None:
assert len(requests[f"{API_URL}/nodes/node-1/configs"]) == 2


def test_should_not_send_anything() -> None:
def test_should_not_send_anything(logger: logging.Logger) -> None:
with ApiMock(collect_yarn_config=mock_yarn_config) as mock:
client = ConfigFeederClient("token1", "service1", yarn=False)
client = ConfigFeederClient("token1", "service1", yarn=False, logger=logger)

client.collect()
client.collect()
Expand All @@ -84,16 +84,16 @@ def test_should_not_send_anything() -> None:
assert len(requests) == 0


def test_should_fail_with_client_error() -> None:
def test_should_fail_with_client_error(logger: logging.Logger) -> None:
with ApiMock(
collect_yarn_config=mock_yarn_config,
register_cluster_response={"exc": ConnectionError("Connection refused")},
):
with pytest.raises(ClientError, match=f"could not connect to {API_URL}"):
ConfigFeederClient("token1", "service1").collect()
ConfigFeederClient("token1", "service1", logger=logger).collect()


def test_should_fail_with_invalid_token_exception() -> None:
def test_should_fail_with_invalid_token_exception(logger: logging.Logger) -> None:
with ApiMock(
collect_yarn_config=mock_yarn_config,
register_cluster_response={
Expand All @@ -102,24 +102,16 @@ def test_should_fail_with_invalid_token_exception() -> None:
},
):
with pytest.raises(InvalidTokenException, match="Invalid token"):
ConfigFeederClient("token1", "service1").collect()
ConfigFeederClient("token1", "service1", logger=logger).collect()


def test_should_fail_with_api_error() -> None:
def test_should_fail_with_api_error(logger: logging.Logger) -> None:
with ApiMock(
collect_yarn_config=mock_yarn_config,
register_cluster_response={"status_code": 400, "text": "unexpected error"},
):
with pytest.raises(APIError, match="400 unexpected error /clusters"):
ConfigFeederClient("token1", "service1").collect()


def test_should_have_logger_with_null_handler() -> None:
client = ConfigFeederClient("token1", "service1")

assert isinstance(client.logger, logging.Logger)
assert len(client.logger.handlers) == 1
assert isinstance(client.logger.handlers[0], type(logging.NullHandler()))
ConfigFeederClient("token1", "service1", logger=logger).collect()


def mock_yarn_config(*args: Any, thread_count: int = 64) -> YarnConfig:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import pytest

from granulate_utils.config_feeder.client.yarn.collector import YarnConfigCollector
Expand All @@ -6,7 +8,7 @@


@pytest.mark.asyncio
async def test_collect_from_master_node() -> None:
async def test_collect_from_master_node(logger: logging.Logger) -> None:
instance_id = "7203450965080656748"
cluster_uuid = "824afc19-cf18-4b23-99b0-51a6b20b35d"

Expand All @@ -16,13 +18,13 @@ async def test_collect_from_master_node() -> None:
instance_id=instance_id,
is_master=True,
) as mock:
assert await YarnConfigCollector().collect(mock.node_info) == YarnConfig(
assert await YarnConfigCollector(logger=logger).collect(mock.node_info) == YarnConfig(
config=mock.expected_config,
)


@pytest.mark.asyncio
async def test_collect_from_worker_noder() -> None:
async def test_collect_from_worker_noder(logger: logging.Logger) -> None:
instance_id = "3344294988448254828"
cluster_uuid = "824afc19-cf18-4b23-99b0-51a6b20b35d"

Expand All @@ -33,6 +35,6 @@ async def test_collect_from_worker_noder() -> None:
is_master=False,
web_address="http://0.0.0.0:8042",
) as mock:
assert await YarnConfigCollector().collect(mock.node_info) == YarnConfig(
assert await YarnConfigCollector(logger=logger).collect(mock.node_info) == YarnConfig(
config=mock.expected_config,
)
20 changes: 12 additions & 8 deletions tests/granulate_utils/config_feeder/test_emr_yarn_collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import pytest
from requests.exceptions import ConnectionError

Expand All @@ -8,7 +10,7 @@


@pytest.mark.asyncio
async def test_collect_from_master_node() -> None:
async def test_collect_from_master_node(logger: logging.Logger) -> None:
job_flow_id = "j-1234567890"
instance_id = "i-06828639fa954e04c"

Expand All @@ -28,13 +30,13 @@ async def test_collect_from_master_node() -> None:
yarn_site_xml=yarn_site_xml,
web_address="http://0.0.0.0:8001",
) as mock:
assert await YarnConfigCollector().collect(mock.node_info) == YarnConfig(
assert await YarnConfigCollector(logger=logger).collect(mock.node_info) == YarnConfig(
config=mock.expected_config,
)


@pytest.mark.asyncio
async def test_collect_from_worker_noder() -> None:
async def test_collect_from_worker_noder(logger: logging.Logger) -> None:
job_flow_id = "j-1234567890"
instance_id = "i-0c97511ec7fa849a3"

Expand All @@ -45,7 +47,7 @@ async def test_collect_from_worker_noder() -> None:
is_master=False,
web_address="http://0.0.0.0:8042",
) as mock:
assert await YarnConfigCollector().collect(mock.node_info) == YarnConfig(
assert await YarnConfigCollector(logger=logger).collect(mock.node_info) == YarnConfig(
config=mock.expected_config,
)

Expand All @@ -58,21 +60,23 @@ async def test_collect_from_worker_noder() -> None:
],
)
@pytest.mark.asyncio
async def test_should_fail_with_max_retries_exception(is_master: bool, web_address: str) -> None:
async def test_should_fail_with_max_retries_exception(
is_master: bool, web_address: str, logger: logging.Logger
) -> None:
with YarnNodeMock(
is_master=is_master,
web_address=web_address,
home_dir="/home/not-hadoop",
response={"exc": ConnectionError},
) as mock:
with pytest.raises(MaximumRetriesExceeded, match="maximum number of failed requests reached"):
collector = YarnConfigCollector(max_retries=3)
collector = YarnConfigCollector(max_retries=3, logger=logger)
while True:
await collector.collect(mock.node_info)


@pytest.mark.asyncio
async def test_should_mask_sensitive_values() -> None:
async def test_should_mask_sensitive_values(logger: logging.Logger) -> None:
with YarnNodeMock(
provider="aws",
job_flow_id="j-1234567890",
Expand All @@ -87,7 +91,7 @@ async def test_should_mask_sensitive_values() -> None:
}
],
) as mock:
result = await YarnConfigCollector().collect(mock.node_info)
result = await YarnConfigCollector(logger=logger).collect(mock.node_info)

assert result is not None
assert result.config == {
Expand Down

0 comments on commit 882d591

Please sign in to comment.