Skip to content

Commit

Permalink
Webserver: Sanitize values passed to origin param (#10334)
Browse files Browse the repository at this point in the history
(cherry-picked from 5c2bb7b)
  • Loading branch information
kaxil committed Aug 15, 2020
1 parent 004efff commit 9e8d566
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 20 deletions.
37 changes: 27 additions & 10 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from pygments import highlight, lexers
import six
from pygments.formatters.html import HtmlFormatter
from six.moves.urllib.parse import quote, unquote
from six.moves.urllib.parse import quote, unquote, urlparse

from sqlalchemy import or_, desc, and_, union_all
from wtforms import (
Expand Down Expand Up @@ -328,6 +328,23 @@ def get_chart_height(dag):
return 600 + len(dag.tasks) * 10


def get_safe_url(url):
"""Given a user-supplied URL, ensure it points to our web server"""
try:
valid_schemes = ['http', 'https', '']
valid_netlocs = [request.host, '']

parsed = urlparse(url)
if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
return url
except Exception as e: # pylint: disable=broad-except
log.debug("Error validating value in origin parameter passed to URL: %s", url)
log.debug("Error: %s", e)
pass

return "/admin/"


def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
dttm = request.args.get('execution_date')
if dttm:
Expand Down Expand Up @@ -1108,7 +1125,7 @@ def xcom(self, session=None):
def run(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))

dag = dagbag.get_dag(dag_id)
task = dag.get_task(task_id)
Expand Down Expand Up @@ -1179,7 +1196,7 @@ def delete(self):
from airflow.exceptions import DagNotFound, DagFileExists

dag_id = request.values.get('dag_id')
origin = request.values.get('origin') or "/admin/"
origin = get_safe_url(request.values.get('origin'))

try:
delete_dag.delete_dag(dag_id)
Expand All @@ -1203,7 +1220,7 @@ def delete(self):
@provide_session
def trigger(self, session=None):
dag_id = request.values.get('dag_id')
origin = request.values.get('origin') or "/admin/"
origin = get_safe_url(request.values.get('origin'))

if request.method == 'GET':
return self.render(
Expand Down Expand Up @@ -1304,7 +1321,7 @@ def _clear_dag_tis(self, dag, start_date, end_date, origin,
def clear(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)

execution_date = request.form.get('execution_date')
Expand Down Expand Up @@ -1334,7 +1351,7 @@ def clear(self):
@wwwutils.notify_owner
def dagrun_clear(self):
dag_id = request.form.get('dag_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"

Expand Down Expand Up @@ -1437,7 +1454,7 @@ def dagrun_failed(self):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_failed(dag_id, execution_date,
confirmed, origin)

Expand All @@ -1449,7 +1466,7 @@ def dagrun_success(self):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_success(dag_id, execution_date,
confirmed, origin)

Expand Down Expand Up @@ -1502,7 +1519,7 @@ def _mark_task_instance_state(self, dag_id, task_id, origin, execution_date,
def failed(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')

confirmed = request.form.get('confirmed') == "true"
Expand All @@ -1522,7 +1539,7 @@ def failed(self):
def success(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')

confirmed = request.form.get('confirmed') == "true"
Expand Down
37 changes: 27 additions & 10 deletions airflow/www_rbac/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from urllib.parse import unquote

import six
from six.moves.urllib.parse import quote
from six.moves.urllib.parse import quote, urlparse

import pendulum
import sqlalchemy as sqla
Expand Down Expand Up @@ -89,6 +89,23 @@
dagbag = models.DagBag(os.devnull, include_examples=False)


def get_safe_url(url):
"""Given a user-supplied URL, ensure it points to our web server"""
try:
valid_schemes = ['http', 'https', '']
valid_netlocs = [request.host, '']

parsed = urlparse(url)
if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
return url
except Exception as e: # pylint: disable=broad-except
logging.debug("Error validating value in origin parameter passed to URL: %s", url)
logging.debug("Error: %s", e)
pass

return url_for('Airflow.index')


def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
dttm = request.args.get('execution_date')
if dttm:
Expand Down Expand Up @@ -930,7 +947,7 @@ def xcom(self, session=None):
def run(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)
task = dag.get_task(task_id)

Expand Down Expand Up @@ -1000,7 +1017,7 @@ def delete(self):
from airflow.exceptions import DagNotFound, DagFileExists

dag_id = request.values.get('dag_id')
origin = request.values.get('origin') or url_for('Airflow.index')
origin = get_safe_url(request.values.get('origin'))

try:
delete_dag.delete_dag(dag_id)
Expand All @@ -1027,7 +1044,7 @@ def delete(self):
def trigger(self, session=None):

dag_id = request.values.get('dag_id')
origin = request.values.get('origin') or url_for('Airflow.index')
origin = get_safe_url(request.values.get('origin'))

if request.method == 'GET':
return self.render_template(
Expand Down Expand Up @@ -1128,7 +1145,7 @@ def _clear_dag_tis(self, dag, start_date, end_date, origin,
def clear(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
dag = dagbag.get_dag(dag_id)

execution_date = request.form.get('execution_date')
Expand Down Expand Up @@ -1158,7 +1175,7 @@ def clear(self):
@action_logging
def dagrun_clear(self):
dag_id = request.form.get('dag_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"

Expand Down Expand Up @@ -1280,7 +1297,7 @@ def dagrun_failed(self):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_failed(dag_id, execution_date,
confirmed, origin)

Expand All @@ -1292,7 +1309,7 @@ def dagrun_success(self):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_success(dag_id, execution_date,
confirmed, origin)

Expand Down Expand Up @@ -1345,7 +1362,7 @@ def _mark_task_instance_state(self, dag_id, task_id, origin, execution_date,
def failed(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')

confirmed = request.form.get('confirmed') == "true"
Expand All @@ -1365,7 +1382,7 @@ def failed(self):
def success(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')

confirmed = request.form.get('confirmed') == "true"
Expand Down
23 changes: 23 additions & 0 deletions tests/www/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from airflow.operators.bash_operator import BashOperator
from airflow.utils import timezone
from airflow.utils.db import create_session
from parameterized import parameterized
from tests.compat import mock

from six.moves.urllib.parse import quote_plus
Expand Down Expand Up @@ -1115,6 +1116,28 @@ def test_trigger_serialized_dag(self, mock_os_isfile, mock_dagrun):
'Triggered example_bash_operator, it should start any moment now.',
response.data.decode('utf-8'))

@parameterized.expand([
("javascript:alert(1)", "/admin/"),
("http://google.com", "/admin/"),
(
"%2Fadmin%2Fairflow%2Ftree%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator",
"/admin/airflow/tree?dag_id=example_bash_operator"
),
(
"%2Fadmin%2Fairflow%2Fgraph%3Fdag_id%3Dexample_bash_operator&dag_id=example_bash_operator",
"/admin/airflow/graph?dag_id=example_bash_operator"
),
("", ""),
])
def test_trigger_dag_form_origin_url(self, test_origin, expected_origin):
test_dag_id = "example_bash_operator"
response = self.app.get(
'/admin/airflow/trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin))
self.assertIn(
'<button class="btn" onclick="location.href = \'{}\'; return false">'.format(
expected_origin),
response.data.decode('utf-8'))


class HelpersTest(unittest.TestCase):
@classmethod
Expand Down
16 changes: 16 additions & 0 deletions tests/www_rbac/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,6 +2244,22 @@ def test_trigger_serialized_dag(self, mock_os_isfile, mock_dagrun):
self.check_content_in_response(
'Triggered example_bash_operator, it should start any moment now.', response)

@parameterized.expand([
("javascript:alert(1)", "/home"),
("http://google.com", "/home"),
("%2Ftree%3Fdag_id%3Dexample_bash_operator", "/tree?dag_id=example_bash_operator"),
("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "/graph?dag_id=example_bash_operator"),
("", ""),
])
def test_trigger_dag_form_origin_url(self, test_origin, expected_origin):
test_dag_id = "example_bash_operator"

resp = self.client.get('trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin))
self.check_content_in_response(
'<button class="btn" onclick="location.href = \'{}\'; return false">'.format(
expected_origin),
resp)

@mock.patch('airflow.www_rbac.views.dagbag.get_dag')
def test_trigger_endpoint_uses_existing_dagbag(self, mock_get_dag):
"""
Expand Down

0 comments on commit 9e8d566

Please sign in to comment.