From 84e80b729209cdb1295efe3e19702c796211ae75 Mon Sep 17 00:00:00 2001 From: Beni Ben Zikry Date: Wed, 30 Sep 2020 06:33:16 +0300 Subject: [PATCH] Spark-on-k8s logs - ensure driver pod name and namespace are based on SparkApplication CRD --- .../cncf/kubernetes/sensors/spark_kubernetes.py | 15 +++++++++++---- .../kubernetes/sensors/test_spark_kubernetes.py | 4 ++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py index 6b4ddee94c890d..03fd14b5dd9080 100644 --- a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py @@ -64,14 +64,21 @@ def __init__( self.kubernetes_conn_id = kubernetes_conn_id self.hook = KubernetesHook(conn_id=self.kubernetes_conn_id) - def _log_driver(self, application_state: str) -> None: + def _log_driver(self, application_state: str, response: dict) -> None: if not self.attach_log: return - driver_pod_name = f"{self.application_name}-driver" + status_info = response["status"] + if "driverInfo" not in status_info: + return + driver_info = status_info["driverInfo"] + if "podName" not in driver_info: + return + driver_pod_name = driver_info["podName"] + namespace = response["metadata"]["namespace"] log_method = self.log.error if application_state in self.FAILURE_STATES else self.log.info try: log = "" - for line in self.hook.get_pod_logs(driver_pod_name): + for line in self.hook.get_pod_logs(driver_pod_name, namespace=namespace): log += line.decode() log_method(log) except client.rest.ApiException as e: @@ -97,7 +104,7 @@ def poke(self, context: Dict) -> bool: except KeyError: return False if self.attach_log and application_state in self.FAILURE_STATES + self.SUCCESS_STATES: - self._log_driver(application_state) + self._log_driver(application_state, response) if application_state in self.FAILURE_STATES: raise AirflowException("Spark application failed with state: %s" % application_state) elif application_state in self.SUCCESS_STATES: diff --git a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py index ea90834f38dcfa..3e187719c97838 100644 --- a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py @@ -697,7 +697,7 @@ def test_driver_logging_failure( task_id="test_task_id", ) self.assertRaises(AirflowException, sensor.poke, None) - mock_log_call.assert_called_once_with("spark_pi-driver") + mock_log_call.assert_called_once_with("spark-pi-driver", namespace="default") error_log_call.assert_called_once_with(TEST_POD_LOG_RESULT) @patch( @@ -719,7 +719,7 @@ def test_driver_logging_completed( task_id="test_task_id", ) sensor.poke(None) - mock_log_call.assert_called_once_with("spark_pi-driver") + mock_log_call.assert_called_once_with("spark-pi-2020-02-24-1-driver", namespace="default") log_info_call = info_log_call.mock_calls[1] log_value = log_info_call[1][0] self.assertEqual(log_value, TEST_POD_LOG_RESULT)