Skip to content

Commit

Permalink
Validate only task commands are run by executors (apache#9178)
Browse files Browse the repository at this point in the history
(cherry-picked from 6943b17)

(cherry picked from commit 4aea266)
  • Loading branch information
ashb authored and kaxil committed Jun 15, 2020
1 parent d819da6 commit 22c5f2b
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 2 deletions.
5 changes: 5 additions & 0 deletions airflow/contrib/executors/kubernetes_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,11 @@ def run_next(self, next_job):
self.log.info('Kubernetes job is %s', str(next_job))
key, command, kube_executor_config = next_job
dag_id, task_id, execution_date, try_number = key
if isinstance(command, str):
command = [command]

if command[0] != "airflow":
raise ValueError('The first element of command must be equal to "airflow".')
self.log.debug("Kubernetes running for command %s", command)
self.log.debug("Kubernetes launching image %s", self.kube_config.kube_image)
pod = self.worker_configuration.make_pod(
Expand Down
2 changes: 2 additions & 0 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
@app.task
def execute_command(command_to_exec):
log = LoggingMixin().log
if command_to_exec[0:2] != ["airflow", "run"]:
raise ValueError('The command must start with ["airflow", "run"].')
log.info("Executing command in Celery: %s", command_to_exec)
env = os.environ.copy()
try:
Expand Down
3 changes: 3 additions & 0 deletions airflow/executors/dask_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def execute_async(self, key, command, queue=None, executor_config=None):
'All tasks will be run in the same cluster'
)

if command[0:2] != ["airflow", "run"]:
raise ValueError('The command must start with ["airflow", "run"].')

def airflow_run():
return subprocess.check_call(command, close_fds=True)

Expand Down
2 changes: 2 additions & 0 deletions airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def start(self):
self.impl.start()

def execute_async(self, key, command, queue=None, executor_config=None):
if command[0:2] != ["airflow", "run"]:
raise ValueError('The command must start with ["airflow", "run"].')
self.impl.execute_async(key=key, command=command)

def sync(self):
Expand Down
2 changes: 2 additions & 0 deletions airflow/executors/sequential_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self):
self.commands_to_run = []

def execute_async(self, key, command, queue=None, executor_config=None):
if command[0:2] != ["airflow", "run"]:
raise ValueError('The command must start with ["airflow", "run"].')
self.commands_to_run.append((key, command,))

def sync(self):
Expand Down
24 changes: 22 additions & 2 deletions tests/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from multiprocessing import Pool

import mock
import pytest
from celery.contrib.testing.worker import start_worker
from parameterized import parameterized

from airflow.executors import celery_executor
from airflow.executors.celery_executor import (CeleryExecutor, celery_configuration,
Expand All @@ -43,8 +45,8 @@ def test_celery_integration(self):
executor = CeleryExecutor()
executor.start()
with start_worker(app=app, logfile=sys.stdout, loglevel='debug'):
success_command = ['true', 'some_parameter']
fail_command = ['false', 'some_parameter']
success_command = ['airflow', 'run', 'true', 'some_parameter']
fail_command = ['airflow', 'run', 'false']

cached_celery_backend = execute_command.backend
task_tuples_to_send = [('success', 'fake_simple_ti', success_command,
Expand Down Expand Up @@ -135,6 +137,24 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock
mock.call('executor.running_tasks', mock.ANY)]
mock_stats_gauge.assert_has_calls(calls)

@parameterized.expand((
[['true'], ValueError],
[['airflow', 'version'], ValueError],
[['airflow', 'run'], None]
))
@mock.patch('subprocess.check_call')
def test_command_validation(self, command, expected_exception, mock_check_call):
# Check that we validate _on the receiving_ side, not just sending side
if expected_exception:
with pytest.raises(expected_exception):
celery_executor.execute_command(command)
mock_check_call.assert_not_called()
else:
celery_executor.execute_command(command)
mock_check_call.assert_called_once_with(
command, stderr=mock.ANY, close_fds=mock.ANY, env=mock.ANY,
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 22c5f2b

Please sign in to comment.