Skip to content

Commit

Permalink
Webserver: Sanitize values passed to origin param (#10334)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored Aug 15, 2020
1 parent 4454224 commit 5c2bb7b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 10 deletions.
33 changes: 23 additions & 10 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from datetime import datetime, timedelta
from json import JSONDecodeError
from typing import Dict, List, Optional, Tuple
from urllib.parse import unquote
from urllib.parse import unquote, urlparse

import lazy_object_proxy
import nvd3
Expand Down Expand Up @@ -81,6 +81,19 @@
FILTER_STATUS_COOKIE = 'dag_status_filter'


def get_safe_url(url):
"""Given a user-supplied URL, ensure it points to our web server"""
valid_schemes = ['http', 'https', '']
valid_netlocs = [request.host, '']

parsed = urlparse(url)

if parsed.scheme in valid_schemes and parsed.netloc in valid_netlocs:
return url

return url_for('Airflow.index')


def get_date_time_num_runs_dag_runs_form_data(request, session, dag):
"""Get Execution Data, Base Date & Number of runs from a Request """
dttm = request.args.get('execution_date')
Expand Down Expand Up @@ -921,7 +934,7 @@ def xcom(self, session=None):
def run(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
dag = current_app.dag_bag.get_dag(dag_id)
task = dag.get_task(task_id)

Expand Down Expand Up @@ -990,7 +1003,7 @@ def delete(self):
from airflow.exceptions import DagFileExists, DagNotFound

dag_id = request.values.get('dag_id')
origin = request.values.get('origin') or url_for('Airflow.index')
origin = get_safe_url(request.values.get('origin'))

try:
delete_dag.delete_dag(dag_id)
Expand All @@ -1017,7 +1030,7 @@ def delete(self):
def trigger(self, session=None):

dag_id = request.values.get('dag_id')
origin = request.values.get('origin') or url_for('Airflow.index')
origin = get_safe_url(request.values.get('origin'))

if request.method == 'GET':
return self.render_template(
Expand Down Expand Up @@ -1115,7 +1128,7 @@ def _clear_dag_tis(self, dag, start_date, end_date, origin,
def clear(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
dag = current_app.dag_bag.get_dag(dag_id)

execution_date = request.form.get('execution_date')
Expand Down Expand Up @@ -1145,7 +1158,7 @@ def clear(self):
@action_logging
def dagrun_clear(self):
dag_id = request.form.get('dag_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == "true"

Expand Down Expand Up @@ -1267,7 +1280,7 @@ def dagrun_failed(self):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_failed(dag_id, execution_date,
confirmed, origin)

Expand All @@ -1279,7 +1292,7 @@ def dagrun_success(self):
dag_id = request.form.get('dag_id')
execution_date = request.form.get('execution_date')
confirmed = request.form.get('confirmed') == 'true'
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
return self._mark_dagrun_state_as_success(dag_id, execution_date,
confirmed, origin)

Expand Down Expand Up @@ -1329,7 +1342,7 @@ def _mark_task_instance_state(self, dag_id, task_id, origin, execution_date,
def failed(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')

confirmed = request.form.get('confirmed') == "true"
Expand All @@ -1349,7 +1362,7 @@ def failed(self):
def success(self):
dag_id = request.form.get('dag_id')
task_id = request.form.get('task_id')
origin = request.form.get('origin')
origin = get_safe_url(request.form.get('origin'))
execution_date = request.form.get('execution_date')

confirmed = request.form.get('confirmed') == "true"
Expand Down
15 changes: 15 additions & 0 deletions tests/www/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2468,6 +2468,21 @@ def test_trigger_dag_form(self):
self.assertEqual(resp.status_code, 200)
self.check_content_in_response('Trigger DAG: {}'.format(test_dag_id), resp)

@parameterized.expand([
("javascript:alert(1)", "/home"),
("http://google.com", "/home"),
("%2Ftree%3Fdag_id%3Dexample_bash_operator", "/tree?dag_id=example_bash_operator"),
("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "/graph?dag_id=example_bash_operator"),
])
def test_trigger_dag_form_origin_url(self, test_origin, expected_origin):
test_dag_id = "example_bash_operator"

resp = self.client.get('trigger?dag_id={}&origin={}'.format(test_dag_id, test_origin))
self.check_content_in_response(
'<button class="btn" onclick="location.href = \'{}\'; return false">'.format(
expected_origin),
resp)

def test_trigger_endpoint_uses_existing_dagbag(self):
"""
Test that Trigger Endpoint uses the DagBag already created in views.py
Expand Down

0 comments on commit 5c2bb7b

Please sign in to comment.