From 071810abb50931fd8a8b72a6c9f5886f09e18be2 Mon Sep 17 00:00:00 2001 From: tooptoop4 <33283496+tooptoop4@users.noreply.github.com> Date: Tue, 31 Dec 2019 15:22:00 +0000 Subject: [PATCH] [AIRFLOW-5385] spark hook does not work on spark 2.3/2.4 (#6976) --- airflow/contrib/hooks/spark_submit_hook.py | 45 ++++++++++++++----- tests/contrib/hooks/test_spark_submit_hook.py | 30 +++++++++++++ 2 files changed, 64 insertions(+), 11 deletions(-) diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py index bb668a34d413f8..a5f0fe0abd8150 100644 --- a/airflow/contrib/hooks/spark_submit_hook.py +++ b/airflow/contrib/hooks/spark_submit_hook.py @@ -330,18 +330,41 @@ def _build_track_driver_status_command(self): :return: full command to be executed """ - connection_cmd = self._get_spark_binary_path() - - # The url ot the spark master - connection_cmd += ["--master", self._connection['master']] + curl_max_wait_time = 30 + spark_host = self._connection['master'] + if spark_host.endswith(':6066'): + spark_host = spark_host.replace("spark://", "http://") + connection_cmd = [ + "/usr/bin/curl", + "--max-time", + str(curl_max_wait_time), + "{host}/v1/submissions/status/{submission_id}".format( + host=spark_host, + submission_id=self._driver_id)] + self.log.info(connection_cmd) + + # The driver id so we can poll for its status + if self._driver_id: + pass + else: + raise AirflowException( + "Invalid status: attempted to poll driver " + + "status but no driver id is known. Giving up.") - # The driver id so we can poll for its status - if self._driver_id: - connection_cmd += ["--status", self._driver_id] else: - raise AirflowException( - "Invalid status: attempted to poll driver " + - "status but no driver id is known. Giving up.") + + connection_cmd = self._get_spark_binary_path() + + # The url to the spark master + connection_cmd += ["--master", self._connection['master']] + + # The driver id so we can poll for its status + if self._driver_id: + connection_cmd += ["--status", self._driver_id] + else: + raise AirflowException( + "Invalid status: attempted to poll driver " + + "status but no driver id is known. Giving up.") self.log.debug("Poll driver status cmd: %s", connection_cmd) @@ -556,7 +579,7 @@ def _build_spark_driver_kill_command(self): else: connection_cmd = [self._connection['spark_binary']] - # The url ot the spark master + # The url to the spark master connection_cmd += ["--master", self._connection['master']] # The actual kill command diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py index f7165bea9e7574..d8312225d100ab 100644 --- a/tests/contrib/hooks/test_spark_submit_hook.py +++ b/tests/contrib/hooks/test_spark_submit_hook.py @@ -166,6 +166,36 @@ def test_build_spark_submit_command(self): ] self.assertEqual(expected_build_cmd, cmd) + def test_build_track_driver_status_command(self): + # note this function is only relevant for spark setup matching below condition + # 'spark://' in self._connection['master'] and self._connection['deploy_mode'] == 'cluster' + + # Given + hook_spark_standalone_cluster = SparkSubmitHook( + conn_id='spark_standalone_cluster') + hook_spark_standalone_cluster._driver_id = 'driver-20171128111416-0001' + hook_spark_yarn_cluster = SparkSubmitHook( + conn_id='spark_yarn_cluster') + hook_spark_yarn_cluster._driver_id = 'driver-20171128111417-0001' + + # When + build_track_driver_status_spark_standalone_cluster = \ + hook_spark_standalone_cluster._build_track_driver_status_command() + build_track_driver_status_spark_yarn_cluster = \ + hook_spark_yarn_cluster._build_track_driver_status_command() + + # Then + expected_spark_standalone_cluster = [ + '/usr/bin/curl', + '--max-time', + '30', + 'http://spark-standalone-master:6066/v1/submissions/status/driver-20171128111416-0001'] + expected_spark_yarn_cluster = [ + 'spark-submit', '--master', 'yarn://yarn-master', '--status', 'driver-20171128111417-0001'] + + assert expected_spark_standalone_cluster == build_track_driver_status_spark_standalone_cluster + assert expected_spark_yarn_cluster == build_track_driver_status_spark_yarn_cluster + @patch('airflow.contrib.hooks.spark_submit_hook.subprocess.Popen') def test_spark_process_runcmd(self, mock_popen): # Given