From 9e8d5660d6fb29b6161835bcd9b0d2230467fab5 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 15 Aug 2020 16:01:33 +0100 Subject: [PATCH] Webserver: Sanitize values passed to origin param (#10334) (cherry-picked from 5c2bb7b0b0e717b11f093910b443243330ad93ca) --- airflow/www/views.py | 37 ++++++++++++++++++++++++++---------- airflow/www_rbac/views.py | 37 ++++++++++++++++++++++++++---------- tests/www/test_views.py | 23 ++++++++++++++++++++++ tests/www_rbac/test_views.py | 16 ++++++++++++++++ 4 files changed, 93 insertions(+), 20 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index b496e7246833a..60873566dd158 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -54,7 +54,7 @@ from pygments import highlight, lexers import six from pygments.formatters.html import HtmlFormatter -from six.moves.urllib.parse import quote, unquote +from six.moves.urllib.parse import quote, unquote, urlparse from sqlalchemy import or_, desc, and_, union_all from wtforms import ( @@ -328,6 +328,23 @@ def get_chart_height(dag): return 600 + len(dag.tasks) * 10 +def get_safe_url(url): + """Given a user-supplied URL, ensure it points to our web server""" + try: + valid_schemes = ['http', 'https', ''] + valid_netlocs = [request.host, ''] + + parsed = urlparse(url) + if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs: + return url + except Exception as e: # pylint: disable=broad-except + log.debug("Error validating value in origin parameter passed to URL: %s", url) + log.debug("Error: %s", e) + pass + + return "/admin/" + + def get_date_time_num_runs_dag_runs_form_data(request, session, dag): dttm = request.args.get('execution_date') if dttm: @@ -1108,7 +1125,7 @@ def xcom(self, session=None): def run(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) dag = dagbag.get_dag(dag_id) task = dag.get_task(task_id) @@ -1179,7 +1196,7 @@ def delete(self): from airflow.exceptions import DagNotFound, DagFileExists dag_id = request.values.get('dag_id') - origin = request.values.get('origin') or "/admin/" + origin = get_safe_url(request.values.get('origin')) try: delete_dag.delete_dag(dag_id) @@ -1203,7 +1220,7 @@ def delete(self): @provide_session def trigger(self, session=None): dag_id = request.values.get('dag_id') - origin = request.values.get('origin') or "/admin/" + origin = get_safe_url(request.values.get('origin')) if request.method == 'GET': return self.render( @@ -1304,7 +1321,7 @@ def _clear_dag_tis(self, dag, start_date, end_date, origin, def clear(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) dag = dagbag.get_dag(dag_id) execution_date = request.form.get('execution_date') @@ -1334,7 +1351,7 @@ def clear(self): @wwwutils.notify_owner def dagrun_clear(self): dag_id = request.form.get('dag_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" @@ -1437,7 +1454,7 @@ def dagrun_failed(self): dag_id = request.form.get('dag_id') execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) return self._mark_dagrun_state_as_failed(dag_id, execution_date, confirmed, origin) @@ -1449,7 +1466,7 @@ def dagrun_success(self): dag_id = request.form.get('dag_id') execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) return self._mark_dagrun_state_as_success(dag_id, execution_date, confirmed, origin) @@ -1502,7 +1519,7 @@ def _mark_task_instance_state(self, dag_id, task_id, origin, execution_date, def failed(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" @@ -1522,7 +1539,7 @@ def failed(self): def success(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py index f098b25aba34f..9d46d030c01d0 100644 --- a/airflow/www_rbac/views.py +++ b/airflow/www_rbac/views.py @@ -31,7 +31,7 @@ from urllib.parse import unquote import six -from six.moves.urllib.parse import quote +from six.moves.urllib.parse import quote, urlparse import pendulum import sqlalchemy as sqla @@ -89,6 +89,23 @@ dagbag = models.DagBag(os.devnull, include_examples=False) +def get_safe_url(url): + """Given a user-supplied URL, ensure it points to our web server""" + try: + valid_schemes = ['http', 'https', ''] + valid_netlocs = [request.host, ''] + + parsed = urlparse(url) + if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs: + return url + except Exception as e: # pylint: disable=broad-except + logging.debug("Error validating value in origin parameter passed to URL: %s", url) + logging.debug("Error: %s", e) + pass + + return url_for('Airflow.index') + + def get_date_time_num_runs_dag_runs_form_data(request, session, dag): dttm = request.args.get('execution_date') if dttm: @@ -930,7 +947,7 @@ def xcom(self, session=None): def run(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) dag = dagbag.get_dag(dag_id) task = dag.get_task(task_id) @@ -1000,7 +1017,7 @@ def delete(self): from airflow.exceptions import DagNotFound, DagFileExists dag_id = request.values.get('dag_id') - origin = request.values.get('origin') or url_for('Airflow.index') + origin = get_safe_url(request.values.get('origin')) try: delete_dag.delete_dag(dag_id) @@ -1027,7 +1044,7 @@ def delete(self): def trigger(self, session=None): dag_id = request.values.get('dag_id') - origin = request.values.get('origin') or url_for('Airflow.index') + origin = get_safe_url(request.values.get('origin')) if request.method == 'GET': return self.render_template( @@ -1128,7 +1145,7 @@ def _clear_dag_tis(self, dag, start_date, end_date, origin, def clear(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) dag = dagbag.get_dag(dag_id) execution_date = request.form.get('execution_date') @@ -1158,7 +1175,7 @@ def clear(self): @action_logging def dagrun_clear(self): dag_id = request.form.get('dag_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" @@ -1280,7 +1297,7 @@ def dagrun_failed(self): dag_id = request.form.get('dag_id') execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) return self._mark_dagrun_state_as_failed(dag_id, execution_date, confirmed, origin) @@ -1292,7 +1309,7 @@ def dagrun_success(self): dag_id = request.form.get('dag_id') execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == 'true' - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) return self._mark_dagrun_state_as_success(dag_id, execution_date, confirmed, origin) @@ -1345,7 +1362,7 @@ def _mark_task_instance_state(self, dag_id, task_id, origin, execution_date, def failed(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" @@ -1365,7 +1382,7 @@ def failed(self): def success(self): dag_id = request.form.get('dag_id') task_id = request.form.get('task_id') - origin = request.form.get('origin') + origin = get_safe_url(request.form.get('origin')) execution_date = request.form.get('execution_date') confirmed = request.form.get('confirmed') == "true" diff --git a/tests/www/test_views.py b/tests/www/test_views.py index ac71ebb57d648..438830c99206b 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -37,6 +37,7 @@ from airflow.operators.bash_operator import BashOperator from airflow.utils import timezone from airflow.utils.db import create_session +from parameterized import parameterized from tests.compat import mock from six.moves.urllib.parse import quote_plus @@ -1115,6 +1116,28 @@ def test_trigger_serialized_dag(self, mock_os_isfile, mock_dagrun): 'Triggered example_bash_operator, it should start any moment now.', response.data.decode('utf-8')) + @parameterized.expand([ + ("javascript:alert(1)", "/admin/"), + ("http://google.com", "/admin/"), + ( + "%2Fadmin%2Fairflow%2Ftree%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator", + "/admin/airflow/tree?dag_id=example_bash_operator" + ), + ( + "%2Fadmin%2Fairflow%2Fgraph%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator", + "/admin/airflow/graph?dag_id=example_bash_operator" + ), + ("", ""), + ]) + def test_trigger_dag_form_origin_url(self, test_origin, expected_origin): + test_dag_id = "example_bash_operator" + response = self.app.get( + '/admin/airflow/trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin)) + self.assertIn( + '