diff --git a/granulate_utils/config_feeder/client/base.py b/granulate_utils/config_feeder/client/base.py index 8c0bb608..a2eb66eb 100644 --- a/granulate_utils/config_feeder/client/base.py +++ b/granulate_utils/config_feeder/client/base.py @@ -8,14 +8,13 @@ from requests import Session from granulate_utils.config_feeder.client.exceptions import MaximumRetriesExceeded -from granulate_utils.config_feeder.client.logging import get_logger T = TypeVar("T") class ConfigCollectorBase(ABC): - def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter] = None) -> None: - self.logger = logger or get_logger() + def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter]) -> None: + self.logger = logger self._max_retries = max_retries self._failed_requests = 0 self._init_session() diff --git a/granulate_utils/config_feeder/client/bigdata/databricks.py b/granulate_utils/config_feeder/client/bigdata/databricks.py index cab9f7c9..b1853211 100644 --- a/granulate_utils/config_feeder/client/bigdata/databricks.py +++ b/granulate_utils/config_feeder/client/bigdata/databricks.py @@ -15,7 +15,7 @@ def get_databricks_node_info() -> Optional[NodeInfo]: return NodeInfo( external_id=instance_id, external_cluster_id=properties["spark.databricks.clusterUsageTags.clusterId"], - is_master=instance_id == driver_instance_id, + is_master=(instance_id == driver_instance_id), provider=provider, bigdata_platform=BigDataPlatform.DATABRICKS, properties=properties, diff --git a/granulate_utils/config_feeder/client/client.py b/granulate_utils/config_feeder/client/client.py index 4eeefcaa..bd96fe95 100644 --- a/granulate_utils/config_feeder/client/client.py +++ b/granulate_utils/config_feeder/client/client.py @@ -10,7 +10,6 @@ from granulate_utils.config_feeder.client.bigdata import get_node_info from granulate_utils.config_feeder.client.exceptions import APIError, ClientError -from granulate_utils.config_feeder.client.logging import get_logger from granulate_utils.config_feeder.client.models import CollectionResult, ConfigType from granulate_utils.config_feeder.client.yarn.collector import YarnConfigCollector from granulate_utils.config_feeder.client.yarn.models import YarnConfig @@ -36,21 +35,21 @@ def __init__( token: str, service: str, *, - logger: Union[logging.Logger, logging.LoggerAdapter] = None, + logger: Union[logging.Logger, logging.LoggerAdapter], server_address: Optional[str] = None, yarn: bool = True, collector=CollectorType.SAGENT, ) -> None: if not token or not service: raise ClientError("Token and service must be provided") - self.logger = logger or get_logger() + self.logger = logger self._token = token self._service = service self._cluster_id: Optional[str] = None self._collector = collector self._server_address: str = server_address.rstrip("/") if server_address else DEFAULT_API_SERVER_ADDRESS self._is_yarn_enabled = yarn - self._yarn_collector = YarnConfigCollector() + self._yarn_collector = YarnConfigCollector(logger=logger) self._last_hash: DefaultDict[ConfigType, Dict[str, str]] = defaultdict(dict) self._init_api_session() diff --git a/granulate_utils/config_feeder/client/logging.py b/granulate_utils/config_feeder/client/logging.py deleted file mode 100644 index 94c9946e..00000000 --- a/granulate_utils/config_feeder/client/logging.py +++ /dev/null @@ -1,20 +0,0 @@ -import logging -from functools import lru_cache - -LOGGER_NAME = "config-feeder-client" - - -class Extra(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - if not hasattr(record, "extra"): - record.extra = {} - return True - - -@lru_cache(maxsize=None) -def get_logger() -> logging.Logger: - logger = logging.getLogger(LOGGER_NAME) - logger.addFilter(Extra()) - logger.setLevel(logging.INFO) - logger.addHandler(logging.NullHandler()) - return logger diff --git a/granulate_utils/config_feeder/client/spark/collector.py b/granulate_utils/config_feeder/client/spark/collector.py index 54884162..9c6a6be0 100644 --- a/granulate_utils/config_feeder/client/spark/collector.py +++ b/granulate_utils/config_feeder/client/spark/collector.py @@ -8,7 +8,7 @@ class SparkConfigCollector(ConfigCollectorBase): - def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter] = None) -> None: + def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter]) -> None: super().__init__(max_retries=max_retries, logger=logger) self._history_address = SPARK_HISTORY_DEFAULT_ADDRESS diff --git a/granulate_utils/config_feeder/client/yarn/collector.py b/granulate_utils/config_feeder/client/yarn/collector.py index 3f8bc287..13c353be 100644 --- a/granulate_utils/config_feeder/client/yarn/collector.py +++ b/granulate_utils/config_feeder/client/yarn/collector.py @@ -15,7 +15,7 @@ class YarnConfigCollector(ConfigCollectorBase): - def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter] = None) -> None: + def __init__(self, *, max_retries: int = 20, logger: Union[logging.Logger, logging.LoggerAdapter]) -> None: super().__init__(max_retries=max_retries, logger=logger) self._resource_manager_address = RM_DEFAULT_ADDRESS self._is_address_detected = False diff --git a/tests/granulate_utils/config_feeder/conftest.py b/tests/granulate_utils/config_feeder/conftest.py index 3ef9d88f..cc33605b 100644 --- a/tests/granulate_utils/config_feeder/conftest.py +++ b/tests/granulate_utils/config_feeder/conftest.py @@ -1,4 +1,5 @@ import pytest +import logging from granulate_utils.config_feeder.client.bigdata import get_node_info @@ -6,3 +7,10 @@ @pytest.fixture(autouse=True) def clear_cache() -> None: get_node_info.cache_clear() + + +@pytest.fixture(scope="session") +def logger() -> logging.Logger: + logger = logging.getLogger("test-logger") + logger.addHandler(logging.NullHandler()) + return logger diff --git a/tests/granulate_utils/config_feeder/test_client.py b/tests/granulate_utils/config_feeder/test_client.py index adf42cff..9c304168 100644 --- a/tests/granulate_utils/config_feeder/test_client.py +++ b/tests/granulate_utils/config_feeder/test_client.py @@ -13,11 +13,11 @@ from tests.granulate_utils.config_feeder.fixtures.api import ApiMock -def test_should_send_config_only_once_when_not_changed() -> None: +def test_should_send_config_only_once_when_not_changed(logger: logging.Logger) -> None: with ApiMock( collect_yarn_config=mock_yarn_config, ) as mock: - client = ConfigFeederClient("token1", "service1") + client = ConfigFeederClient("token1", "service1", logger=logger) client.collect() client.collect() @@ -52,7 +52,7 @@ def test_should_send_config_only_once_when_not_changed() -> None: } -def test_should_send_config_only_when_changed() -> None: +def test_should_send_config_only_when_changed(logger: logging.Logger) -> None: yarn_configs = [ mock_yarn_config(thread_count=128), mock_yarn_config(thread_count=128), @@ -60,7 +60,7 @@ def test_should_send_config_only_when_changed() -> None: ] with ApiMock(collect_yarn_config=lambda _: yarn_configs.pop()) as mock: - client = ConfigFeederClient("token1", "service1") + client = ConfigFeederClient("token1", "service1", logger=logger) client.collect() client.collect() @@ -71,9 +71,9 @@ def test_should_send_config_only_when_changed() -> None: assert len(requests[f"{API_URL}/nodes/node-1/configs"]) == 2 -def test_should_not_send_anything() -> None: +def test_should_not_send_anything(logger: logging.Logger) -> None: with ApiMock(collect_yarn_config=mock_yarn_config) as mock: - client = ConfigFeederClient("token1", "service1", yarn=False) + client = ConfigFeederClient("token1", "service1", yarn=False, logger=logger) client.collect() client.collect() @@ -84,16 +84,16 @@ def test_should_not_send_anything() -> None: assert len(requests) == 0 -def test_should_fail_with_client_error() -> None: +def test_should_fail_with_client_error(logger: logging.Logger) -> None: with ApiMock( collect_yarn_config=mock_yarn_config, register_cluster_response={"exc": ConnectionError("Connection refused")}, ): with pytest.raises(ClientError, match=f"could not connect to {API_URL}"): - ConfigFeederClient("token1", "service1").collect() + ConfigFeederClient("token1", "service1", logger=logger).collect() -def test_should_fail_with_invalid_token_exception() -> None: +def test_should_fail_with_invalid_token_exception(logger: logging.Logger) -> None: with ApiMock( collect_yarn_config=mock_yarn_config, register_cluster_response={ @@ -102,24 +102,16 @@ def test_should_fail_with_invalid_token_exception() -> None: }, ): with pytest.raises(InvalidTokenException, match="Invalid token"): - ConfigFeederClient("token1", "service1").collect() + ConfigFeederClient("token1", "service1", logger=logger).collect() -def test_should_fail_with_api_error() -> None: +def test_should_fail_with_api_error(logger: logging.Logger) -> None: with ApiMock( collect_yarn_config=mock_yarn_config, register_cluster_response={"status_code": 400, "text": "unexpected error"}, ): with pytest.raises(APIError, match="400 unexpected error /clusters"): - ConfigFeederClient("token1", "service1").collect() - - -def test_should_have_logger_with_null_handler() -> None: - client = ConfigFeederClient("token1", "service1") - - assert isinstance(client.logger, logging.Logger) - assert len(client.logger.handlers) == 1 - assert isinstance(client.logger.handlers[0], type(logging.NullHandler())) + ConfigFeederClient("token1", "service1", logger=logger).collect() def mock_yarn_config(*args: Any, thread_count: int = 64) -> YarnConfig: diff --git a/tests/granulate_utils/config_feeder/test_dataproc_yarn_collector.py b/tests/granulate_utils/config_feeder/test_dataproc_yarn_collector.py index 331f1bb3..11699b0d 100644 --- a/tests/granulate_utils/config_feeder/test_dataproc_yarn_collector.py +++ b/tests/granulate_utils/config_feeder/test_dataproc_yarn_collector.py @@ -1,3 +1,5 @@ +import logging + import pytest from granulate_utils.config_feeder.client.yarn.collector import YarnConfigCollector @@ -6,7 +8,7 @@ @pytest.mark.asyncio -async def test_collect_from_master_node() -> None: +async def test_collect_from_master_node(logger: logging.Logger) -> None: instance_id = "7203450965080656748" cluster_uuid = "824afc19-cf18-4b23-99b0-51a6b20b35d" @@ -16,13 +18,13 @@ async def test_collect_from_master_node() -> None: instance_id=instance_id, is_master=True, ) as mock: - assert await YarnConfigCollector().collect(mock.node_info) == YarnConfig( + assert await YarnConfigCollector(logger=logger).collect(mock.node_info) == YarnConfig( config=mock.expected_config, ) @pytest.mark.asyncio -async def test_collect_from_worker_noder() -> None: +async def test_collect_from_worker_noder(logger: logging.Logger) -> None: instance_id = "3344294988448254828" cluster_uuid = "824afc19-cf18-4b23-99b0-51a6b20b35d" @@ -33,6 +35,6 @@ async def test_collect_from_worker_noder() -> None: is_master=False, web_address="http://0.0.0.0:8042", ) as mock: - assert await YarnConfigCollector().collect(mock.node_info) == YarnConfig( + assert await YarnConfigCollector(logger=logger).collect(mock.node_info) == YarnConfig( config=mock.expected_config, ) diff --git a/tests/granulate_utils/config_feeder/test_emr_yarn_collector.py b/tests/granulate_utils/config_feeder/test_emr_yarn_collector.py index cfa2b70e..f1140fd6 100644 --- a/tests/granulate_utils/config_feeder/test_emr_yarn_collector.py +++ b/tests/granulate_utils/config_feeder/test_emr_yarn_collector.py @@ -1,3 +1,5 @@ +import logging + import pytest from requests.exceptions import ConnectionError @@ -8,7 +10,7 @@ @pytest.mark.asyncio -async def test_collect_from_master_node() -> None: +async def test_collect_from_master_node(logger: logging.Logger) -> None: job_flow_id = "j-1234567890" instance_id = "i-06828639fa954e04c" @@ -28,13 +30,13 @@ async def test_collect_from_master_node() -> None: yarn_site_xml=yarn_site_xml, web_address="http://0.0.0.0:8001", ) as mock: - assert await YarnConfigCollector().collect(mock.node_info) == YarnConfig( + assert await YarnConfigCollector(logger=logger).collect(mock.node_info) == YarnConfig( config=mock.expected_config, ) @pytest.mark.asyncio -async def test_collect_from_worker_noder() -> None: +async def test_collect_from_worker_noder(logger: logging.Logger) -> None: job_flow_id = "j-1234567890" instance_id = "i-0c97511ec7fa849a3" @@ -45,7 +47,7 @@ async def test_collect_from_worker_noder() -> None: is_master=False, web_address="http://0.0.0.0:8042", ) as mock: - assert await YarnConfigCollector().collect(mock.node_info) == YarnConfig( + assert await YarnConfigCollector(logger=logger).collect(mock.node_info) == YarnConfig( config=mock.expected_config, ) @@ -58,7 +60,9 @@ async def test_collect_from_worker_noder() -> None: ], ) @pytest.mark.asyncio -async def test_should_fail_with_max_retries_exception(is_master: bool, web_address: str) -> None: +async def test_should_fail_with_max_retries_exception( + is_master: bool, web_address: str, logger: logging.Logger +) -> None: with YarnNodeMock( is_master=is_master, web_address=web_address, @@ -66,13 +70,13 @@ async def test_should_fail_with_max_retries_exception(is_master: bool, web_addre response={"exc": ConnectionError}, ) as mock: with pytest.raises(MaximumRetriesExceeded, match="maximum number of failed requests reached"): - collector = YarnConfigCollector(max_retries=3) + collector = YarnConfigCollector(max_retries=3, logger=logger) while True: await collector.collect(mock.node_info) @pytest.mark.asyncio -async def test_should_mask_sensitive_values() -> None: +async def test_should_mask_sensitive_values(logger: logging.Logger) -> None: with YarnNodeMock( provider="aws", job_flow_id="j-1234567890", @@ -87,7 +91,7 @@ async def test_should_mask_sensitive_values() -> None: } ], ) as mock: - result = await YarnConfigCollector().collect(mock.node_info) + result = await YarnConfigCollector(logger=logger).collect(mock.node_info) assert result is not None assert result.config == {