From 2327aa5a263f25beeaf4ba79670f10f001daf0bf Mon Sep 17 00:00:00 2001 From: Kengo Seki Date: Fri, 13 Mar 2020 01:13:33 +0900 Subject: [PATCH] [AIRFLOW-7025] Fix SparkSqlHook.run_query to handle its parameter properly (#7677) --- .../providers/apache/spark/hooks/spark_sql.py | 14 +++++-- .../apache/spark/hooks/test_spark_sql.py | 38 +++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py index 5182a43a246727..d056ca090386ed 100644 --- a/airflow/providers/apache/spark/hooks/spark_sql.py +++ b/airflow/providers/apache/spark/hooks/spark_sql.py @@ -93,7 +93,7 @@ def _prepare_command(self, cmd): as default. :param cmd: command to append to the spark-sql command - :type cmd: str + :type cmd: str or list[str] :return: full command to be executed """ connection_cmd = ["spark-sql"] @@ -127,7 +127,13 @@ def _prepare_command(self, cmd): if self._yarn_queue: connection_cmd += ["--queue", self._yarn_queue] - connection_cmd += cmd + if isinstance(cmd, str): + connection_cmd += cmd.split() + elif isinstance(cmd, list): + connection_cmd += cmd + else: + raise AirflowException("Invalid additional command: {}".format(cmd)) + self.log.debug("Spark-Sql cmd: %s", connection_cmd) return connection_cmd @@ -136,8 +142,10 @@ def run_query(self, cmd="", **kwargs): """ Remote Popen (actually execute the Spark-sql query) - :param cmd: command to remotely execute + :param cmd: command to append to the spark-sql command + :type cmd: str or list[str] :param kwargs: extra arguments to Popen (see subprocess.Popen) + :type kwargs: dict """ spark_sql_cmd = self._prepare_command(cmd) self._sp = subprocess.Popen(spark_sql_cmd, diff --git a/tests/providers/apache/spark/hooks/test_spark_sql.py b/tests/providers/apache/spark/hooks/test_spark_sql.py index 4cec168debf83a..13fd47d55168a9 100644 --- a/tests/providers/apache/spark/hooks/test_spark_sql.py +++ b/tests/providers/apache/spark/hooks/test_spark_sql.py @@ -108,6 +108,44 @@ def test_spark_process_runcmd(self, mock_popen): '--queue', 'default'], stderr=-2, stdout=-1) ) + @patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen') + def test_spark_process_runcmd_with_str(self, mock_popen): + # Given + mock_popen.return_value.wait.return_value = 0 + + # When + hook = SparkSqlHook( + conn_id='spark_default', + sql='SELECT 1' + ) + hook.run_query('--deploy-mode cluster') + + # Then + self.assertEqual( + mock_popen.mock_calls[0], + call(['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name', 'default-name', '--verbose', + '--queue', 'default', '--deploy-mode', 'cluster'], stderr=-2, stdout=-1) + ) + + @patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen') + def test_spark_process_runcmd_with_list(self, mock_popen): + # Given + mock_popen.return_value.wait.return_value = 0 + + # When + hook = SparkSqlHook( + conn_id='spark_default', + sql='SELECT 1' + ) + hook.run_query(['--deploy-mode', 'cluster']) + + # Then + self.assertEqual( + mock_popen.mock_calls[0], + call(['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name', 'default-name', '--verbose', + '--queue', 'default', '--deploy-mode', 'cluster'], stderr=-2, stdout=-1) + ) + if __name__ == '__main__': unittest.main()