Skip to content

Commit

Permalink
[AIRFLOW-7025] Fix SparkSqlHook.run_query to handle its parameter pro…
Browse files Browse the repository at this point in the history
…perly (#7677)
  • Loading branch information
sekikn authored Mar 12, 2020
1 parent 421e7a2 commit 2327aa5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 3 deletions.
14 changes: 11 additions & 3 deletions airflow/providers/apache/spark/hooks/spark_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _prepare_command(self, cmd):
as default.
:param cmd: command to append to the spark-sql command
:type cmd: str
:type cmd: str or list[str]
:return: full command to be executed
"""
connection_cmd = ["spark-sql"]
Expand Down Expand Up @@ -127,7 +127,13 @@ def _prepare_command(self, cmd):
if self._yarn_queue:
connection_cmd += ["--queue", self._yarn_queue]

connection_cmd += cmd
if isinstance(cmd, str):
connection_cmd += cmd.split()
elif isinstance(cmd, list):
connection_cmd += cmd
else:
raise AirflowException("Invalid additional command: {}".format(cmd))

self.log.debug("Spark-Sql cmd: %s", connection_cmd)

return connection_cmd
Expand All @@ -136,8 +142,10 @@ def run_query(self, cmd="", **kwargs):
"""
Remote Popen (actually execute the Spark-sql query)
:param cmd: command to remotely execute
:param cmd: command to append to the spark-sql command
:type cmd: str or list[str]
:param kwargs: extra arguments to Popen (see subprocess.Popen)
:type kwargs: dict
"""
spark_sql_cmd = self._prepare_command(cmd)
self._sp = subprocess.Popen(spark_sql_cmd,
Expand Down
38 changes: 38 additions & 0 deletions tests/providers/apache/spark/hooks/test_spark_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,44 @@ def test_spark_process_runcmd(self, mock_popen):
'--queue', 'default'], stderr=-2, stdout=-1)
)

@patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen')
def test_spark_process_runcmd_with_str(self, mock_popen):
# Given
mock_popen.return_value.wait.return_value = 0

# When
hook = SparkSqlHook(
conn_id='spark_default',
sql='SELECT 1'
)
hook.run_query('--deploy-mode cluster')

# Then
self.assertEqual(
mock_popen.mock_calls[0],
call(['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name', 'default-name', '--verbose',
'--queue', 'default', '--deploy-mode', 'cluster'], stderr=-2, stdout=-1)
)

@patch('airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen')
def test_spark_process_runcmd_with_list(self, mock_popen):
# Given
mock_popen.return_value.wait.return_value = 0

# When
hook = SparkSqlHook(
conn_id='spark_default',
sql='SELECT 1'
)
hook.run_query(['--deploy-mode', 'cluster'])

# Then
self.assertEqual(
mock_popen.mock_calls[0],
call(['spark-sql', '-e', 'SELECT 1', '--master', 'yarn', '--name', 'default-name', '--verbose',
'--queue', 'default', '--deploy-mode', 'cluster'], stderr=-2, stdout=-1)
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2327aa5

Please sign in to comment.