From e985287f07c2f34dd3c2256b19301bfe2f1fb4a3 Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Fri, 14 Jun 2024 23:36:24 -0600 Subject: [PATCH 01/21] Fix unit tests for coverage issue and refactor request handling --- airflow/www/views.py | 7 +++- tests/www/test_validators.py | 6 +-- tests/www/views/test_views_grid.py | 61 +++++++++++++++++++++++++++++- 3 files changed, 67 insertions(+), 7 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index 7ffd5ec17c4cca..75741689282b09 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -271,7 +271,12 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): base_date = (date_time + datetime.timedelta(seconds=1)).replace(microsecond=0) default_dag_run = conf.getint("webserver", "default_dag_run_display_number") - num_runs = www_request.args.get("num_runs", default=default_dag_run, type=int) + num_runs = www_request.args.get("num_runs") + if num_runs is None: + num_runs = default_dag_run + else: + num_runs = int(num_runs) + # When base_date has been rounded up because of the DateTimeField widget, we want # to use the execution_date as the starting point for our query just to ensure a diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py index e92cc9d80453df..c8961f90073228 100644 --- a/tests/www/test_validators.py +++ b/tests/www/test_validators.py @@ -18,12 +18,11 @@ from __future__ import annotations from unittest import mock - import pytest from airflow.www import validators - + class TestGreaterEqualThan: def setup_method(self): self.form_field_mock = mock.MagicMock(data="2017-05-06") @@ -120,7 +119,6 @@ def test_validation_raises_custom_message(self): message="Invalid JSON: {}", ) - class TestValidKey: def setup_method(self): self.form_field_mock = mock.MagicMock(data="valid_key") @@ -164,4 +162,4 @@ def setup_method(self): def test_read_only_validator(self): validator = validators.ReadOnly() assert validator(self.form_mock, self.form_read_only_field_mock) is None - assert self.form_read_only_field_mock.flags.readonly is True + assert self.form_read_only_field_mock.flags.readonly is True \ No newline at end of file diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 3d13dea4d1248f..fffb2f87286a93 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -20,6 +20,7 @@ from datetime import timedelta from typing import TYPE_CHECKING +from airflow.utils.session import create_session import pendulum import pytest from dateutil.tz import UTC @@ -34,13 +35,13 @@ from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType -from airflow.www.views import dag_to_grid +from airflow.www.views import _safe_parse_datetime, dag_to_grid, get_date_time_num_runs_dag_runs_form_data from tests.test_utils.asserts import assert_queries_count from tests.test_utils.db import clear_db_datasets, clear_db_runs from tests.test_utils.mock_operators import MockOperator pytestmark = pytest.mark.db_test - + if TYPE_CHECKING: from airflow.models.dagrun import DagRun @@ -514,3 +515,59 @@ def test_next_run_datasets_404(admin_client): resp = admin_client.get("/object/next_run_datasets/missingdag", follow_redirects=True) assert resp.status_code == 404, resp.json assert resp.json == {"error": "can't find dag missingdag"} + +def test_get_date_time_num_runs_dag_runs_form_data(dag_with_runs): + run1, _ = dag_with_runs + + class Request: + def __init__(self, form): + self.form = form + self.args = form + + def get(self, key, default=None): + return self.args.get(key, default) + + # Test case 1: run_id is provided + request_with_run_id = Request(form={ + "execution_date": run1.execution_date.isoformat(), + "num_runs": "5", + "run_id": run1.run_id + }) + + with create_session() as session: + data = get_date_time_num_runs_dag_runs_form_data(request_with_run_id, session, run1.dag) + + assert data['dttm'] == run1.execution_date + assert data['execution_date'] == run1.execution_date.isoformat() + assert data['num_runs'] == 5 + + # Test case 2: base_date is provided + base_date = "2023-01-01T00:00:00+00:00" + request_with_base_date = Request(form={ + "execution_date": run1.execution_date.isoformat(), + "num_runs": "5", + "base_date": base_date + }) + + with create_session() as session: + data = get_date_time_num_runs_dag_runs_form_data(request_with_base_date, session, run1.dag) + + assert data['base_date'] == _safe_parse_datetime(base_date) + assert data['execution_date'] == run1.execution_date.isoformat() + assert data['num_runs'] == 5 + + # Test case 3: both run_id and base_date are provided + request_with_run_id_and_base_date = Request(form={ + "execution_date": run1.execution_date.isoformat(), + "num_runs": "5", + "run_id": run1.run_id, + "base_date": base_date + }) + + with create_session() as session: + data = get_date_time_num_runs_dag_runs_form_data(request_with_run_id_and_base_date, session, run1.dag) + + assert data['dttm'] == run1.execution_date + assert data['base_date'] == _safe_parse_datetime(base_date) + assert data['execution_date'] == run1.execution_date.isoformat() + assert data['num_runs'] == 5 \ No newline at end of file From 9ad44045c282d4cfe78fea0f48b4596e054f6e4c Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Fri, 14 Jun 2024 23:46:10 -0600 Subject: [PATCH 02/21] Spaces fix --- tests/www/views/test_views_grid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index fffb2f87286a93..c16444785b518f 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -20,7 +20,6 @@ from datetime import timedelta from typing import TYPE_CHECKING -from airflow.utils.session import create_session import pendulum import pytest from dateutil.tz import UTC @@ -35,13 +34,14 @@ from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType +from airflow.utils.session import create_session from airflow.www.views import _safe_parse_datetime, dag_to_grid, get_date_time_num_runs_dag_runs_form_data from tests.test_utils.asserts import assert_queries_count from tests.test_utils.db import clear_db_datasets, clear_db_runs from tests.test_utils.mock_operators import MockOperator pytestmark = pytest.mark.db_test - + if TYPE_CHECKING: from airflow.models.dagrun import DagRun From a20d7eb398b88bae856d1861d5d0cd74f2d8532e Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Sat, 15 Jun 2024 00:32:21 -0600 Subject: [PATCH 03/21] Adding coverage --- tests/www/views/test_views_grid.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index c16444785b518f..0328567d316146 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -527,36 +527,7 @@ def __init__(self, form): def get(self, key, default=None): return self.args.get(key, default) - # Test case 1: run_id is provided - request_with_run_id = Request(form={ - "execution_date": run1.execution_date.isoformat(), - "num_runs": "5", - "run_id": run1.run_id - }) - - with create_session() as session: - data = get_date_time_num_runs_dag_runs_form_data(request_with_run_id, session, run1.dag) - - assert data['dttm'] == run1.execution_date - assert data['execution_date'] == run1.execution_date.isoformat() - assert data['num_runs'] == 5 - - # Test case 2: base_date is provided base_date = "2023-01-01T00:00:00+00:00" - request_with_base_date = Request(form={ - "execution_date": run1.execution_date.isoformat(), - "num_runs": "5", - "base_date": base_date - }) - - with create_session() as session: - data = get_date_time_num_runs_dag_runs_form_data(request_with_base_date, session, run1.dag) - - assert data['base_date'] == _safe_parse_datetime(base_date) - assert data['execution_date'] == run1.execution_date.isoformat() - assert data['num_runs'] == 5 - - # Test case 3: both run_id and base_date are provided request_with_run_id_and_base_date = Request(form={ "execution_date": run1.execution_date.isoformat(), "num_runs": "5", From e081547e86824eaf7c0d8f6d82c6e86ed93d454c Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Mon, 17 Jun 2024 12:10:51 -0600 Subject: [PATCH 04/21] Fixed spacing --- tests/www/test_validators.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py index c8961f90073228..884401a16bc914 100644 --- a/tests/www/test_validators.py +++ b/tests/www/test_validators.py @@ -162,4 +162,5 @@ def setup_method(self): def test_read_only_validator(self): validator = validators.ReadOnly() assert validator(self.form_mock, self.form_read_only_field_mock) is None - assert self.form_read_only_field_mock.flags.readonly is True \ No newline at end of file + assert self.form_read_only_field_mock.flags.readonly is True + \ No newline at end of file From 6a921bbe8ec9c8f8227668932ac0863251ff0534 Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Mon, 17 Jun 2024 15:41:55 -0600 Subject: [PATCH 05/21] Changes before repo fix --- tests/www/views/test_views_grid.py | 1 - tests/www/views/test_views_tasks.py | 24 +++++++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 0328567d316146..468f3dd449d507 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -518,7 +518,6 @@ def test_next_run_datasets_404(admin_client): def test_get_date_time_num_runs_dag_runs_form_data(dag_with_runs): run1, _ = dag_with_runs - class Request: def __init__(self, form): self.form = form diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index de86d9227bd646..9a741eb7f54224 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -43,7 +43,7 @@ from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType -from airflow.www.views import TaskInstanceModelView +from airflow.www.views import TaskInstanceModelView, _safe_parse_datetime from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_runs, clear_db_xcom @@ -1050,6 +1050,28 @@ def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, admin_client) assert resp.status_code == 200 +def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, admin_client): + """Test the get_date_time_num_runs_dag_runs_form_data function.""" + from airflow.www.views import get_date_time_num_runs_dag_runs_form_data + + with dag_maker("test_get_date_time_num_runs_dag_runs_form_data") as dag: + BashOperator(task_id="task_1", bash_command="echo test") + + # May delete later as it's repetitive and unused in this test case + with unittest.mock.patch.object(app, "dag_bag") as mocked_dag_bag: + mocked_dag_bag.get_dag.return_value = dag + url = f"/dags/{dag.dag_id}/graph" + resp = admin_client.get(url, follow_redirects=True) + assert resp.status_code == 200 + + with create_session() as session: + data = get_date_time_num_runs_dag_runs_form_data(resp, session, dag) + + assert data["dttm"] == dag.execution_date + assert data["base_date"] == _safe_parse_datetime(dag.base_date) + assert data["execution_date"] == dag.execution_date.isoformat() + assert data["num_runs"] == 1 + def test_task_instances(admin_client): """Test task_instances view.""" resp = admin_client.get( From 581646a69fe11b5be862051f0a5e0cc420bcb910 Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Mon, 17 Jun 2024 21:24:57 -0600 Subject: [PATCH 06/21] Increased Coverage to 85%, adding proper mock dag generation and response --- airflow/www/views.py | 6 +---- tests/www/views/test_views_grid.py | 26 ------------------- tests/www/views/test_views_tasks.py | 39 ++++++++++++++++++++--------- 3 files changed, 28 insertions(+), 43 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index 75741689282b09..396d7e6e78fa03 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -271,11 +271,7 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): base_date = (date_time + datetime.timedelta(seconds=1)).replace(microsecond=0) default_dag_run = conf.getint("webserver", "default_dag_run_display_number") - num_runs = www_request.args.get("num_runs") - if num_runs is None: - num_runs = default_dag_run - else: - num_runs = int(num_runs) + num_runs = www_request.args.get("num_runs", default=default_dag_run, type=int) # When base_date has been rounded up because of the DateTimeField widget, we want diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 468f3dd449d507..2850f4e21e8a21 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -515,29 +515,3 @@ def test_next_run_datasets_404(admin_client): resp = admin_client.get("/object/next_run_datasets/missingdag", follow_redirects=True) assert resp.status_code == 404, resp.json assert resp.json == {"error": "can't find dag missingdag"} - -def test_get_date_time_num_runs_dag_runs_form_data(dag_with_runs): - run1, _ = dag_with_runs - class Request: - def __init__(self, form): - self.form = form - self.args = form - - def get(self, key, default=None): - return self.args.get(key, default) - - base_date = "2023-01-01T00:00:00+00:00" - request_with_run_id_and_base_date = Request(form={ - "execution_date": run1.execution_date.isoformat(), - "num_runs": "5", - "run_id": run1.run_id, - "base_date": base_date - }) - - with create_session() as session: - data = get_date_time_num_runs_dag_runs_form_data(request_with_run_id_and_base_date, session, run1.dag) - - assert data['dttm'] == run1.execution_date - assert data['base_date'] == _safe_parse_datetime(base_date) - assert data['execution_date'] == run1.execution_date.isoformat() - assert data['num_runs'] == 5 \ No newline at end of file diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 9a741eb7f54224..f4f507048dd86d 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -66,7 +66,7 @@ def reset_dagruns(): @pytest.fixture(autouse=True) -def init_dagruns(app, reset_dagruns): +def init_dagruns(app): with time_machine.travel(DEFAULT_DATE, tick=False): app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id=DEFAULT_DAGRUN, @@ -1049,28 +1049,43 @@ def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, admin_client) resp = admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 - + def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, admin_client): """Test the get_date_time_num_runs_dag_runs_form_data function.""" from airflow.www.views import get_date_time_num_runs_dag_runs_form_data - with dag_maker("test_get_date_time_num_runs_dag_runs_form_data") as dag: + execution_date = pendulum.DateTime(2024, 1, 1, 0, 0, 0, tzinfo=pendulum.UTC) + with dag_maker( + dag_id="test_get_date_time_num_runs_dag_runs_form_data", + start_date=execution_date, + ) as dag: BashOperator(task_id="task_1", bash_command="echo test") - - # May delete later as it's repetitive and unused in this test case + + dag_run = dag_maker.create_dagrun( + run_id="test_dagrun_id", + run_type=DagRunType.SCHEDULED, + execution_date=execution_date, + start_date=execution_date, + state=DagRunState.RUNNING, + ) + with unittest.mock.patch.object(app, "dag_bag") as mocked_dag_bag: mocked_dag_bag.get_dag.return_value = dag url = f"/dags/{dag.dag_id}/graph" resp = admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 - + with create_session() as session: - data = get_date_time_num_runs_dag_runs_form_data(resp, session, dag) - - assert data["dttm"] == dag.execution_date - assert data["base_date"] == _safe_parse_datetime(dag.base_date) - assert data["execution_date"] == dag.execution_date.isoformat() - assert data["num_runs"] == 1 + data = get_date_time_num_runs_dag_runs_form_data(resp.request, session, dag) + + dttm = pendulum.parse(data["dttm"].isoformat()) + base_date = pendulum.parse(data["base_date"].isoformat()) + + assert dttm.date() == execution_date.date() + assert dttm.time() == _safe_parse_datetime(execution_date.time().isoformat()).time() + assert base_date.date() == execution_date.date() + assert data["execution_date"] == execution_date.isoformat() + def test_task_instances(admin_client): """Test task_instances view.""" From f04434dad28acd4599ad450018112ba423fdf8c0 Mon Sep 17 00:00:00 2001 From: andyjianzhou Date: Wed, 10 Jul 2024 12:18:54 -0600 Subject: [PATCH 07/21] Fixing conflicts --- airflow/www/views.py | 1 - tests/www/test_utils.py | 52 ++ tests/www/test_utils_BACKUP_289691.py | 708 ++++++++++++++++++++++++++ tests/www/test_utils_BASE_289691.py | 479 +++++++++++++++++ tests/www/test_utils_LOCAL_289691.py | 656 ++++++++++++++++++++++++ tests/www/test_utils_REMOTE_289691.py | 531 +++++++++++++++++++ tests/www/test_validators.py | 5 +- tests/www/views/test_views_grid.py | 5 +- tests/www/views/test_views_tasks.py | 18 +- 9 files changed, 2436 insertions(+), 19 deletions(-) create mode 100644 tests/www/test_utils_BACKUP_289691.py create mode 100644 tests/www/test_utils_BASE_289691.py create mode 100644 tests/www/test_utils_LOCAL_289691.py create mode 100644 tests/www/test_utils_REMOTE_289691.py diff --git a/airflow/www/views.py b/airflow/www/views.py index 396d7e6e78fa03..7ffd5ec17c4cca 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -272,7 +272,6 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): default_dag_run = conf.getint("webserver", "default_dag_run_display_number") num_runs = www_request.args.get("num_runs", default=default_dag_run, type=int) - # When base_date has been rounded up because of the DateTimeField widget, we want # to use the execution_date as the starting point for our query just to ensure a diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index a90d9246998d6f..02eb43526e1bd7 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -28,13 +28,16 @@ import pytest from bs4 import BeautifulSoup from flask_appbuilder.models.sqla.filters import get_field_setup_query, set_value_to_type +from flask_wtf import FlaskForm from markupsafe import Markup from sqlalchemy.orm import Query +from wtforms.fields import StringField, TextAreaField from airflow.models import DagRun from airflow.utils import json as utils_json from airflow.www import utils from airflow.www.utils import CustomSQLAInterface, DagRunCustomSQLAInterface, json_f, wrapped_markdown +from airflow.www.widgets import AirflowDateTimePickerROWidget, BS3TextAreaROWidget, BS3TextFieldROWidget from tests.test_utils.config import conf_vars @@ -654,3 +657,52 @@ def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, se assert len(dag_runs) == 6 assert len(set(x.dag_id for x in dag_runs)) == 3 assert len(set(x.run_id for x in dag_runs)) == 3 + + +@pytest.fixture +def app(): + from flask import Flask + + app = Flask(__name__) + app.config["WTF_CSRF_ENABLED"] = False + app.config["SECRET_KEY"] = "secret" + with app.app_context(): + yield app + + +class TestWidgets: + def test_airflow_datetime_picker_ro_widget(self, app): + class TestForm(FlaskForm): + datetime_field = StringField(widget=AirflowDateTimePickerROWidget()) + + form = TestForm() + field = form.datetime_field + + html_output = field() + + assert 'readonly="true"' in html_output + assert "input-group datetime datetimepicker" in html_output + + def test_bs3_text_field_ro_widget(self, app): + class TestForm(FlaskForm): + text_field = StringField(widget=BS3TextFieldROWidget()) + + form = TestForm() + field = form.text_field + + html_output = field() + + assert 'readonly="true"' in html_output + assert "form-control" in html_output + + def test_bs3_text_area_ro_widget(self, app): + class TestForm(FlaskForm): + textarea_field = TextAreaField(widget=BS3TextAreaROWidget()) + + form = TestForm() + field = form.textarea_field + + html_output = field() + + assert 'readonly="true"' in html_output + assert "form-control" in html_output diff --git a/tests/www/test_utils_BACKUP_289691.py b/tests/www/test_utils_BACKUP_289691.py new file mode 100644 index 00000000000000..02eb43526e1bd7 --- /dev/null +++ b/tests/www/test_utils_BACKUP_289691.py @@ -0,0 +1,708 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import itertools +import re +import time +from datetime import datetime +from unittest.mock import Mock +from urllib.parse import parse_qs + +import pendulum +import pytest +from bs4 import BeautifulSoup +from flask_appbuilder.models.sqla.filters import get_field_setup_query, set_value_to_type +from flask_wtf import FlaskForm +from markupsafe import Markup +from sqlalchemy.orm import Query +from wtforms.fields import StringField, TextAreaField + +from airflow.models import DagRun +from airflow.utils import json as utils_json +from airflow.www import utils +from airflow.www.utils import CustomSQLAInterface, DagRunCustomSQLAInterface, json_f, wrapped_markdown +from airflow.www.widgets import AirflowDateTimePickerROWidget, BS3TextAreaROWidget, BS3TextFieldROWidget +from tests.test_utils.config import conf_vars + + +class TestUtils: + def check_generate_pages_html( + self, + current_page, + total_pages, + window=7, + check_middle=False, + sorting_key=None, + sorting_direction=None, + ): + extra_links = 4 # first, prev, next, last + search = "'>\"/>" + if sorting_key and sorting_direction: + html_str = utils.generate_pages( + current_page, + total_pages, + search=search, + sorting_key=sorting_key, + sorting_direction=sorting_direction, + ) + else: + html_str = utils.generate_pages(current_page, total_pages, search=search) + + assert search not in html_str, "The raw search string shouldn't appear in the output" + assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" in html_str + + assert callable(html_str.__html__), "Should return something that is HTML-escaping aware" + + dom = BeautifulSoup(html_str, "html.parser") + assert dom is not None + + ulist = dom.ul + ulist_items = ulist.find_all("li") + assert min(window, total_pages) + extra_links == len(ulist_items) + + page_items = ulist_items[2:-2] + mid = len(page_items) // 2 + all_nodes = [] + pages = [] + + if sorting_key and sorting_direction: + last_page = total_pages - 1 + + if current_page <= mid or total_pages < window: + pages = list(range(min(total_pages, window))) + elif mid < current_page < last_page - mid: + pages = list(range(current_page - mid, current_page + mid + 1)) + else: + pages = list(range(total_pages - window, last_page + 1)) + + pages.append(last_page + 1) + pages.sort(reverse=True if sorting_direction == "desc" else False) + + for i, item in enumerate(page_items): + a_node = item.a + href_link = a_node["href"] + node_text = a_node.string + all_nodes.append(node_text) + if node_text == str(current_page + 1): + if check_middle: + assert mid == i + assert "javascript:void(0)" == href_link + assert "active" in item["class"] + else: + assert re.search(r"^\?", href_link), "Link is page-relative" + query = parse_qs(href_link[1:]) + assert query["page"] == [str(int(node_text) - 1)] + assert query["search"] == [search] + + if sorting_key and sorting_direction: + if pages[0] == 0: + pages = [str(page) for page in pages[1:]] + + assert pages == all_nodes + + def test_generate_pager_current_start(self): + self.check_generate_pages_html(current_page=0, total_pages=6) + + def test_generate_pager_current_middle(self): + self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True) + + def test_generate_pager_current_end(self): + self.check_generate_pages_html(current_page=38, total_pages=39) + + def test_generate_pager_current_start_with_sorting(self): + self.check_generate_pages_html( + current_page=0, total_pages=4, sorting_key="dag_id", sorting_direction="asc" + ) + + def test_params_no_values(self): + """Should return an empty string if no params are passed""" + assert "" == utils.get_params() + + def test_params_search(self): + assert "search=bash_" == utils.get_params(search="bash_") + + def test_params_none_and_zero(self): + query_str = utils.get_params(a=0, b=None, c="true") + # The order won't be consistent, but that doesn't affect behaviour of a browser + pairs = sorted(query_str.split("&")) + assert ["a=0", "c=true"] == pairs + + def test_params_all(self): + query = utils.get_params(tags=["tag1", "tag2"], status="active", page=3, search="bash_") + assert { + "tags": ["tag1", "tag2"], + "page": ["3"], + "search": ["bash_"], + "status": ["active"], + } == parse_qs(query) + + def test_params_escape(self): + assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" == utils.get_params( + search="'>\"/>" + ) + + def test_state_token(self): + # It's shouldn't possible to set these odd values anymore, but lets + # ensure they are escaped! + html = str(utils.state_token("")) + + assert "<script>alert(1)</script>" in html + assert "" not in html + + def test_nobr_f(self): + attr = {"attr_name": "attribute"} + f = attr.get("attr_name") + expected_markup = Markup("{}").format(f) + + nobr = utils.nobr_f("attr_name") + result_markup = nobr(attr) + + assert result_markup == expected_markup + + def test_nobr_f_empty_attr(self): + attr = {"attr_name": ""} + f = attr.get("attr_name") + expected_markup = Markup("{}").format(f) + + nobr = utils.nobr_f("attr_name") + result_markup = nobr(attr) + + assert result_markup == expected_markup + + def test_nobr_f_missing_attr(self): + attr = {} + f = None + expected_markup = Markup("{}").format(f) + + nobr = utils.nobr_f("attr_name") + result_markup = nobr(attr) + + assert result_markup == expected_markup + + def test_epoch(self): + test_datetime = datetime(2024, 6, 19, 12, 0, 0) + result = utils.epoch(test_datetime) + epoch_time = result[0] + + expected_epoch_time = int(time.mktime(test_datetime.timetuple())) * 1000 + + assert epoch_time == expected_epoch_time + + @pytest.mark.db_test + def test_make_cache_key(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context( + "/test/path", query_string={"key1": "value1", "key2": "value2"} + ): + expected_args = str(hash(frozenset({"key1": "value1", "key2": "value2"}.items()))) + expected_cache_key = ("/test/path" + expected_args).encode("ascii", "ignore") + result_cache_key = utils.make_cache_key() + assert result_cache_key == expected_cache_key + + @pytest.mark.db_test + def test_task_instance_link(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str( + utils.task_instance_link( + {"dag_id": "", "task_id": "", "map_index": 1, "execution_date": datetime.now()} + ) + ) + + html_map_index_none = str( + utils.task_instance_link( + {"dag_id": "", "task_id": "", "map_index": -1, "execution_date": datetime.now()} + ) + ) + + assert "%3Ca%261%3E" in html + assert "%3Cb2%3E" in html + assert "map_index" in html + assert "" not in html + assert "" not in html + + assert "%3Ca%261%3E" in html_map_index_none + assert "%3Cb2%3E" in html_map_index_none + assert "map_index" not in html_map_index_none + assert "" not in html_map_index_none + assert "" not in html_map_index_none + + @pytest.mark.db_test + def test_dag_link(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) + + assert "%3Ca%261%3E" in html + assert "" not in html + + @pytest.mark.db_test + def test_dag_link_when_dag_is_none(self): + """Test that when there is no dag_id, dag_link does not contain hyperlink""" + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str(utils.dag_link({})) + + assert "None" in html + assert "", "run_id": "", "execution_date": datetime.now()}) + ) + + assert "%3Ca%261%3E" in html + assert "%3Cb2%3E" in html + assert "" not in html + assert "" not in html + + +class TestAttrRenderer: + def setup_method(self): + self.attr_renderer = utils.get_attr_renderer() + + def test_python_callable(self): + def example_callable(unused_self): + print("example") + + rendered = self.attr_renderer["python_callable"](example_callable) + assert ""example"" in rendered + + def test_python_callable_none(self): + rendered = self.attr_renderer["python_callable"](None) + assert "" == rendered + + def test_markdown(self): + markdown = "* foo\n* bar" + rendered = self.attr_renderer["doc_md"](markdown) + assert "
  • foo
  • " in rendered + assert "
  • bar
  • " in rendered + + def test_markdown_none(self): + rendered = self.attr_renderer["doc_md"](None) + assert rendered is None + + def test_get_dag_run_conf(self): + dag_run_conf = { + "1": "string", + "2": b"bytes", + "3": 123, + "4": "à".encode("latin"), + "5": datetime(2023, 1, 1), + } + expected_encoded_dag_run_conf = ( + '{"1": "string", "2": "bytes", "3": 123, "4": "à", "5": "2023-01-01T00:00:00+00:00"}' + ) + encoded_dag_run_conf, conf_is_json = utils.get_dag_run_conf( + dag_run_conf, json_encoder=utils_json.WebEncoder + ) + assert expected_encoded_dag_run_conf == encoded_dag_run_conf + + def test_encode_dag_run_none(self): + no_dag_run_result = utils.encode_dag_run(None) + assert no_dag_run_result is None + + def test_json_f_webencoder(self): + dag_run_conf = { + "1": "string", + "2": b"bytes", + "3": 123, + "4": "à".encode("latin"), + "5": datetime(2023, 1, 1), + } + expected_encoded_dag_run_conf = ( + # HTML sanitization is insane + '{"1": "string", "2": "bytes", "3": 123, "4": "\\u00e0", "5": "2023-01-01T00:00:00+00:00"}' + ) + expected_markup = Markup("{}").format(expected_encoded_dag_run_conf) + + formatter = json_f("conf") + dagrun = Mock() + dagrun.get = Mock(return_value=dag_run_conf) + + assert formatter(dagrun) == expected_markup + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_get_sensitive_variables_fields(): + with pytest.warns(DeprecationWarning) as warning: + result = utils.get_sensitive_variables_fields() + + # assert deprecation warning + assert len(warning) == 1 + assert "This function is deprecated." in str(warning[-1].message) + + from airflow.utils.log.secrets_masker import get_sensitive_variables_fields + + expected_result = get_sensitive_variables_fields() + assert result == expected_result + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_should_hide_value_for_key(): + key_name = "key" + + with pytest.warns(DeprecationWarning) as warning: + result = utils.should_hide_value_for_key(key_name) + + # assert deprecation warning + assert len(warning) == 1 + assert "This function is deprecated." in str(warning[-1].message) + + from airflow.utils.log.secrets_masker import should_hide_value_for_key + + expected_result = should_hide_value_for_key(key_name) + assert result == expected_result + + +class TestWrappedMarkdown: + def test_wrapped_markdown_with_docstring_curly_braces(self): + rendered = wrapped_markdown("{braces}", css_class="a_class") + assert ( + """

    {braces}

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_some_markdown(self): + rendered = wrapped_markdown( + """*italic* + **bold** + """, + css_class="a_class", + ) + + assert ( + """

    italic +bold

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_table(self): + rendered = wrapped_markdown( + """ +| Job | Duration | +| ----------- | ----------- | +| ETL | 14m | +""" + ) + + assert ( + """
    + + + + + + + + + + + + +
    JobDuration
    ETL14m
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_indented_lines(self): + rendered = wrapped_markdown( + """ + # header + 1st line + 2nd line + """ + ) + + assert ( + """

    header

    \n

    1st line\n2nd line

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_raw_code_block(self): + rendered = wrapped_markdown( + """\ + # Markdown code block + + Inline `code` works well. + + Code block + does not + respect + newlines + + """ + ) + + assert ( + """

    Markdown code block

    +

    Inline code works well.

    +
    Code block\ndoes not\nrespect\nnewlines\n
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_nested_list(self): + rendered = wrapped_markdown( + """ + ### Docstring with a code block + + - And + - A nested list + """ + ) + + assert ( + """

    Docstring with a code block

    +
      +
    • And +
        +
      • A nested list
      • +
      +
    • +
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_collapsible_section(self): + with conf_vars({("webserver", "allow_raw_html_descriptions"): "true"}): + rendered = wrapped_markdown( + """ +# A collapsible section with markdown +
    + Click to expand! + + ## Heading + 1. A numbered + 2. list + * With some + * Sub bullets +
    + """ + ) + + assert ( + """

    A collapsible section with markdown

    +
    + Click to expand! +

    Heading

    +
      +
    1. A numbered
    2. +
    3. list +
        +
      • With some
      • +
      • Sub bullets
      • +
      +
    4. +
    +
    +
    """ + == rendered + ) + + @pytest.mark.parametrize("allow_html", [False, True]) + def test_wrapped_markdown_with_raw_html(self, allow_html): + with conf_vars({("webserver", "allow_raw_html_descriptions"): str(allow_html)}): + HTML = "test raw HTML" + rendered = wrapped_markdown(HTML) + if allow_html: + assert HTML in rendered + else: + from markupsafe import escape + + assert escape(HTML) in rendered + + +class TestFilter: + def setup_method(self): + self.mock_datamodel = Mock() + self.mock_query = Mock(spec=Query) + self.mock_column_name = "test_column" + + def test_filter_is_null_apply(self): + filter_is_null = utils.FilterIsNull(datamodel=self.mock_datamodel, column_name=self.mock_column_name) + + self.mock_query, mock_field = get_field_setup_query( + self.mock_query, self.mock_datamodel, self.mock_column_name + ) + mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) + + result_query_filter = filter_is_null.apply(self.mock_query, None) + self.mock_query.filter.assert_called_once_with(mock_field == mock_value) + + expected_query_filter = self.mock_query.filter(mock_field == mock_value) + + assert result_query_filter == expected_query_filter + + def test_filter_is_not_null_apply(self): + filter_is_not_null = utils.FilterIsNotNull( + datamodel=self.mock_datamodel, column_name=self.mock_column_name + ) + + self.mock_query, mock_field = get_field_setup_query( + self.mock_query, self.mock_datamodel, self.mock_column_name + ) + mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) + + result_query_filter = filter_is_not_null.apply(self.mock_query, None) + self.mock_query.filter.assert_called_once_with(mock_field != mock_value) + + expected_query_filter = self.mock_query.filter(mock_field != mock_value) + + assert result_query_filter == expected_query_filter + + def test_filter_gte_none_value_apply(self): + filter_gte = utils.FilterGreaterOrEqual( + datamodel=self.mock_datamodel, column_name=self.mock_column_name + ) + + self.mock_query, mock_field = get_field_setup_query( + self.mock_query, self.mock_datamodel, self.mock_column_name + ) + mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) + + result_query_filter = filter_gte.apply(self.mock_query, mock_value) + + assert result_query_filter == self.mock_query + + def test_filter_lte_none_value_apply(self): + filter_lte = utils.FilterSmallerOrEqual( + datamodel=self.mock_datamodel, column_name=self.mock_column_name + ) + + self.mock_query, mock_field = get_field_setup_query( + self.mock_query, self.mock_datamodel, self.mock_column_name + ) + mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) + + result_query_filter = filter_lte.apply(self.mock_query, mock_value) + + assert result_query_filter == self.mock_query + + +@pytest.mark.db_test +def test_get_col_default_not_existing(session): + interface = CustomSQLAInterface(obj=DagRun, session=session) + default_value = interface.get_col_default("column_not_existing") + assert default_value is None + + +@pytest.mark.db_test +def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, session): + interface = DagRunCustomSQLAInterface(obj=DagRun, session=session) + dag_ids = (f"test_dag_{x}" for x in range(1, 4)) + dates = (pendulum.datetime(2023, 1, x) for x in range(1, 4)) + for dag_id, date in itertools.product(dag_ids, dates): + with dag_maker(dag_id=dag_id) as dag: + dag.create_dagrun( + execution_date=date, state="running", run_type="scheduled", data_interval=(date, date) + ) + dag_runs = session.query(DagRun).all() + assert len(dag_runs) == 9 + assert len(set(x.run_id for x in dag_runs)) == 3 + run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00" + # we have 3 runs with this same run_id + assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3 + # each is a different dag + + # if we delete one, it shouldn't delete the others + one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete) + assert interface.delete(item=one_run) is True + session.commit() + dag_runs = session.query(DagRun).all() + # we should have one fewer dag run now + assert len(dag_runs) == 8 + + # now let's try multi delete + run_id_for_multi_delete = "scheduled__2023-01-02T00:00:00+00:00" + # verify we have 3 + runs_of_interest = [x for x in dag_runs if x.run_id == run_id_for_multi_delete] + assert len(runs_of_interest) == 3 + # and that each is different dag + assert len(set(x.dag_id for x in dag_runs)) == 3 + + to_delete = runs_of_interest[:2] + # now try multi delete + assert interface.delete_all(items=to_delete) is True + session.commit() + dag_runs = session.query(DagRun).all() + assert len(dag_runs) == 6 + assert len(set(x.dag_id for x in dag_runs)) == 3 + assert len(set(x.run_id for x in dag_runs)) == 3 + + +@pytest.fixture +def app(): + from flask import Flask + + app = Flask(__name__) + app.config["WTF_CSRF_ENABLED"] = False + app.config["SECRET_KEY"] = "secret" + with app.app_context(): + yield app + + +class TestWidgets: + def test_airflow_datetime_picker_ro_widget(self, app): + class TestForm(FlaskForm): + datetime_field = StringField(widget=AirflowDateTimePickerROWidget()) + + form = TestForm() + field = form.datetime_field + + html_output = field() + + assert 'readonly="true"' in html_output + assert "input-group datetime datetimepicker" in html_output + + def test_bs3_text_field_ro_widget(self, app): + class TestForm(FlaskForm): + text_field = StringField(widget=BS3TextFieldROWidget()) + + form = TestForm() + field = form.text_field + + html_output = field() + + assert 'readonly="true"' in html_output + assert "form-control" in html_output + + def test_bs3_text_area_ro_widget(self, app): + class TestForm(FlaskForm): + textarea_field = TextAreaField(widget=BS3TextAreaROWidget()) + + form = TestForm() + field = form.textarea_field + + html_output = field() + + assert 'readonly="true"' in html_output + assert "form-control" in html_output diff --git a/tests/www/test_utils_BASE_289691.py b/tests/www/test_utils_BASE_289691.py new file mode 100644 index 00000000000000..1fc42c1fefb622 --- /dev/null +++ b/tests/www/test_utils_BASE_289691.py @@ -0,0 +1,479 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import itertools +import re +from datetime import datetime +from unittest.mock import Mock +from urllib.parse import parse_qs + +import pendulum +import pytest +from bs4 import BeautifulSoup +from markupsafe import Markup + +from airflow.models import DagRun +from airflow.utils import json as utils_json +from airflow.www import utils +from airflow.www.utils import DagRunCustomSQLAInterface, json_f, wrapped_markdown +from tests.test_utils.config import conf_vars + + +class TestUtils: + def check_generate_pages_html( + self, + current_page, + total_pages, + window=7, + check_middle=False, + sorting_key=None, + sorting_direction=None, + ): + extra_links = 4 # first, prev, next, last + search = "'>\"/>" + if sorting_key and sorting_direction: + html_str = utils.generate_pages( + current_page, + total_pages, + search=search, + sorting_key=sorting_key, + sorting_direction=sorting_direction, + ) + else: + html_str = utils.generate_pages(current_page, total_pages, search=search) + + assert search not in html_str, "The raw search string shouldn't appear in the output" + assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" in html_str + + assert callable(html_str.__html__), "Should return something that is HTML-escaping aware" + + dom = BeautifulSoup(html_str, "html.parser") + assert dom is not None + + ulist = dom.ul + ulist_items = ulist.find_all("li") + assert min(window, total_pages) + extra_links == len(ulist_items) + + page_items = ulist_items[2:-2] + mid = len(page_items) // 2 + all_nodes = [] + pages = [] + + if sorting_key and sorting_direction: + last_page = total_pages - 1 + + if current_page <= mid or total_pages < window: + pages = list(range(min(total_pages, window))) + elif mid < current_page < last_page - mid: + pages = list(range(current_page - mid, current_page + mid + 1)) + else: + pages = list(range(total_pages - window, last_page + 1)) + + pages.append(last_page + 1) + pages.sort(reverse=True if sorting_direction == "desc" else False) + + for i, item in enumerate(page_items): + a_node = item.a + href_link = a_node["href"] + node_text = a_node.string + all_nodes.append(node_text) + if node_text == str(current_page + 1): + if check_middle: + assert mid == i + assert "javascript:void(0)" == href_link + assert "active" in item["class"] + else: + assert re.search(r"^\?", href_link), "Link is page-relative" + query = parse_qs(href_link[1:]) + assert query["page"] == [str(int(node_text) - 1)] + assert query["search"] == [search] + + if sorting_key and sorting_direction: + if pages[0] == 0: + pages = [str(page) for page in pages[1:]] + + assert pages == all_nodes + + def test_generate_pager_current_start(self): + self.check_generate_pages_html(current_page=0, total_pages=6) + + def test_generate_pager_current_middle(self): + self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True) + + def test_generate_pager_current_end(self): + self.check_generate_pages_html(current_page=38, total_pages=39) + + def test_generate_pager_current_start_with_sorting(self): + self.check_generate_pages_html( + current_page=0, total_pages=4, sorting_key="dag_id", sorting_direction="asc" + ) + + def test_params_no_values(self): + """Should return an empty string if no params are passed""" + assert "" == utils.get_params() + + def test_params_search(self): + assert "search=bash_" == utils.get_params(search="bash_") + + def test_params_none_and_zero(self): + query_str = utils.get_params(a=0, b=None, c="true") + # The order won't be consistent, but that doesn't affect behaviour of a browser + pairs = sorted(query_str.split("&")) + assert ["a=0", "c=true"] == pairs + + def test_params_all(self): + query = utils.get_params(tags=["tag1", "tag2"], status="active", page=3, search="bash_") + assert { + "tags": ["tag1", "tag2"], + "page": ["3"], + "search": ["bash_"], + "status": ["active"], + } == parse_qs(query) + + def test_params_escape(self): + assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" == utils.get_params( + search="'>\"/>" + ) + + def test_state_token(self): + # It's shouldn't possible to set these odd values anymore, but lets + # ensure they are escaped! + html = str(utils.state_token("")) + + assert "<script>alert(1)</script>" in html + assert "" not in html + + @pytest.mark.db_test + def test_task_instance_link(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str( + utils.task_instance_link( + {"dag_id": "", "task_id": "", "execution_date": datetime.now()} + ) + ) + + assert "%3Ca%261%3E" in html + assert "%3Cb2%3E" in html + assert "" not in html + assert "" not in html + + @pytest.mark.db_test + def test_dag_link(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) + + assert "%3Ca%261%3E" in html + assert "" not in html + + @pytest.mark.db_test + def test_dag_link_when_dag_is_none(self): + """Test that when there is no dag_id, dag_link does not contain hyperlink""" + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str(utils.dag_link({})) + + assert "None" in html + assert "
    ", "run_id": "", "execution_date": datetime.now()}) + ) + + assert "%3Ca%261%3E" in html + assert "%3Cb2%3E" in html + assert "" not in html + assert "" not in html + + +class TestAttrRenderer: + def setup_method(self): + self.attr_renderer = utils.get_attr_renderer() + + def test_python_callable(self): + def example_callable(unused_self): + print("example") + + rendered = self.attr_renderer["python_callable"](example_callable) + assert ""example"" in rendered + + def test_python_callable_none(self): + rendered = self.attr_renderer["python_callable"](None) + assert "" == rendered + + def test_markdown(self): + markdown = "* foo\n* bar" + rendered = self.attr_renderer["doc_md"](markdown) + assert "
  • foo
  • " in rendered + assert "
  • bar
  • " in rendered + + def test_markdown_none(self): + rendered = self.attr_renderer["doc_md"](None) + assert rendered is None + + def test_get_dag_run_conf(self): + dag_run_conf = { + "1": "string", + "2": b"bytes", + "3": 123, + "4": "à".encode("latin"), + "5": datetime(2023, 1, 1), + } + expected_encoded_dag_run_conf = ( + '{"1": "string", "2": "bytes", "3": 123, "4": "à", "5": "2023-01-01T00:00:00+00:00"}' + ) + encoded_dag_run_conf, conf_is_json = utils.get_dag_run_conf( + dag_run_conf, json_encoder=utils_json.WebEncoder + ) + assert expected_encoded_dag_run_conf == encoded_dag_run_conf + + def test_json_f_webencoder(self): + dag_run_conf = { + "1": "string", + "2": b"bytes", + "3": 123, + "4": "à".encode("latin"), + "5": datetime(2023, 1, 1), + } + expected_encoded_dag_run_conf = ( + # HTML sanitization is insane + '{"1": "string", "2": "bytes", "3": 123, "4": "\\u00e0", "5": "2023-01-01T00:00:00+00:00"}' + ) + expected_markup = Markup("{}").format(expected_encoded_dag_run_conf) + + formatter = json_f("conf") + dagrun = Mock() + dagrun.get = Mock(return_value=dag_run_conf) + + assert formatter(dagrun) == expected_markup + + +class TestWrappedMarkdown: + def test_wrapped_markdown_with_docstring_curly_braces(self): + rendered = wrapped_markdown("{braces}", css_class="a_class") + assert ( + """

    {braces}

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_some_markdown(self): + rendered = wrapped_markdown( + """*italic* + **bold** + """, + css_class="a_class", + ) + + assert ( + """

    italic +bold

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_table(self): + rendered = wrapped_markdown( + """ +| Job | Duration | +| ----------- | ----------- | +| ETL | 14m | +""" + ) + + assert ( + """
    + + + + + + + + + + + + +
    JobDuration
    ETL14m
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_indented_lines(self): + rendered = wrapped_markdown( + """ + # header + 1st line + 2nd line + """ + ) + + assert ( + """

    header

    \n

    1st line\n2nd line

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_raw_code_block(self): + rendered = wrapped_markdown( + """\ + # Markdown code block + + Inline `code` works well. + + Code block + does not + respect + newlines + + """ + ) + + assert ( + """

    Markdown code block

    +

    Inline code works well.

    +
    Code block\ndoes not\nrespect\nnewlines\n
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_nested_list(self): + rendered = wrapped_markdown( + """ + ### Docstring with a code block + + - And + - A nested list + """ + ) + + assert ( + """

    Docstring with a code block

    +
      +
    • And +
        +
      • A nested list
      • +
      +
    • +
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_collapsible_section(self): + with conf_vars({("webserver", "allow_raw_html_descriptions"): "true"}): + rendered = wrapped_markdown( + """ +# A collapsible section with markdown +
    + Click to expand! + + ## Heading + 1. A numbered + 2. list + * With some + * Sub bullets +
    + """ + ) + + assert ( + """

    A collapsible section with markdown

    +
    + Click to expand! +

    Heading

    +
      +
    1. A numbered
    2. +
    3. list +
        +
      • With some
      • +
      • Sub bullets
      • +
      +
    4. +
    +
    +
    """ + == rendered + ) + + @pytest.mark.parametrize("allow_html", [False, True]) + def test_wrapped_markdown_with_raw_html(self, allow_html): + with conf_vars({("webserver", "allow_raw_html_descriptions"): str(allow_html)}): + HTML = "test raw HTML" + rendered = wrapped_markdown(HTML) + if allow_html: + assert HTML in rendered + else: + from markupsafe import escape + + assert escape(HTML) in rendered + + +@pytest.mark.db_test +def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, session): + interface = DagRunCustomSQLAInterface(obj=DagRun, session=session) + dag_ids = (f"test_dag_{x}" for x in range(1, 4)) + dates = (pendulum.datetime(2023, 1, x) for x in range(1, 4)) + for dag_id, date in itertools.product(dag_ids, dates): + with dag_maker(dag_id=dag_id) as dag: + dag.create_dagrun( + execution_date=date, state="running", run_type="scheduled", data_interval=(date, date) + ) + dag_runs = session.query(DagRun).all() + assert len(dag_runs) == 9 + assert len(set(x.run_id for x in dag_runs)) == 3 + run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00" + # we have 3 runs with this same run_id + assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3 + # each is a different dag + + # if we delete one, it shouldn't delete the others + one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete) + assert interface.delete(item=one_run) is True + session.commit() + dag_runs = session.query(DagRun).all() + # we should have one fewer dag run now + assert len(dag_runs) == 8 + + # now let's try multi delete + run_id_for_multi_delete = "scheduled__2023-01-02T00:00:00+00:00" + # verify we have 3 + runs_of_interest = [x for x in dag_runs if x.run_id == run_id_for_multi_delete] + assert len(runs_of_interest) == 3 + # and that each is different dag + assert len(set(x.dag_id for x in dag_runs)) == 3 + + to_delete = runs_of_interest[:2] + # now try multi delete + assert interface.delete_all(items=to_delete) is True + session.commit() + dag_runs = session.query(DagRun).all() + assert len(dag_runs) == 6 + assert len(set(x.dag_id for x in dag_runs)) == 3 + assert len(set(x.run_id for x in dag_runs)) == 3 diff --git a/tests/www/test_utils_LOCAL_289691.py b/tests/www/test_utils_LOCAL_289691.py new file mode 100644 index 00000000000000..a90d9246998d6f --- /dev/null +++ b/tests/www/test_utils_LOCAL_289691.py @@ -0,0 +1,656 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import itertools +import re +import time +from datetime import datetime +from unittest.mock import Mock +from urllib.parse import parse_qs + +import pendulum +import pytest +from bs4 import BeautifulSoup +from flask_appbuilder.models.sqla.filters import get_field_setup_query, set_value_to_type +from markupsafe import Markup +from sqlalchemy.orm import Query + +from airflow.models import DagRun +from airflow.utils import json as utils_json +from airflow.www import utils +from airflow.www.utils import CustomSQLAInterface, DagRunCustomSQLAInterface, json_f, wrapped_markdown +from tests.test_utils.config import conf_vars + + +class TestUtils: + def check_generate_pages_html( + self, + current_page, + total_pages, + window=7, + check_middle=False, + sorting_key=None, + sorting_direction=None, + ): + extra_links = 4 # first, prev, next, last + search = "'>\"/>" + if sorting_key and sorting_direction: + html_str = utils.generate_pages( + current_page, + total_pages, + search=search, + sorting_key=sorting_key, + sorting_direction=sorting_direction, + ) + else: + html_str = utils.generate_pages(current_page, total_pages, search=search) + + assert search not in html_str, "The raw search string shouldn't appear in the output" + assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" in html_str + + assert callable(html_str.__html__), "Should return something that is HTML-escaping aware" + + dom = BeautifulSoup(html_str, "html.parser") + assert dom is not None + + ulist = dom.ul + ulist_items = ulist.find_all("li") + assert min(window, total_pages) + extra_links == len(ulist_items) + + page_items = ulist_items[2:-2] + mid = len(page_items) // 2 + all_nodes = [] + pages = [] + + if sorting_key and sorting_direction: + last_page = total_pages - 1 + + if current_page <= mid or total_pages < window: + pages = list(range(min(total_pages, window))) + elif mid < current_page < last_page - mid: + pages = list(range(current_page - mid, current_page + mid + 1)) + else: + pages = list(range(total_pages - window, last_page + 1)) + + pages.append(last_page + 1) + pages.sort(reverse=True if sorting_direction == "desc" else False) + + for i, item in enumerate(page_items): + a_node = item.a + href_link = a_node["href"] + node_text = a_node.string + all_nodes.append(node_text) + if node_text == str(current_page + 1): + if check_middle: + assert mid == i + assert "javascript:void(0)" == href_link + assert "active" in item["class"] + else: + assert re.search(r"^\?", href_link), "Link is page-relative" + query = parse_qs(href_link[1:]) + assert query["page"] == [str(int(node_text) - 1)] + assert query["search"] == [search] + + if sorting_key and sorting_direction: + if pages[0] == 0: + pages = [str(page) for page in pages[1:]] + + assert pages == all_nodes + + def test_generate_pager_current_start(self): + self.check_generate_pages_html(current_page=0, total_pages=6) + + def test_generate_pager_current_middle(self): + self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True) + + def test_generate_pager_current_end(self): + self.check_generate_pages_html(current_page=38, total_pages=39) + + def test_generate_pager_current_start_with_sorting(self): + self.check_generate_pages_html( + current_page=0, total_pages=4, sorting_key="dag_id", sorting_direction="asc" + ) + + def test_params_no_values(self): + """Should return an empty string if no params are passed""" + assert "" == utils.get_params() + + def test_params_search(self): + assert "search=bash_" == utils.get_params(search="bash_") + + def test_params_none_and_zero(self): + query_str = utils.get_params(a=0, b=None, c="true") + # The order won't be consistent, but that doesn't affect behaviour of a browser + pairs = sorted(query_str.split("&")) + assert ["a=0", "c=true"] == pairs + + def test_params_all(self): + query = utils.get_params(tags=["tag1", "tag2"], status="active", page=3, search="bash_") + assert { + "tags": ["tag1", "tag2"], + "page": ["3"], + "search": ["bash_"], + "status": ["active"], + } == parse_qs(query) + + def test_params_escape(self): + assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" == utils.get_params( + search="'>\"/>" + ) + + def test_state_token(self): + # It's shouldn't possible to set these odd values anymore, but lets + # ensure they are escaped! + html = str(utils.state_token("")) + + assert "<script>alert(1)</script>" in html + assert "" not in html + + def test_nobr_f(self): + attr = {"attr_name": "attribute"} + f = attr.get("attr_name") + expected_markup = Markup("{}").format(f) + + nobr = utils.nobr_f("attr_name") + result_markup = nobr(attr) + + assert result_markup == expected_markup + + def test_nobr_f_empty_attr(self): + attr = {"attr_name": ""} + f = attr.get("attr_name") + expected_markup = Markup("{}").format(f) + + nobr = utils.nobr_f("attr_name") + result_markup = nobr(attr) + + assert result_markup == expected_markup + + def test_nobr_f_missing_attr(self): + attr = {} + f = None + expected_markup = Markup("{}").format(f) + + nobr = utils.nobr_f("attr_name") + result_markup = nobr(attr) + + assert result_markup == expected_markup + + def test_epoch(self): + test_datetime = datetime(2024, 6, 19, 12, 0, 0) + result = utils.epoch(test_datetime) + epoch_time = result[0] + + expected_epoch_time = int(time.mktime(test_datetime.timetuple())) * 1000 + + assert epoch_time == expected_epoch_time + + @pytest.mark.db_test + def test_make_cache_key(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context( + "/test/path", query_string={"key1": "value1", "key2": "value2"} + ): + expected_args = str(hash(frozenset({"key1": "value1", "key2": "value2"}.items()))) + expected_cache_key = ("/test/path" + expected_args).encode("ascii", "ignore") + result_cache_key = utils.make_cache_key() + assert result_cache_key == expected_cache_key + + @pytest.mark.db_test + def test_task_instance_link(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str( + utils.task_instance_link( + {"dag_id": "", "task_id": "", "map_index": 1, "execution_date": datetime.now()} + ) + ) + + html_map_index_none = str( + utils.task_instance_link( + {"dag_id": "", "task_id": "", "map_index": -1, "execution_date": datetime.now()} + ) + ) + + assert "%3Ca%261%3E" in html + assert "%3Cb2%3E" in html + assert "map_index" in html + assert "" not in html + assert "" not in html + + assert "%3Ca%261%3E" in html_map_index_none + assert "%3Cb2%3E" in html_map_index_none + assert "map_index" not in html_map_index_none + assert "" not in html_map_index_none + assert "" not in html_map_index_none + + @pytest.mark.db_test + def test_dag_link(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) + + assert "%3Ca%261%3E" in html + assert "" not in html + + @pytest.mark.db_test + def test_dag_link_when_dag_is_none(self): + """Test that when there is no dag_id, dag_link does not contain hyperlink""" + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str(utils.dag_link({})) + + assert "None" in html + assert "
    ", "run_id": "", "execution_date": datetime.now()}) + ) + + assert "%3Ca%261%3E" in html + assert "%3Cb2%3E" in html + assert "" not in html + assert "" not in html + + +class TestAttrRenderer: + def setup_method(self): + self.attr_renderer = utils.get_attr_renderer() + + def test_python_callable(self): + def example_callable(unused_self): + print("example") + + rendered = self.attr_renderer["python_callable"](example_callable) + assert ""example"" in rendered + + def test_python_callable_none(self): + rendered = self.attr_renderer["python_callable"](None) + assert "" == rendered + + def test_markdown(self): + markdown = "* foo\n* bar" + rendered = self.attr_renderer["doc_md"](markdown) + assert "
  • foo
  • " in rendered + assert "
  • bar
  • " in rendered + + def test_markdown_none(self): + rendered = self.attr_renderer["doc_md"](None) + assert rendered is None + + def test_get_dag_run_conf(self): + dag_run_conf = { + "1": "string", + "2": b"bytes", + "3": 123, + "4": "à".encode("latin"), + "5": datetime(2023, 1, 1), + } + expected_encoded_dag_run_conf = ( + '{"1": "string", "2": "bytes", "3": 123, "4": "à", "5": "2023-01-01T00:00:00+00:00"}' + ) + encoded_dag_run_conf, conf_is_json = utils.get_dag_run_conf( + dag_run_conf, json_encoder=utils_json.WebEncoder + ) + assert expected_encoded_dag_run_conf == encoded_dag_run_conf + + def test_encode_dag_run_none(self): + no_dag_run_result = utils.encode_dag_run(None) + assert no_dag_run_result is None + + def test_json_f_webencoder(self): + dag_run_conf = { + "1": "string", + "2": b"bytes", + "3": 123, + "4": "à".encode("latin"), + "5": datetime(2023, 1, 1), + } + expected_encoded_dag_run_conf = ( + # HTML sanitization is insane + '{"1": "string", "2": "bytes", "3": 123, "4": "\\u00e0", "5": "2023-01-01T00:00:00+00:00"}' + ) + expected_markup = Markup("{}").format(expected_encoded_dag_run_conf) + + formatter = json_f("conf") + dagrun = Mock() + dagrun.get = Mock(return_value=dag_run_conf) + + assert formatter(dagrun) == expected_markup + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_get_sensitive_variables_fields(): + with pytest.warns(DeprecationWarning) as warning: + result = utils.get_sensitive_variables_fields() + + # assert deprecation warning + assert len(warning) == 1 + assert "This function is deprecated." in str(warning[-1].message) + + from airflow.utils.log.secrets_masker import get_sensitive_variables_fields + + expected_result = get_sensitive_variables_fields() + assert result == expected_result + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_should_hide_value_for_key(): + key_name = "key" + + with pytest.warns(DeprecationWarning) as warning: + result = utils.should_hide_value_for_key(key_name) + + # assert deprecation warning + assert len(warning) == 1 + assert "This function is deprecated." in str(warning[-1].message) + + from airflow.utils.log.secrets_masker import should_hide_value_for_key + + expected_result = should_hide_value_for_key(key_name) + assert result == expected_result + + +class TestWrappedMarkdown: + def test_wrapped_markdown_with_docstring_curly_braces(self): + rendered = wrapped_markdown("{braces}", css_class="a_class") + assert ( + """

    {braces}

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_some_markdown(self): + rendered = wrapped_markdown( + """*italic* + **bold** + """, + css_class="a_class", + ) + + assert ( + """

    italic +bold

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_table(self): + rendered = wrapped_markdown( + """ +| Job | Duration | +| ----------- | ----------- | +| ETL | 14m | +""" + ) + + assert ( + """
    + + + + + + + + + + + + +
    JobDuration
    ETL14m
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_indented_lines(self): + rendered = wrapped_markdown( + """ + # header + 1st line + 2nd line + """ + ) + + assert ( + """

    header

    \n

    1st line\n2nd line

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_raw_code_block(self): + rendered = wrapped_markdown( + """\ + # Markdown code block + + Inline `code` works well. + + Code block + does not + respect + newlines + + """ + ) + + assert ( + """

    Markdown code block

    +

    Inline code works well.

    +
    Code block\ndoes not\nrespect\nnewlines\n
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_nested_list(self): + rendered = wrapped_markdown( + """ + ### Docstring with a code block + + - And + - A nested list + """ + ) + + assert ( + """

    Docstring with a code block

    +
      +
    • And +
        +
      • A nested list
      • +
      +
    • +
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_collapsible_section(self): + with conf_vars({("webserver", "allow_raw_html_descriptions"): "true"}): + rendered = wrapped_markdown( + """ +# A collapsible section with markdown +
    + Click to expand! + + ## Heading + 1. A numbered + 2. list + * With some + * Sub bullets +
    + """ + ) + + assert ( + """

    A collapsible section with markdown

    +
    + Click to expand! +

    Heading

    +
      +
    1. A numbered
    2. +
    3. list +
        +
      • With some
      • +
      • Sub bullets
      • +
      +
    4. +
    +
    +
    """ + == rendered + ) + + @pytest.mark.parametrize("allow_html", [False, True]) + def test_wrapped_markdown_with_raw_html(self, allow_html): + with conf_vars({("webserver", "allow_raw_html_descriptions"): str(allow_html)}): + HTML = "test raw HTML" + rendered = wrapped_markdown(HTML) + if allow_html: + assert HTML in rendered + else: + from markupsafe import escape + + assert escape(HTML) in rendered + + +class TestFilter: + def setup_method(self): + self.mock_datamodel = Mock() + self.mock_query = Mock(spec=Query) + self.mock_column_name = "test_column" + + def test_filter_is_null_apply(self): + filter_is_null = utils.FilterIsNull(datamodel=self.mock_datamodel, column_name=self.mock_column_name) + + self.mock_query, mock_field = get_field_setup_query( + self.mock_query, self.mock_datamodel, self.mock_column_name + ) + mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) + + result_query_filter = filter_is_null.apply(self.mock_query, None) + self.mock_query.filter.assert_called_once_with(mock_field == mock_value) + + expected_query_filter = self.mock_query.filter(mock_field == mock_value) + + assert result_query_filter == expected_query_filter + + def test_filter_is_not_null_apply(self): + filter_is_not_null = utils.FilterIsNotNull( + datamodel=self.mock_datamodel, column_name=self.mock_column_name + ) + + self.mock_query, mock_field = get_field_setup_query( + self.mock_query, self.mock_datamodel, self.mock_column_name + ) + mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) + + result_query_filter = filter_is_not_null.apply(self.mock_query, None) + self.mock_query.filter.assert_called_once_with(mock_field != mock_value) + + expected_query_filter = self.mock_query.filter(mock_field != mock_value) + + assert result_query_filter == expected_query_filter + + def test_filter_gte_none_value_apply(self): + filter_gte = utils.FilterGreaterOrEqual( + datamodel=self.mock_datamodel, column_name=self.mock_column_name + ) + + self.mock_query, mock_field = get_field_setup_query( + self.mock_query, self.mock_datamodel, self.mock_column_name + ) + mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) + + result_query_filter = filter_gte.apply(self.mock_query, mock_value) + + assert result_query_filter == self.mock_query + + def test_filter_lte_none_value_apply(self): + filter_lte = utils.FilterSmallerOrEqual( + datamodel=self.mock_datamodel, column_name=self.mock_column_name + ) + + self.mock_query, mock_field = get_field_setup_query( + self.mock_query, self.mock_datamodel, self.mock_column_name + ) + mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) + + result_query_filter = filter_lte.apply(self.mock_query, mock_value) + + assert result_query_filter == self.mock_query + + +@pytest.mark.db_test +def test_get_col_default_not_existing(session): + interface = CustomSQLAInterface(obj=DagRun, session=session) + default_value = interface.get_col_default("column_not_existing") + assert default_value is None + + +@pytest.mark.db_test +def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, session): + interface = DagRunCustomSQLAInterface(obj=DagRun, session=session) + dag_ids = (f"test_dag_{x}" for x in range(1, 4)) + dates = (pendulum.datetime(2023, 1, x) for x in range(1, 4)) + for dag_id, date in itertools.product(dag_ids, dates): + with dag_maker(dag_id=dag_id) as dag: + dag.create_dagrun( + execution_date=date, state="running", run_type="scheduled", data_interval=(date, date) + ) + dag_runs = session.query(DagRun).all() + assert len(dag_runs) == 9 + assert len(set(x.run_id for x in dag_runs)) == 3 + run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00" + # we have 3 runs with this same run_id + assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3 + # each is a different dag + + # if we delete one, it shouldn't delete the others + one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete) + assert interface.delete(item=one_run) is True + session.commit() + dag_runs = session.query(DagRun).all() + # we should have one fewer dag run now + assert len(dag_runs) == 8 + + # now let's try multi delete + run_id_for_multi_delete = "scheduled__2023-01-02T00:00:00+00:00" + # verify we have 3 + runs_of_interest = [x for x in dag_runs if x.run_id == run_id_for_multi_delete] + assert len(runs_of_interest) == 3 + # and that each is different dag + assert len(set(x.dag_id for x in dag_runs)) == 3 + + to_delete = runs_of_interest[:2] + # now try multi delete + assert interface.delete_all(items=to_delete) is True + session.commit() + dag_runs = session.query(DagRun).all() + assert len(dag_runs) == 6 + assert len(set(x.dag_id for x in dag_runs)) == 3 + assert len(set(x.run_id for x in dag_runs)) == 3 diff --git a/tests/www/test_utils_REMOTE_289691.py b/tests/www/test_utils_REMOTE_289691.py new file mode 100644 index 00000000000000..640a4a10de2ae8 --- /dev/null +++ b/tests/www/test_utils_REMOTE_289691.py @@ -0,0 +1,531 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import itertools +import re +from datetime import datetime +from unittest.mock import Mock +from urllib.parse import parse_qs + +import pendulum +import pytest +from bs4 import BeautifulSoup +from flask_wtf import FlaskForm +from markupsafe import Markup +from wtforms.fields import StringField, TextAreaField + +from airflow.models import DagRun +from airflow.utils import json as utils_json +from airflow.www import utils +from airflow.www.utils import DagRunCustomSQLAInterface, json_f, wrapped_markdown +from airflow.www.widgets import AirflowDateTimePickerROWidget, BS3TextAreaROWidget, BS3TextFieldROWidget +from tests.test_utils.config import conf_vars + + +class TestUtils: + def check_generate_pages_html( + self, + current_page, + total_pages, + window=7, + check_middle=False, + sorting_key=None, + sorting_direction=None, + ): + extra_links = 4 # first, prev, next, last + search = "'>\"/>" + if sorting_key and sorting_direction: + html_str = utils.generate_pages( + current_page, + total_pages, + search=search, + sorting_key=sorting_key, + sorting_direction=sorting_direction, + ) + else: + html_str = utils.generate_pages(current_page, total_pages, search=search) + + assert search not in html_str, "The raw search string shouldn't appear in the output" + assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" in html_str + + assert callable(html_str.__html__), "Should return something that is HTML-escaping aware" + + dom = BeautifulSoup(html_str, "html.parser") + assert dom is not None + + ulist = dom.ul + ulist_items = ulist.find_all("li") + assert min(window, total_pages) + extra_links == len(ulist_items) + + page_items = ulist_items[2:-2] + mid = len(page_items) // 2 + all_nodes = [] + pages = [] + + if sorting_key and sorting_direction: + last_page = total_pages - 1 + + if current_page <= mid or total_pages < window: + pages = list(range(min(total_pages, window))) + elif mid < current_page < last_page - mid: + pages = list(range(current_page - mid, current_page + mid + 1)) + else: + pages = list(range(total_pages - window, last_page + 1)) + + pages.append(last_page + 1) + pages.sort(reverse=True if sorting_direction == "desc" else False) + + for i, item in enumerate(page_items): + a_node = item.a + href_link = a_node["href"] + node_text = a_node.string + all_nodes.append(node_text) + if node_text == str(current_page + 1): + if check_middle: + assert mid == i + assert "javascript:void(0)" == href_link + assert "active" in item["class"] + else: + assert re.search(r"^\?", href_link), "Link is page-relative" + query = parse_qs(href_link[1:]) + assert query["page"] == [str(int(node_text) - 1)] + assert query["search"] == [search] + + if sorting_key and sorting_direction: + if pages[0] == 0: + pages = [str(page) for page in pages[1:]] + + assert pages == all_nodes + + def test_generate_pager_current_start(self): + self.check_generate_pages_html(current_page=0, total_pages=6) + + def test_generate_pager_current_middle(self): + self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True) + + def test_generate_pager_current_end(self): + self.check_generate_pages_html(current_page=38, total_pages=39) + + def test_generate_pager_current_start_with_sorting(self): + self.check_generate_pages_html( + current_page=0, total_pages=4, sorting_key="dag_id", sorting_direction="asc" + ) + + def test_params_no_values(self): + """Should return an empty string if no params are passed""" + assert "" == utils.get_params() + + def test_params_search(self): + assert "search=bash_" == utils.get_params(search="bash_") + + def test_params_none_and_zero(self): + query_str = utils.get_params(a=0, b=None, c="true") + # The order won't be consistent, but that doesn't affect behaviour of a browser + pairs = sorted(query_str.split("&")) + assert ["a=0", "c=true"] == pairs + + def test_params_all(self): + query = utils.get_params(tags=["tag1", "tag2"], status="active", page=3, search="bash_") + assert { + "tags": ["tag1", "tag2"], + "page": ["3"], + "search": ["bash_"], + "status": ["active"], + } == parse_qs(query) + + def test_params_escape(self): + assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" == utils.get_params( + search="'>\"/>" + ) + + def test_state_token(self): + # It's shouldn't possible to set these odd values anymore, but lets + # ensure they are escaped! + html = str(utils.state_token("")) + + assert "<script>alert(1)</script>" in html + assert "" not in html + + @pytest.mark.db_test + def test_task_instance_link(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str( + utils.task_instance_link( + {"dag_id": "", "task_id": "", "execution_date": datetime.now()} + ) + ) + + assert "%3Ca%261%3E" in html + assert "%3Cb2%3E" in html + assert "" not in html + assert "" not in html + + @pytest.mark.db_test + def test_dag_link(self): + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) + + assert "%3Ca%261%3E" in html + assert "" not in html + + @pytest.mark.db_test + def test_dag_link_when_dag_is_none(self): + """Test that when there is no dag_id, dag_link does not contain hyperlink""" + from airflow.www.app import cached_app + + with cached_app(testing=True).test_request_context(): + html = str(utils.dag_link({})) + + assert "None" in html + assert "
    ", "run_id": "", "execution_date": datetime.now()}) + ) + + assert "%3Ca%261%3E" in html + assert "%3Cb2%3E" in html + assert "" not in html + assert "" not in html + + +class TestAttrRenderer: + def setup_method(self): + self.attr_renderer = utils.get_attr_renderer() + + def test_python_callable(self): + def example_callable(unused_self): + print("example") + + rendered = self.attr_renderer["python_callable"](example_callable) + assert ""example"" in rendered + + def test_python_callable_none(self): + rendered = self.attr_renderer["python_callable"](None) + assert "" == rendered + + def test_markdown(self): + markdown = "* foo\n* bar" + rendered = self.attr_renderer["doc_md"](markdown) + assert "
  • foo
  • " in rendered + assert "
  • bar
  • " in rendered + + def test_markdown_none(self): + rendered = self.attr_renderer["doc_md"](None) + assert rendered is None + + def test_get_dag_run_conf(self): + dag_run_conf = { + "1": "string", + "2": b"bytes", + "3": 123, + "4": "à".encode("latin"), + "5": datetime(2023, 1, 1), + } + expected_encoded_dag_run_conf = ( + '{"1": "string", "2": "bytes", "3": 123, "4": "à", "5": "2023-01-01T00:00:00+00:00"}' + ) + encoded_dag_run_conf, conf_is_json = utils.get_dag_run_conf( + dag_run_conf, json_encoder=utils_json.WebEncoder + ) + assert expected_encoded_dag_run_conf == encoded_dag_run_conf + + def test_json_f_webencoder(self): + dag_run_conf = { + "1": "string", + "2": b"bytes", + "3": 123, + "4": "à".encode("latin"), + "5": datetime(2023, 1, 1), + } + expected_encoded_dag_run_conf = ( + # HTML sanitization is insane + '{"1": "string", "2": "bytes", "3": 123, "4": "\\u00e0", "5": "2023-01-01T00:00:00+00:00"}' + ) + expected_markup = Markup("{}").format(expected_encoded_dag_run_conf) + + formatter = json_f("conf") + dagrun = Mock() + dagrun.get = Mock(return_value=dag_run_conf) + + assert formatter(dagrun) == expected_markup + + +class TestWrappedMarkdown: + def test_wrapped_markdown_with_docstring_curly_braces(self): + rendered = wrapped_markdown("{braces}", css_class="a_class") + assert ( + """

    {braces}

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_some_markdown(self): + rendered = wrapped_markdown( + """*italic* + **bold** + """, + css_class="a_class", + ) + + assert ( + """

    italic +bold

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_table(self): + rendered = wrapped_markdown( + """ +| Job | Duration | +| ----------- | ----------- | +| ETL | 14m | +""" + ) + + assert ( + """
    + + + + + + + + + + + + +
    JobDuration
    ETL14m
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_indented_lines(self): + rendered = wrapped_markdown( + """ + # header + 1st line + 2nd line + """ + ) + + assert ( + """

    header

    \n

    1st line\n2nd line

    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_raw_code_block(self): + rendered = wrapped_markdown( + """\ + # Markdown code block + + Inline `code` works well. + + Code block + does not + respect + newlines + + """ + ) + + assert ( + """

    Markdown code block

    +

    Inline code works well.

    +
    Code block\ndoes not\nrespect\nnewlines\n
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_nested_list(self): + rendered = wrapped_markdown( + """ + ### Docstring with a code block + + - And + - A nested list + """ + ) + + assert ( + """

    Docstring with a code block

    +
      +
    • And +
        +
      • A nested list
      • +
      +
    • +
    +
    """ + == rendered + ) + + def test_wrapped_markdown_with_collapsible_section(self): + with conf_vars({("webserver", "allow_raw_html_descriptions"): "true"}): + rendered = wrapped_markdown( + """ +# A collapsible section with markdown +
    + Click to expand! + + ## Heading + 1. A numbered + 2. list + * With some + * Sub bullets +
    + """ + ) + + assert ( + """

    A collapsible section with markdown

    +
    + Click to expand! +

    Heading

    +
      +
    1. A numbered
    2. +
    3. list +
        +
      • With some
      • +
      • Sub bullets
      • +
      +
    4. +
    +
    +
    """ + == rendered + ) + + @pytest.mark.parametrize("allow_html", [False, True]) + def test_wrapped_markdown_with_raw_html(self, allow_html): + with conf_vars({("webserver", "allow_raw_html_descriptions"): str(allow_html)}): + HTML = "test raw HTML" + rendered = wrapped_markdown(HTML) + if allow_html: + assert HTML in rendered + else: + from markupsafe import escape + + assert escape(HTML) in rendered + + +@pytest.mark.db_test +def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, session): + interface = DagRunCustomSQLAInterface(obj=DagRun, session=session) + dag_ids = (f"test_dag_{x}" for x in range(1, 4)) + dates = (pendulum.datetime(2023, 1, x) for x in range(1, 4)) + for dag_id, date in itertools.product(dag_ids, dates): + with dag_maker(dag_id=dag_id) as dag: + dag.create_dagrun( + execution_date=date, state="running", run_type="scheduled", data_interval=(date, date) + ) + dag_runs = session.query(DagRun).all() + assert len(dag_runs) == 9 + assert len(set(x.run_id for x in dag_runs)) == 3 + run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00" + # we have 3 runs with this same run_id + assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3 + # each is a different dag + + # if we delete one, it shouldn't delete the others + one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete) + assert interface.delete(item=one_run) is True + session.commit() + dag_runs = session.query(DagRun).all() + # we should have one fewer dag run now + assert len(dag_runs) == 8 + + # now let's try multi delete + run_id_for_multi_delete = "scheduled__2023-01-02T00:00:00+00:00" + # verify we have 3 + runs_of_interest = [x for x in dag_runs if x.run_id == run_id_for_multi_delete] + assert len(runs_of_interest) == 3 + # and that each is different dag + assert len(set(x.dag_id for x in dag_runs)) == 3 + + to_delete = runs_of_interest[:2] + # now try multi delete + assert interface.delete_all(items=to_delete) is True + session.commit() + dag_runs = session.query(DagRun).all() + assert len(dag_runs) == 6 + assert len(set(x.dag_id for x in dag_runs)) == 3 + assert len(set(x.run_id for x in dag_runs)) == 3 + + +@pytest.fixture +def app(): + from flask import Flask + + app = Flask(__name__) + app.config["WTF_CSRF_ENABLED"] = False + app.config["SECRET_KEY"] = "secret" + with app.app_context(): + yield app + + +class TestWidgets: + def test_airflow_datetime_picker_ro_widget(self, app): + class TestForm(FlaskForm): + datetime_field = StringField(widget=AirflowDateTimePickerROWidget()) + + form = TestForm() + field = form.datetime_field + + html_output = field() + + assert 'readonly="true"' in html_output + assert "input-group datetime datetimepicker" in html_output + + def test_bs3_text_field_ro_widget(self, app): + class TestForm(FlaskForm): + text_field = StringField(widget=BS3TextFieldROWidget()) + + form = TestForm() + field = form.text_field + + html_output = field() + + assert 'readonly="true"' in html_output + assert "form-control" in html_output + + def test_bs3_text_area_ro_widget(self, app): + class TestForm(FlaskForm): + textarea_field = TextAreaField(widget=BS3TextAreaROWidget()) + + form = TestForm() + field = form.textarea_field + + html_output = field() + + assert 'readonly="true"' in html_output + assert "form-control" in html_output diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py index 884401a16bc914..e92cc9d80453df 100644 --- a/tests/www/test_validators.py +++ b/tests/www/test_validators.py @@ -18,11 +18,12 @@ from __future__ import annotations from unittest import mock + import pytest from airflow.www import validators - + class TestGreaterEqualThan: def setup_method(self): self.form_field_mock = mock.MagicMock(data="2017-05-06") @@ -119,6 +120,7 @@ def test_validation_raises_custom_message(self): message="Invalid JSON: {}", ) + class TestValidKey: def setup_method(self): self.form_field_mock = mock.MagicMock(data="valid_key") @@ -163,4 +165,3 @@ def test_read_only_validator(self): validator = validators.ReadOnly() assert validator(self.form_mock, self.form_read_only_field_mock) is None assert self.form_read_only_field_mock.flags.readonly is True - \ No newline at end of file diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 2850f4e21e8a21..3d13dea4d1248f 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -34,14 +34,13 @@ from airflow.utils.state import DagRunState, TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.types import DagRunType -from airflow.utils.session import create_session -from airflow.www.views import _safe_parse_datetime, dag_to_grid, get_date_time_num_runs_dag_runs_form_data +from airflow.www.views import dag_to_grid from tests.test_utils.asserts import assert_queries_count from tests.test_utils.db import clear_db_datasets, clear_db_runs from tests.test_utils.mock_operators import MockOperator pytestmark = pytest.mark.db_test - + if TYPE_CHECKING: from airflow.models.dagrun import DagRun diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index f4f507048dd86d..4e4b8d27afc836 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -1049,26 +1049,18 @@ def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, admin_client) resp = admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 - + def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, admin_client): """Test the get_date_time_num_runs_dag_runs_form_data function.""" from airflow.www.views import get_date_time_num_runs_dag_runs_form_data - - execution_date = pendulum.DateTime(2024, 1, 1, 0, 0, 0, tzinfo=pendulum.UTC) + + execution_date = pendulum.now(tz="UTC") with dag_maker( dag_id="test_get_date_time_num_runs_dag_runs_form_data", start_date=execution_date, ) as dag: BashOperator(task_id="task_1", bash_command="echo test") - dag_run = dag_maker.create_dagrun( - run_id="test_dagrun_id", - run_type=DagRunType.SCHEDULED, - execution_date=execution_date, - start_date=execution_date, - state=DagRunState.RUNNING, - ) - with unittest.mock.patch.object(app, "dag_bag") as mocked_dag_bag: mocked_dag_bag.get_dag.return_value = dag url = f"/dags/{dag.dag_id}/graph" @@ -1082,9 +1074,9 @@ def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, ad base_date = pendulum.parse(data["base_date"].isoformat()) assert dttm.date() == execution_date.date() - assert dttm.time() == _safe_parse_datetime(execution_date.time().isoformat()).time() + assert dttm.time().hour == _safe_parse_datetime(execution_date.time().isoformat()).time().hour + assert dttm.time().minute == _safe_parse_datetime(execution_date.time().isoformat()).time().minute assert base_date.date() == execution_date.date() - assert data["execution_date"] == execution_date.isoformat() def test_task_instances(admin_client): From 6160f641f2b6a68b881cd78a696c7cff43a6d00e Mon Sep 17 00:00:00 2001 From: andyjianzhou Date: Wed, 10 Jul 2024 12:34:40 -0600 Subject: [PATCH 08/21] Fixing conflicts --- tests/models/test_xcom_arg_map.py | 41 +++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index 90085ba6df644b..b997cf024ebba3 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -20,6 +20,9 @@ import pytest from airflow.exceptions import AirflowSkipException +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap, TaskMapVariant +from airflow.operators.empty import EmptyOperator from airflow.utils.state import TaskInstanceState from airflow.utils.trigger_rule import TriggerRule @@ -188,6 +191,44 @@ def does_not_work_with_c(v): ] +def test_task_map_from_task_instance_xcom(): + task = EmptyOperator(task_id="test_task") + ti = TaskInstance(task=task, run_id="test_run", map_index=0) + ti.dag_id = "test_dag" + value = {"key1": "value1", "key2": "value2"} + + # Test case where run_id is not None + task_map = TaskMap.from_task_instance_xcom(ti, value) + assert task_map.dag_id == ti.dag_id + assert task_map.task_id == ti.task_id + assert task_map.run_id == ti.run_id + assert task_map.map_index == ti.map_index + assert task_map.length == len(value) + assert task_map.keys == list(value) + + # Test case where run_id is None + ti.run_id = None + with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): + TaskMap.from_task_instance_xcom(ti, value) + + +def test_task_map_variant(): + # Test case where keys is None + task_map = TaskMap( + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=0, + length=3, + keys=None, + ) + assert task_map.variant == TaskMapVariant.LIST + + # Test case where keys is not None + task_map.keys = ["key1", "key2"] + assert task_map.variant == TaskMapVariant.DICT + + def test_xcom_map_raise_to_skip(dag_maker, session): result = None From 8224047773f3ba55983b212ad1faa85b34652ad7 Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Mon, 17 Jun 2024 22:25:06 -0600 Subject: [PATCH 09/21] Re-added reset_dagruns mistake --- tests/www/views/test_views_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 4e4b8d27afc836..139a7e6427e3f1 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -66,7 +66,7 @@ def reset_dagruns(): @pytest.fixture(autouse=True) -def init_dagruns(app): +def init_dagruns(app, reset_dagruns): with time_machine.travel(DEFAULT_DATE, tick=False): app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id=DEFAULT_DAGRUN, From 52b3b24fe530a6c88101cf08f966b73c9638f91c Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Tue, 18 Jun 2024 12:59:59 -0600 Subject: [PATCH 10/21] Fixed linting issues --- airflow/www/views.py | 4 +++ tests/models/test_xcom_arg_map.py | 41 ----------------------------- tests/www/views/test_views_tasks.py | 16 ++++++++--- 3 files changed, 16 insertions(+), 45 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index 7ffd5ec17c4cca..8812deb87a5982 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -272,6 +272,10 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): default_dag_run = conf.getint("webserver", "default_dag_run_display_number") num_runs = www_request.args.get("num_runs", default=default_dag_run, type=int) +<<<<<<< HEAD +======= + +>>>>>>> 0b672caaf2 (Fixed linting issues) # When base_date has been rounded up because of the DateTimeField widget, we want # to use the execution_date as the starting point for our query just to ensure a diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index b997cf024ebba3..90085ba6df644b 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -20,9 +20,6 @@ import pytest from airflow.exceptions import AirflowSkipException -from airflow.models.taskinstance import TaskInstance -from airflow.models.taskmap import TaskMap, TaskMapVariant -from airflow.operators.empty import EmptyOperator from airflow.utils.state import TaskInstanceState from airflow.utils.trigger_rule import TriggerRule @@ -191,44 +188,6 @@ def does_not_work_with_c(v): ] -def test_task_map_from_task_instance_xcom(): - task = EmptyOperator(task_id="test_task") - ti = TaskInstance(task=task, run_id="test_run", map_index=0) - ti.dag_id = "test_dag" - value = {"key1": "value1", "key2": "value2"} - - # Test case where run_id is not None - task_map = TaskMap.from_task_instance_xcom(ti, value) - assert task_map.dag_id == ti.dag_id - assert task_map.task_id == ti.task_id - assert task_map.run_id == ti.run_id - assert task_map.map_index == ti.map_index - assert task_map.length == len(value) - assert task_map.keys == list(value) - - # Test case where run_id is None - ti.run_id = None - with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): - TaskMap.from_task_instance_xcom(ti, value) - - -def test_task_map_variant(): - # Test case where keys is None - task_map = TaskMap( - dag_id="test_dag", - task_id="test_task", - run_id="test_run", - map_index=0, - length=3, - keys=None, - ) - assert task_map.variant == TaskMapVariant.LIST - - # Test case where keys is not None - task_map.keys = ["key1", "key2"] - assert task_map.variant == TaskMapVariant.DICT - - def test_xcom_map_raise_to_skip(dag_maker, session): result = None diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 139a7e6427e3f1..601e4658cf45fc 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -66,7 +66,7 @@ def reset_dagruns(): @pytest.fixture(autouse=True) -def init_dagruns(app, reset_dagruns): +def init_dagruns(app): with time_machine.travel(DEFAULT_DATE, tick=False): app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id=DEFAULT_DAGRUN, @@ -1054,13 +1054,21 @@ def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, ad """Test the get_date_time_num_runs_dag_runs_form_data function.""" from airflow.www.views import get_date_time_num_runs_dag_runs_form_data - execution_date = pendulum.now(tz="UTC") + execution_date = pendulum.DateTime(2024, 1, 1, 0, 0, 0, tzinfo=pendulum.UTC) with dag_maker( dag_id="test_get_date_time_num_runs_dag_runs_form_data", start_date=execution_date, ) as dag: BashOperator(task_id="task_1", bash_command="echo test") + dag_run = dag_maker.create_dagrun( + run_id="test_dagrun_id", + run_type=DagRunType.SCHEDULED, + execution_date=execution_date, + start_date=execution_date, + state=DagRunState.RUNNING, + ) + with unittest.mock.patch.object(app, "dag_bag") as mocked_dag_bag: mocked_dag_bag.get_dag.return_value = dag url = f"/dags/{dag.dag_id}/graph" @@ -1074,9 +1082,9 @@ def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, ad base_date = pendulum.parse(data["base_date"].isoformat()) assert dttm.date() == execution_date.date() - assert dttm.time().hour == _safe_parse_datetime(execution_date.time().isoformat()).time().hour - assert dttm.time().minute == _safe_parse_datetime(execution_date.time().isoformat()).time().minute + assert dttm.time() == _safe_parse_datetime(execution_date.time().isoformat()).time() assert base_date.date() == execution_date.date() + assert data["execution_date"] == execution_date.isoformat() def test_task_instances(admin_client): From 2890acb300cbb1ae3d93e6a8b62e8799b0c440c3 Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Wed, 19 Jun 2024 12:15:32 -0600 Subject: [PATCH 11/21] linting issues --- tests/www/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index 02eb43526e1bd7..604edd6bc44936 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -705,4 +705,4 @@ class TestForm(FlaskForm): html_output = field() assert 'readonly="true"' in html_output - assert "form-control" in html_output + assert 'form-control' in html_output From 98a9a76f68f1f365656f8a6788648b5b312b2f22 Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Wed, 19 Jun 2024 23:35:51 -0600 Subject: [PATCH 12/21] Added Missing testCases for taskmap.py to 100% --- tests/models/test_xcom_arg_map.py | 41 +++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index 90085ba6df644b..65bfe479ba7aad 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -20,8 +20,12 @@ import pytest from airflow.exceptions import AirflowSkipException +from airflow.operators.empty import EmptyOperator from airflow.utils.state import TaskInstanceState from airflow.utils.trigger_rule import TriggerRule +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap, TaskMapVariant +from airflow.utils.state import TaskInstanceState pytestmark = pytest.mark.db_test @@ -188,6 +192,43 @@ def does_not_work_with_c(v): ] +def test_task_map_from_task_instance_xcom(): + task = EmptyOperator(task_id='test_task') + ti = TaskInstance(task=task, run_id='test_run', map_index=0) + ti.dag_id = 'test_dag' + value = {"key1": "value1", "key2": "value2"} + + # Test case where run_id is not None + task_map = TaskMap.from_task_instance_xcom(ti, value) + assert task_map.dag_id == ti.dag_id + assert task_map.task_id == ti.task_id + assert task_map.run_id == ti.run_id + assert task_map.map_index == ti.map_index + assert task_map.length == len(value) + assert task_map.keys == list(value) + + # Test case where run_id is None + ti.run_id = None + with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): + TaskMap.from_task_instance_xcom(ti, value) + +def test_task_map_variant(): + # Test case where keys is None + task_map = TaskMap( + dag_id='test_dag', + task_id='test_task', + run_id='test_run', + map_index=0, + length=3, + keys=None, + ) + assert task_map.variant == TaskMapVariant.LIST + + # Test case where keys is not None + task_map.keys = ["key1", "key2"] + assert task_map.variant == TaskMapVariant.DICT + + def test_xcom_map_raise_to_skip(dag_maker, session): result = None From cb8ccb136422293453ff6225f42dc79fedb75c7b Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Mon, 24 Jun 2024 14:39:02 -0600 Subject: [PATCH 13/21] Pre-commit fixes --- tests/models/test_xcom_arg_map.py | 18 +++++++++--------- tests/www/test_utils.py | 2 +- tests/www/views/test_views_grid.py | 9 +++++++-- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index 65bfe479ba7aad..b997cf024ebba3 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -20,12 +20,11 @@ import pytest from airflow.exceptions import AirflowSkipException -from airflow.operators.empty import EmptyOperator -from airflow.utils.state import TaskInstanceState -from airflow.utils.trigger_rule import TriggerRule from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap, TaskMapVariant +from airflow.operators.empty import EmptyOperator from airflow.utils.state import TaskInstanceState +from airflow.utils.trigger_rule import TriggerRule pytestmark = pytest.mark.db_test @@ -193,9 +192,9 @@ def does_not_work_with_c(v): def test_task_map_from_task_instance_xcom(): - task = EmptyOperator(task_id='test_task') - ti = TaskInstance(task=task, run_id='test_run', map_index=0) - ti.dag_id = 'test_dag' + task = EmptyOperator(task_id="test_task") + ti = TaskInstance(task=task, run_id="test_run", map_index=0) + ti.dag_id = "test_dag" value = {"key1": "value1", "key2": "value2"} # Test case where run_id is not None @@ -212,12 +211,13 @@ def test_task_map_from_task_instance_xcom(): with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): TaskMap.from_task_instance_xcom(ti, value) + def test_task_map_variant(): # Test case where keys is None task_map = TaskMap( - dag_id='test_dag', - task_id='test_task', - run_id='test_run', + dag_id="test_dag", + task_id="test_task", + run_id="test_run", map_index=0, length=3, keys=None, diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index 604edd6bc44936..02eb43526e1bd7 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -705,4 +705,4 @@ class TestForm(FlaskForm): html_output = field() assert 'readonly="true"' in html_output - assert 'form-control' in html_output + assert "form-control" in html_output diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 3d13dea4d1248f..9dc94a1fb39113 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -89,7 +89,10 @@ def mapped_task_group(arg1): def dag_with_runs(dag_without_runs): date = dag_without_runs.dag.start_date run_1 = dag_without_runs.create_dagrun( - run_id="run_1", state=DagRunState.SUCCESS, run_type=DagRunType.SCHEDULED, execution_date=date + run_id="run_1", + state=DagRunState.SUCCESS, + run_type=DagRunType.SCHEDULED, + execution_date=date ) run_2 = dag_without_runs.create_dagrun( run_id="run_2", @@ -475,7 +478,9 @@ def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): ds1_id = session.query(DatasetModel.id).filter_by(uri=datasets[0].uri).scalar() ds2_id = session.query(DatasetModel.id).filter_by(uri=datasets[1].uri).scalar() ddrq = DatasetDagRunQueue( - target_dag_id=DAG_ID, dataset_id=ds1_id, created_at=pendulum.DateTime(2022, 8, 2, tzinfo=UTC) + target_dag_id=DAG_ID, + dataset_id=ds1_id, + created_at=pendulum.DateTime(2022, 8, 2, tzinfo=UTC) ) session.add(ddrq) dataset_events = [ From d368d2791a1d27986bf78af254634c16cd49975a Mon Sep 17 00:00:00 2001 From: Andy Zhou Date: Mon, 24 Jun 2024 14:41:39 -0600 Subject: [PATCH 14/21] Pre-commit fixes --- tests/www/views/test_views_grid.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 9dc94a1fb39113..3d13dea4d1248f 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -89,10 +89,7 @@ def mapped_task_group(arg1): def dag_with_runs(dag_without_runs): date = dag_without_runs.dag.start_date run_1 = dag_without_runs.create_dagrun( - run_id="run_1", - state=DagRunState.SUCCESS, - run_type=DagRunType.SCHEDULED, - execution_date=date + run_id="run_1", state=DagRunState.SUCCESS, run_type=DagRunType.SCHEDULED, execution_date=date ) run_2 = dag_without_runs.create_dagrun( run_id="run_2", @@ -478,9 +475,7 @@ def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): ds1_id = session.query(DatasetModel.id).filter_by(uri=datasets[0].uri).scalar() ds2_id = session.query(DatasetModel.id).filter_by(uri=datasets[1].uri).scalar() ddrq = DatasetDagRunQueue( - target_dag_id=DAG_ID, - dataset_id=ds1_id, - created_at=pendulum.DateTime(2022, 8, 2, tzinfo=UTC) + target_dag_id=DAG_ID, dataset_id=ds1_id, created_at=pendulum.DateTime(2022, 8, 2, tzinfo=UTC) ) session.add(ddrq) dataset_events = [ From 44388b8a2e7932b92e5c54792b55b019113d34c5 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 25 Jun 2024 14:06:03 -0600 Subject: [PATCH 15/21] Removed unassigned dag_run functionality --- tests/www/views/test_views_tasks.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 601e4658cf45fc..8f5fcdc647ddad 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -1061,14 +1061,6 @@ def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, ad ) as dag: BashOperator(task_id="task_1", bash_command="echo test") - dag_run = dag_maker.create_dagrun( - run_id="test_dagrun_id", - run_type=DagRunType.SCHEDULED, - execution_date=execution_date, - start_date=execution_date, - state=DagRunState.RUNNING, - ) - with unittest.mock.patch.object(app, "dag_bag") as mocked_dag_bag: mocked_dag_bag.get_dag.return_value = dag url = f"/dags/{dag.dag_id}/graph" From 7ac808a201600c834cbc48e94705be2b6e7e843c Mon Sep 17 00:00:00 2001 From: andyjianzhou Date: Tue, 2 Jul 2024 17:46:43 -0600 Subject: [PATCH 16/21] Fix unittest date error --- tests/www/views/test_views_tasks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 8f5fcdc647ddad..7775cfe7a0983f 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -1054,7 +1054,7 @@ def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, ad """Test the get_date_time_num_runs_dag_runs_form_data function.""" from airflow.www.views import get_date_time_num_runs_dag_runs_form_data - execution_date = pendulum.DateTime(2024, 1, 1, 0, 0, 0, tzinfo=pendulum.UTC) + execution_date = pendulum.now(tz='UTC') with dag_maker( dag_id="test_get_date_time_num_runs_dag_runs_form_data", start_date=execution_date, @@ -1074,9 +1074,9 @@ def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, ad base_date = pendulum.parse(data["base_date"].isoformat()) assert dttm.date() == execution_date.date() - assert dttm.time() == _safe_parse_datetime(execution_date.time().isoformat()).time() + assert dttm.time().hour == _safe_parse_datetime(execution_date.time().isoformat()).time().hour + assert dttm.time().minute == _safe_parse_datetime(execution_date.time().isoformat()).time().minute assert base_date.date() == execution_date.date() - assert data["execution_date"] == execution_date.isoformat() def test_task_instances(admin_client): From dd9bd273fdec69c86ed43c5a9427f084d6a76648 Mon Sep 17 00:00:00 2001 From: andyjianzhou Date: Tue, 2 Jul 2024 23:47:05 -0600 Subject: [PATCH 17/21] Resolving Pre-commit single quotation marks --- tests/www/views/test_views_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index 7775cfe7a0983f..4e4b8d27afc836 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -1054,7 +1054,7 @@ def test_get_date_time_num_runs_dag_runs_form_data_graph_view(app, dag_maker, ad """Test the get_date_time_num_runs_dag_runs_form_data function.""" from airflow.www.views import get_date_time_num_runs_dag_runs_form_data - execution_date = pendulum.now(tz='UTC') + execution_date = pendulum.now(tz="UTC") with dag_maker( dag_id="test_get_date_time_num_runs_dag_runs_form_data", start_date=execution_date, From 4ccbe4814dbc3f59f2b2c695ce23872aff17e75d Mon Sep 17 00:00:00 2001 From: andyjianzhou Date: Fri, 12 Jul 2024 12:18:06 -0600 Subject: [PATCH 18/21] Resolving hidden conflicts after rebasing --- airflow/www/views.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index 8812deb87a5982..7ffd5ec17c4cca 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -272,10 +272,6 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): default_dag_run = conf.getint("webserver", "default_dag_run_display_number") num_runs = www_request.args.get("num_runs", default=default_dag_run, type=int) -<<<<<<< HEAD -======= - ->>>>>>> 0b672caaf2 (Fixed linting issues) # When base_date has been rounded up because of the DateTimeField widget, we want # to use the execution_date as the starting point for our query just to ensure a From 0b806add370265640eb355712981f66253877a92 Mon Sep 17 00:00:00 2001 From: andyjianzhou Date: Wed, 17 Jul 2024 11:23:47 -0600 Subject: [PATCH 19/21] fix pre-commit --- tests/www/test_utils.py | 2 +- tests/www/test_utils_BACKUP_289691.py | 708 -------------------------- tests/www/test_utils_BASE_289691.py | 479 ----------------- tests/www/test_utils_LOCAL_289691.py | 656 ------------------------ tests/www/test_utils_REMOTE_289691.py | 531 ------------------- 5 files changed, 1 insertion(+), 2375 deletions(-) delete mode 100644 tests/www/test_utils_BACKUP_289691.py delete mode 100644 tests/www/test_utils_BASE_289691.py delete mode 100644 tests/www/test_utils_LOCAL_289691.py delete mode 100644 tests/www/test_utils_REMOTE_289691.py diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index 732088a76eda33..f592b93e71d4ea 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -36,7 +36,6 @@ from airflow.models import DagRun from airflow.utils import json as utils_json from airflow.www import utils -from airflow.www.widgets import AirflowDateTimePickerROWidget, BS3TextAreaROWidget, BS3TextFieldROWidget from airflow.www.utils import ( CustomSQLAInterface, DagRunCustomSQLAInterface, @@ -44,6 +43,7 @@ json_f, wrapped_markdown, ) +from airflow.www.widgets import AirflowDateTimePickerROWidget, BS3TextAreaROWidget, BS3TextFieldROWidget from tests.test_utils.config import conf_vars diff --git a/tests/www/test_utils_BACKUP_289691.py b/tests/www/test_utils_BACKUP_289691.py deleted file mode 100644 index 02eb43526e1bd7..00000000000000 --- a/tests/www/test_utils_BACKUP_289691.py +++ /dev/null @@ -1,708 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import itertools -import re -import time -from datetime import datetime -from unittest.mock import Mock -from urllib.parse import parse_qs - -import pendulum -import pytest -from bs4 import BeautifulSoup -from flask_appbuilder.models.sqla.filters import get_field_setup_query, set_value_to_type -from flask_wtf import FlaskForm -from markupsafe import Markup -from sqlalchemy.orm import Query -from wtforms.fields import StringField, TextAreaField - -from airflow.models import DagRun -from airflow.utils import json as utils_json -from airflow.www import utils -from airflow.www.utils import CustomSQLAInterface, DagRunCustomSQLAInterface, json_f, wrapped_markdown -from airflow.www.widgets import AirflowDateTimePickerROWidget, BS3TextAreaROWidget, BS3TextFieldROWidget -from tests.test_utils.config import conf_vars - - -class TestUtils: - def check_generate_pages_html( - self, - current_page, - total_pages, - window=7, - check_middle=False, - sorting_key=None, - sorting_direction=None, - ): - extra_links = 4 # first, prev, next, last - search = "'>\"/>" - if sorting_key and sorting_direction: - html_str = utils.generate_pages( - current_page, - total_pages, - search=search, - sorting_key=sorting_key, - sorting_direction=sorting_direction, - ) - else: - html_str = utils.generate_pages(current_page, total_pages, search=search) - - assert search not in html_str, "The raw search string shouldn't appear in the output" - assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" in html_str - - assert callable(html_str.__html__), "Should return something that is HTML-escaping aware" - - dom = BeautifulSoup(html_str, "html.parser") - assert dom is not None - - ulist = dom.ul - ulist_items = ulist.find_all("li") - assert min(window, total_pages) + extra_links == len(ulist_items) - - page_items = ulist_items[2:-2] - mid = len(page_items) // 2 - all_nodes = [] - pages = [] - - if sorting_key and sorting_direction: - last_page = total_pages - 1 - - if current_page <= mid or total_pages < window: - pages = list(range(min(total_pages, window))) - elif mid < current_page < last_page - mid: - pages = list(range(current_page - mid, current_page + mid + 1)) - else: - pages = list(range(total_pages - window, last_page + 1)) - - pages.append(last_page + 1) - pages.sort(reverse=True if sorting_direction == "desc" else False) - - for i, item in enumerate(page_items): - a_node = item.a - href_link = a_node["href"] - node_text = a_node.string - all_nodes.append(node_text) - if node_text == str(current_page + 1): - if check_middle: - assert mid == i - assert "javascript:void(0)" == href_link - assert "active" in item["class"] - else: - assert re.search(r"^\?", href_link), "Link is page-relative" - query = parse_qs(href_link[1:]) - assert query["page"] == [str(int(node_text) - 1)] - assert query["search"] == [search] - - if sorting_key and sorting_direction: - if pages[0] == 0: - pages = [str(page) for page in pages[1:]] - - assert pages == all_nodes - - def test_generate_pager_current_start(self): - self.check_generate_pages_html(current_page=0, total_pages=6) - - def test_generate_pager_current_middle(self): - self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True) - - def test_generate_pager_current_end(self): - self.check_generate_pages_html(current_page=38, total_pages=39) - - def test_generate_pager_current_start_with_sorting(self): - self.check_generate_pages_html( - current_page=0, total_pages=4, sorting_key="dag_id", sorting_direction="asc" - ) - - def test_params_no_values(self): - """Should return an empty string if no params are passed""" - assert "" == utils.get_params() - - def test_params_search(self): - assert "search=bash_" == utils.get_params(search="bash_") - - def test_params_none_and_zero(self): - query_str = utils.get_params(a=0, b=None, c="true") - # The order won't be consistent, but that doesn't affect behaviour of a browser - pairs = sorted(query_str.split("&")) - assert ["a=0", "c=true"] == pairs - - def test_params_all(self): - query = utils.get_params(tags=["tag1", "tag2"], status="active", page=3, search="bash_") - assert { - "tags": ["tag1", "tag2"], - "page": ["3"], - "search": ["bash_"], - "status": ["active"], - } == parse_qs(query) - - def test_params_escape(self): - assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" == utils.get_params( - search="'>\"/>" - ) - - def test_state_token(self): - # It's shouldn't possible to set these odd values anymore, but lets - # ensure they are escaped! - html = str(utils.state_token("")) - - assert "<script>alert(1)</script>" in html - assert "" not in html - - def test_nobr_f(self): - attr = {"attr_name": "attribute"} - f = attr.get("attr_name") - expected_markup = Markup("{}").format(f) - - nobr = utils.nobr_f("attr_name") - result_markup = nobr(attr) - - assert result_markup == expected_markup - - def test_nobr_f_empty_attr(self): - attr = {"attr_name": ""} - f = attr.get("attr_name") - expected_markup = Markup("{}").format(f) - - nobr = utils.nobr_f("attr_name") - result_markup = nobr(attr) - - assert result_markup == expected_markup - - def test_nobr_f_missing_attr(self): - attr = {} - f = None - expected_markup = Markup("{}").format(f) - - nobr = utils.nobr_f("attr_name") - result_markup = nobr(attr) - - assert result_markup == expected_markup - - def test_epoch(self): - test_datetime = datetime(2024, 6, 19, 12, 0, 0) - result = utils.epoch(test_datetime) - epoch_time = result[0] - - expected_epoch_time = int(time.mktime(test_datetime.timetuple())) * 1000 - - assert epoch_time == expected_epoch_time - - @pytest.mark.db_test - def test_make_cache_key(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context( - "/test/path", query_string={"key1": "value1", "key2": "value2"} - ): - expected_args = str(hash(frozenset({"key1": "value1", "key2": "value2"}.items()))) - expected_cache_key = ("/test/path" + expected_args).encode("ascii", "ignore") - result_cache_key = utils.make_cache_key() - assert result_cache_key == expected_cache_key - - @pytest.mark.db_test - def test_task_instance_link(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str( - utils.task_instance_link( - {"dag_id": "", "task_id": "", "map_index": 1, "execution_date": datetime.now()} - ) - ) - - html_map_index_none = str( - utils.task_instance_link( - {"dag_id": "", "task_id": "", "map_index": -1, "execution_date": datetime.now()} - ) - ) - - assert "%3Ca%261%3E" in html - assert "%3Cb2%3E" in html - assert "map_index" in html - assert "" not in html - assert "" not in html - - assert "%3Ca%261%3E" in html_map_index_none - assert "%3Cb2%3E" in html_map_index_none - assert "map_index" not in html_map_index_none - assert "" not in html_map_index_none - assert "" not in html_map_index_none - - @pytest.mark.db_test - def test_dag_link(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) - - assert "%3Ca%261%3E" in html - assert "" not in html - - @pytest.mark.db_test - def test_dag_link_when_dag_is_none(self): - """Test that when there is no dag_id, dag_link does not contain hyperlink""" - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str(utils.dag_link({})) - - assert "None" in html - assert "
    ", "run_id": "", "execution_date": datetime.now()}) - ) - - assert "%3Ca%261%3E" in html - assert "%3Cb2%3E" in html - assert "" not in html - assert "" not in html - - -class TestAttrRenderer: - def setup_method(self): - self.attr_renderer = utils.get_attr_renderer() - - def test_python_callable(self): - def example_callable(unused_self): - print("example") - - rendered = self.attr_renderer["python_callable"](example_callable) - assert ""example"" in rendered - - def test_python_callable_none(self): - rendered = self.attr_renderer["python_callable"](None) - assert "" == rendered - - def test_markdown(self): - markdown = "* foo\n* bar" - rendered = self.attr_renderer["doc_md"](markdown) - assert "
  • foo
  • " in rendered - assert "
  • bar
  • " in rendered - - def test_markdown_none(self): - rendered = self.attr_renderer["doc_md"](None) - assert rendered is None - - def test_get_dag_run_conf(self): - dag_run_conf = { - "1": "string", - "2": b"bytes", - "3": 123, - "4": "à".encode("latin"), - "5": datetime(2023, 1, 1), - } - expected_encoded_dag_run_conf = ( - '{"1": "string", "2": "bytes", "3": 123, "4": "à", "5": "2023-01-01T00:00:00+00:00"}' - ) - encoded_dag_run_conf, conf_is_json = utils.get_dag_run_conf( - dag_run_conf, json_encoder=utils_json.WebEncoder - ) - assert expected_encoded_dag_run_conf == encoded_dag_run_conf - - def test_encode_dag_run_none(self): - no_dag_run_result = utils.encode_dag_run(None) - assert no_dag_run_result is None - - def test_json_f_webencoder(self): - dag_run_conf = { - "1": "string", - "2": b"bytes", - "3": 123, - "4": "à".encode("latin"), - "5": datetime(2023, 1, 1), - } - expected_encoded_dag_run_conf = ( - # HTML sanitization is insane - '{"1": "string", "2": "bytes", "3": 123, "4": "\\u00e0", "5": "2023-01-01T00:00:00+00:00"}' - ) - expected_markup = Markup("{}").format(expected_encoded_dag_run_conf) - - formatter = json_f("conf") - dagrun = Mock() - dagrun.get = Mock(return_value=dag_run_conf) - - assert formatter(dagrun) == expected_markup - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_get_sensitive_variables_fields(): - with pytest.warns(DeprecationWarning) as warning: - result = utils.get_sensitive_variables_fields() - - # assert deprecation warning - assert len(warning) == 1 - assert "This function is deprecated." in str(warning[-1].message) - - from airflow.utils.log.secrets_masker import get_sensitive_variables_fields - - expected_result = get_sensitive_variables_fields() - assert result == expected_result - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_should_hide_value_for_key(): - key_name = "key" - - with pytest.warns(DeprecationWarning) as warning: - result = utils.should_hide_value_for_key(key_name) - - # assert deprecation warning - assert len(warning) == 1 - assert "This function is deprecated." in str(warning[-1].message) - - from airflow.utils.log.secrets_masker import should_hide_value_for_key - - expected_result = should_hide_value_for_key(key_name) - assert result == expected_result - - -class TestWrappedMarkdown: - def test_wrapped_markdown_with_docstring_curly_braces(self): - rendered = wrapped_markdown("{braces}", css_class="a_class") - assert ( - """

    {braces}

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_some_markdown(self): - rendered = wrapped_markdown( - """*italic* - **bold** - """, - css_class="a_class", - ) - - assert ( - """

    italic -bold

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_table(self): - rendered = wrapped_markdown( - """ -| Job | Duration | -| ----------- | ----------- | -| ETL | 14m | -""" - ) - - assert ( - """
    - - - - - - - - - - - - -
    JobDuration
    ETL14m
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_indented_lines(self): - rendered = wrapped_markdown( - """ - # header - 1st line - 2nd line - """ - ) - - assert ( - """

    header

    \n

    1st line\n2nd line

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_raw_code_block(self): - rendered = wrapped_markdown( - """\ - # Markdown code block - - Inline `code` works well. - - Code block - does not - respect - newlines - - """ - ) - - assert ( - """

    Markdown code block

    -

    Inline code works well.

    -
    Code block\ndoes not\nrespect\nnewlines\n
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_nested_list(self): - rendered = wrapped_markdown( - """ - ### Docstring with a code block - - - And - - A nested list - """ - ) - - assert ( - """

    Docstring with a code block

    -
      -
    • And -
        -
      • A nested list
      • -
      -
    • -
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_collapsible_section(self): - with conf_vars({("webserver", "allow_raw_html_descriptions"): "true"}): - rendered = wrapped_markdown( - """ -# A collapsible section with markdown -
    - Click to expand! - - ## Heading - 1. A numbered - 2. list - * With some - * Sub bullets -
    - """ - ) - - assert ( - """

    A collapsible section with markdown

    -
    - Click to expand! -

    Heading

    -
      -
    1. A numbered
    2. -
    3. list -
        -
      • With some
      • -
      • Sub bullets
      • -
      -
    4. -
    -
    -
    """ - == rendered - ) - - @pytest.mark.parametrize("allow_html", [False, True]) - def test_wrapped_markdown_with_raw_html(self, allow_html): - with conf_vars({("webserver", "allow_raw_html_descriptions"): str(allow_html)}): - HTML = "test raw HTML" - rendered = wrapped_markdown(HTML) - if allow_html: - assert HTML in rendered - else: - from markupsafe import escape - - assert escape(HTML) in rendered - - -class TestFilter: - def setup_method(self): - self.mock_datamodel = Mock() - self.mock_query = Mock(spec=Query) - self.mock_column_name = "test_column" - - def test_filter_is_null_apply(self): - filter_is_null = utils.FilterIsNull(datamodel=self.mock_datamodel, column_name=self.mock_column_name) - - self.mock_query, mock_field = get_field_setup_query( - self.mock_query, self.mock_datamodel, self.mock_column_name - ) - mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) - - result_query_filter = filter_is_null.apply(self.mock_query, None) - self.mock_query.filter.assert_called_once_with(mock_field == mock_value) - - expected_query_filter = self.mock_query.filter(mock_field == mock_value) - - assert result_query_filter == expected_query_filter - - def test_filter_is_not_null_apply(self): - filter_is_not_null = utils.FilterIsNotNull( - datamodel=self.mock_datamodel, column_name=self.mock_column_name - ) - - self.mock_query, mock_field = get_field_setup_query( - self.mock_query, self.mock_datamodel, self.mock_column_name - ) - mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) - - result_query_filter = filter_is_not_null.apply(self.mock_query, None) - self.mock_query.filter.assert_called_once_with(mock_field != mock_value) - - expected_query_filter = self.mock_query.filter(mock_field != mock_value) - - assert result_query_filter == expected_query_filter - - def test_filter_gte_none_value_apply(self): - filter_gte = utils.FilterGreaterOrEqual( - datamodel=self.mock_datamodel, column_name=self.mock_column_name - ) - - self.mock_query, mock_field = get_field_setup_query( - self.mock_query, self.mock_datamodel, self.mock_column_name - ) - mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) - - result_query_filter = filter_gte.apply(self.mock_query, mock_value) - - assert result_query_filter == self.mock_query - - def test_filter_lte_none_value_apply(self): - filter_lte = utils.FilterSmallerOrEqual( - datamodel=self.mock_datamodel, column_name=self.mock_column_name - ) - - self.mock_query, mock_field = get_field_setup_query( - self.mock_query, self.mock_datamodel, self.mock_column_name - ) - mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) - - result_query_filter = filter_lte.apply(self.mock_query, mock_value) - - assert result_query_filter == self.mock_query - - -@pytest.mark.db_test -def test_get_col_default_not_existing(session): - interface = CustomSQLAInterface(obj=DagRun, session=session) - default_value = interface.get_col_default("column_not_existing") - assert default_value is None - - -@pytest.mark.db_test -def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, session): - interface = DagRunCustomSQLAInterface(obj=DagRun, session=session) - dag_ids = (f"test_dag_{x}" for x in range(1, 4)) - dates = (pendulum.datetime(2023, 1, x) for x in range(1, 4)) - for dag_id, date in itertools.product(dag_ids, dates): - with dag_maker(dag_id=dag_id) as dag: - dag.create_dagrun( - execution_date=date, state="running", run_type="scheduled", data_interval=(date, date) - ) - dag_runs = session.query(DagRun).all() - assert len(dag_runs) == 9 - assert len(set(x.run_id for x in dag_runs)) == 3 - run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00" - # we have 3 runs with this same run_id - assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3 - # each is a different dag - - # if we delete one, it shouldn't delete the others - one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete) - assert interface.delete(item=one_run) is True - session.commit() - dag_runs = session.query(DagRun).all() - # we should have one fewer dag run now - assert len(dag_runs) == 8 - - # now let's try multi delete - run_id_for_multi_delete = "scheduled__2023-01-02T00:00:00+00:00" - # verify we have 3 - runs_of_interest = [x for x in dag_runs if x.run_id == run_id_for_multi_delete] - assert len(runs_of_interest) == 3 - # and that each is different dag - assert len(set(x.dag_id for x in dag_runs)) == 3 - - to_delete = runs_of_interest[:2] - # now try multi delete - assert interface.delete_all(items=to_delete) is True - session.commit() - dag_runs = session.query(DagRun).all() - assert len(dag_runs) == 6 - assert len(set(x.dag_id for x in dag_runs)) == 3 - assert len(set(x.run_id for x in dag_runs)) == 3 - - -@pytest.fixture -def app(): - from flask import Flask - - app = Flask(__name__) - app.config["WTF_CSRF_ENABLED"] = False - app.config["SECRET_KEY"] = "secret" - with app.app_context(): - yield app - - -class TestWidgets: - def test_airflow_datetime_picker_ro_widget(self, app): - class TestForm(FlaskForm): - datetime_field = StringField(widget=AirflowDateTimePickerROWidget()) - - form = TestForm() - field = form.datetime_field - - html_output = field() - - assert 'readonly="true"' in html_output - assert "input-group datetime datetimepicker" in html_output - - def test_bs3_text_field_ro_widget(self, app): - class TestForm(FlaskForm): - text_field = StringField(widget=BS3TextFieldROWidget()) - - form = TestForm() - field = form.text_field - - html_output = field() - - assert 'readonly="true"' in html_output - assert "form-control" in html_output - - def test_bs3_text_area_ro_widget(self, app): - class TestForm(FlaskForm): - textarea_field = TextAreaField(widget=BS3TextAreaROWidget()) - - form = TestForm() - field = form.textarea_field - - html_output = field() - - assert 'readonly="true"' in html_output - assert "form-control" in html_output diff --git a/tests/www/test_utils_BASE_289691.py b/tests/www/test_utils_BASE_289691.py deleted file mode 100644 index 1fc42c1fefb622..00000000000000 --- a/tests/www/test_utils_BASE_289691.py +++ /dev/null @@ -1,479 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import itertools -import re -from datetime import datetime -from unittest.mock import Mock -from urllib.parse import parse_qs - -import pendulum -import pytest -from bs4 import BeautifulSoup -from markupsafe import Markup - -from airflow.models import DagRun -from airflow.utils import json as utils_json -from airflow.www import utils -from airflow.www.utils import DagRunCustomSQLAInterface, json_f, wrapped_markdown -from tests.test_utils.config import conf_vars - - -class TestUtils: - def check_generate_pages_html( - self, - current_page, - total_pages, - window=7, - check_middle=False, - sorting_key=None, - sorting_direction=None, - ): - extra_links = 4 # first, prev, next, last - search = "'>\"/>" - if sorting_key and sorting_direction: - html_str = utils.generate_pages( - current_page, - total_pages, - search=search, - sorting_key=sorting_key, - sorting_direction=sorting_direction, - ) - else: - html_str = utils.generate_pages(current_page, total_pages, search=search) - - assert search not in html_str, "The raw search string shouldn't appear in the output" - assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" in html_str - - assert callable(html_str.__html__), "Should return something that is HTML-escaping aware" - - dom = BeautifulSoup(html_str, "html.parser") - assert dom is not None - - ulist = dom.ul - ulist_items = ulist.find_all("li") - assert min(window, total_pages) + extra_links == len(ulist_items) - - page_items = ulist_items[2:-2] - mid = len(page_items) // 2 - all_nodes = [] - pages = [] - - if sorting_key and sorting_direction: - last_page = total_pages - 1 - - if current_page <= mid or total_pages < window: - pages = list(range(min(total_pages, window))) - elif mid < current_page < last_page - mid: - pages = list(range(current_page - mid, current_page + mid + 1)) - else: - pages = list(range(total_pages - window, last_page + 1)) - - pages.append(last_page + 1) - pages.sort(reverse=True if sorting_direction == "desc" else False) - - for i, item in enumerate(page_items): - a_node = item.a - href_link = a_node["href"] - node_text = a_node.string - all_nodes.append(node_text) - if node_text == str(current_page + 1): - if check_middle: - assert mid == i - assert "javascript:void(0)" == href_link - assert "active" in item["class"] - else: - assert re.search(r"^\?", href_link), "Link is page-relative" - query = parse_qs(href_link[1:]) - assert query["page"] == [str(int(node_text) - 1)] - assert query["search"] == [search] - - if sorting_key and sorting_direction: - if pages[0] == 0: - pages = [str(page) for page in pages[1:]] - - assert pages == all_nodes - - def test_generate_pager_current_start(self): - self.check_generate_pages_html(current_page=0, total_pages=6) - - def test_generate_pager_current_middle(self): - self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True) - - def test_generate_pager_current_end(self): - self.check_generate_pages_html(current_page=38, total_pages=39) - - def test_generate_pager_current_start_with_sorting(self): - self.check_generate_pages_html( - current_page=0, total_pages=4, sorting_key="dag_id", sorting_direction="asc" - ) - - def test_params_no_values(self): - """Should return an empty string if no params are passed""" - assert "" == utils.get_params() - - def test_params_search(self): - assert "search=bash_" == utils.get_params(search="bash_") - - def test_params_none_and_zero(self): - query_str = utils.get_params(a=0, b=None, c="true") - # The order won't be consistent, but that doesn't affect behaviour of a browser - pairs = sorted(query_str.split("&")) - assert ["a=0", "c=true"] == pairs - - def test_params_all(self): - query = utils.get_params(tags=["tag1", "tag2"], status="active", page=3, search="bash_") - assert { - "tags": ["tag1", "tag2"], - "page": ["3"], - "search": ["bash_"], - "status": ["active"], - } == parse_qs(query) - - def test_params_escape(self): - assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" == utils.get_params( - search="'>\"/>" - ) - - def test_state_token(self): - # It's shouldn't possible to set these odd values anymore, but lets - # ensure they are escaped! - html = str(utils.state_token("")) - - assert "<script>alert(1)</script>" in html - assert "" not in html - - @pytest.mark.db_test - def test_task_instance_link(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str( - utils.task_instance_link( - {"dag_id": "", "task_id": "", "execution_date": datetime.now()} - ) - ) - - assert "%3Ca%261%3E" in html - assert "%3Cb2%3E" in html - assert "" not in html - assert "" not in html - - @pytest.mark.db_test - def test_dag_link(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) - - assert "%3Ca%261%3E" in html - assert "" not in html - - @pytest.mark.db_test - def test_dag_link_when_dag_is_none(self): - """Test that when there is no dag_id, dag_link does not contain hyperlink""" - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str(utils.dag_link({})) - - assert "None" in html - assert "
    ", "run_id": "", "execution_date": datetime.now()}) - ) - - assert "%3Ca%261%3E" in html - assert "%3Cb2%3E" in html - assert "" not in html - assert "" not in html - - -class TestAttrRenderer: - def setup_method(self): - self.attr_renderer = utils.get_attr_renderer() - - def test_python_callable(self): - def example_callable(unused_self): - print("example") - - rendered = self.attr_renderer["python_callable"](example_callable) - assert ""example"" in rendered - - def test_python_callable_none(self): - rendered = self.attr_renderer["python_callable"](None) - assert "" == rendered - - def test_markdown(self): - markdown = "* foo\n* bar" - rendered = self.attr_renderer["doc_md"](markdown) - assert "
  • foo
  • " in rendered - assert "
  • bar
  • " in rendered - - def test_markdown_none(self): - rendered = self.attr_renderer["doc_md"](None) - assert rendered is None - - def test_get_dag_run_conf(self): - dag_run_conf = { - "1": "string", - "2": b"bytes", - "3": 123, - "4": "à".encode("latin"), - "5": datetime(2023, 1, 1), - } - expected_encoded_dag_run_conf = ( - '{"1": "string", "2": "bytes", "3": 123, "4": "à", "5": "2023-01-01T00:00:00+00:00"}' - ) - encoded_dag_run_conf, conf_is_json = utils.get_dag_run_conf( - dag_run_conf, json_encoder=utils_json.WebEncoder - ) - assert expected_encoded_dag_run_conf == encoded_dag_run_conf - - def test_json_f_webencoder(self): - dag_run_conf = { - "1": "string", - "2": b"bytes", - "3": 123, - "4": "à".encode("latin"), - "5": datetime(2023, 1, 1), - } - expected_encoded_dag_run_conf = ( - # HTML sanitization is insane - '{"1": "string", "2": "bytes", "3": 123, "4": "\\u00e0", "5": "2023-01-01T00:00:00+00:00"}' - ) - expected_markup = Markup("{}").format(expected_encoded_dag_run_conf) - - formatter = json_f("conf") - dagrun = Mock() - dagrun.get = Mock(return_value=dag_run_conf) - - assert formatter(dagrun) == expected_markup - - -class TestWrappedMarkdown: - def test_wrapped_markdown_with_docstring_curly_braces(self): - rendered = wrapped_markdown("{braces}", css_class="a_class") - assert ( - """

    {braces}

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_some_markdown(self): - rendered = wrapped_markdown( - """*italic* - **bold** - """, - css_class="a_class", - ) - - assert ( - """

    italic -bold

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_table(self): - rendered = wrapped_markdown( - """ -| Job | Duration | -| ----------- | ----------- | -| ETL | 14m | -""" - ) - - assert ( - """
    - - - - - - - - - - - - -
    JobDuration
    ETL14m
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_indented_lines(self): - rendered = wrapped_markdown( - """ - # header - 1st line - 2nd line - """ - ) - - assert ( - """

    header

    \n

    1st line\n2nd line

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_raw_code_block(self): - rendered = wrapped_markdown( - """\ - # Markdown code block - - Inline `code` works well. - - Code block - does not - respect - newlines - - """ - ) - - assert ( - """

    Markdown code block

    -

    Inline code works well.

    -
    Code block\ndoes not\nrespect\nnewlines\n
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_nested_list(self): - rendered = wrapped_markdown( - """ - ### Docstring with a code block - - - And - - A nested list - """ - ) - - assert ( - """

    Docstring with a code block

    -
      -
    • And -
        -
      • A nested list
      • -
      -
    • -
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_collapsible_section(self): - with conf_vars({("webserver", "allow_raw_html_descriptions"): "true"}): - rendered = wrapped_markdown( - """ -# A collapsible section with markdown -
    - Click to expand! - - ## Heading - 1. A numbered - 2. list - * With some - * Sub bullets -
    - """ - ) - - assert ( - """

    A collapsible section with markdown

    -
    - Click to expand! -

    Heading

    -
      -
    1. A numbered
    2. -
    3. list -
        -
      • With some
      • -
      • Sub bullets
      • -
      -
    4. -
    -
    -
    """ - == rendered - ) - - @pytest.mark.parametrize("allow_html", [False, True]) - def test_wrapped_markdown_with_raw_html(self, allow_html): - with conf_vars({("webserver", "allow_raw_html_descriptions"): str(allow_html)}): - HTML = "test raw HTML" - rendered = wrapped_markdown(HTML) - if allow_html: - assert HTML in rendered - else: - from markupsafe import escape - - assert escape(HTML) in rendered - - -@pytest.mark.db_test -def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, session): - interface = DagRunCustomSQLAInterface(obj=DagRun, session=session) - dag_ids = (f"test_dag_{x}" for x in range(1, 4)) - dates = (pendulum.datetime(2023, 1, x) for x in range(1, 4)) - for dag_id, date in itertools.product(dag_ids, dates): - with dag_maker(dag_id=dag_id) as dag: - dag.create_dagrun( - execution_date=date, state="running", run_type="scheduled", data_interval=(date, date) - ) - dag_runs = session.query(DagRun).all() - assert len(dag_runs) == 9 - assert len(set(x.run_id for x in dag_runs)) == 3 - run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00" - # we have 3 runs with this same run_id - assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3 - # each is a different dag - - # if we delete one, it shouldn't delete the others - one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete) - assert interface.delete(item=one_run) is True - session.commit() - dag_runs = session.query(DagRun).all() - # we should have one fewer dag run now - assert len(dag_runs) == 8 - - # now let's try multi delete - run_id_for_multi_delete = "scheduled__2023-01-02T00:00:00+00:00" - # verify we have 3 - runs_of_interest = [x for x in dag_runs if x.run_id == run_id_for_multi_delete] - assert len(runs_of_interest) == 3 - # and that each is different dag - assert len(set(x.dag_id for x in dag_runs)) == 3 - - to_delete = runs_of_interest[:2] - # now try multi delete - assert interface.delete_all(items=to_delete) is True - session.commit() - dag_runs = session.query(DagRun).all() - assert len(dag_runs) == 6 - assert len(set(x.dag_id for x in dag_runs)) == 3 - assert len(set(x.run_id for x in dag_runs)) == 3 diff --git a/tests/www/test_utils_LOCAL_289691.py b/tests/www/test_utils_LOCAL_289691.py deleted file mode 100644 index a90d9246998d6f..00000000000000 --- a/tests/www/test_utils_LOCAL_289691.py +++ /dev/null @@ -1,656 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import itertools -import re -import time -from datetime import datetime -from unittest.mock import Mock -from urllib.parse import parse_qs - -import pendulum -import pytest -from bs4 import BeautifulSoup -from flask_appbuilder.models.sqla.filters import get_field_setup_query, set_value_to_type -from markupsafe import Markup -from sqlalchemy.orm import Query - -from airflow.models import DagRun -from airflow.utils import json as utils_json -from airflow.www import utils -from airflow.www.utils import CustomSQLAInterface, DagRunCustomSQLAInterface, json_f, wrapped_markdown -from tests.test_utils.config import conf_vars - - -class TestUtils: - def check_generate_pages_html( - self, - current_page, - total_pages, - window=7, - check_middle=False, - sorting_key=None, - sorting_direction=None, - ): - extra_links = 4 # first, prev, next, last - search = "'>\"/>" - if sorting_key and sorting_direction: - html_str = utils.generate_pages( - current_page, - total_pages, - search=search, - sorting_key=sorting_key, - sorting_direction=sorting_direction, - ) - else: - html_str = utils.generate_pages(current_page, total_pages, search=search) - - assert search not in html_str, "The raw search string shouldn't appear in the output" - assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" in html_str - - assert callable(html_str.__html__), "Should return something that is HTML-escaping aware" - - dom = BeautifulSoup(html_str, "html.parser") - assert dom is not None - - ulist = dom.ul - ulist_items = ulist.find_all("li") - assert min(window, total_pages) + extra_links == len(ulist_items) - - page_items = ulist_items[2:-2] - mid = len(page_items) // 2 - all_nodes = [] - pages = [] - - if sorting_key and sorting_direction: - last_page = total_pages - 1 - - if current_page <= mid or total_pages < window: - pages = list(range(min(total_pages, window))) - elif mid < current_page < last_page - mid: - pages = list(range(current_page - mid, current_page + mid + 1)) - else: - pages = list(range(total_pages - window, last_page + 1)) - - pages.append(last_page + 1) - pages.sort(reverse=True if sorting_direction == "desc" else False) - - for i, item in enumerate(page_items): - a_node = item.a - href_link = a_node["href"] - node_text = a_node.string - all_nodes.append(node_text) - if node_text == str(current_page + 1): - if check_middle: - assert mid == i - assert "javascript:void(0)" == href_link - assert "active" in item["class"] - else: - assert re.search(r"^\?", href_link), "Link is page-relative" - query = parse_qs(href_link[1:]) - assert query["page"] == [str(int(node_text) - 1)] - assert query["search"] == [search] - - if sorting_key and sorting_direction: - if pages[0] == 0: - pages = [str(page) for page in pages[1:]] - - assert pages == all_nodes - - def test_generate_pager_current_start(self): - self.check_generate_pages_html(current_page=0, total_pages=6) - - def test_generate_pager_current_middle(self): - self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True) - - def test_generate_pager_current_end(self): - self.check_generate_pages_html(current_page=38, total_pages=39) - - def test_generate_pager_current_start_with_sorting(self): - self.check_generate_pages_html( - current_page=0, total_pages=4, sorting_key="dag_id", sorting_direction="asc" - ) - - def test_params_no_values(self): - """Should return an empty string if no params are passed""" - assert "" == utils.get_params() - - def test_params_search(self): - assert "search=bash_" == utils.get_params(search="bash_") - - def test_params_none_and_zero(self): - query_str = utils.get_params(a=0, b=None, c="true") - # The order won't be consistent, but that doesn't affect behaviour of a browser - pairs = sorted(query_str.split("&")) - assert ["a=0", "c=true"] == pairs - - def test_params_all(self): - query = utils.get_params(tags=["tag1", "tag2"], status="active", page=3, search="bash_") - assert { - "tags": ["tag1", "tag2"], - "page": ["3"], - "search": ["bash_"], - "status": ["active"], - } == parse_qs(query) - - def test_params_escape(self): - assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" == utils.get_params( - search="'>\"/>" - ) - - def test_state_token(self): - # It's shouldn't possible to set these odd values anymore, but lets - # ensure they are escaped! - html = str(utils.state_token("")) - - assert "<script>alert(1)</script>" in html - assert "" not in html - - def test_nobr_f(self): - attr = {"attr_name": "attribute"} - f = attr.get("attr_name") - expected_markup = Markup("{}").format(f) - - nobr = utils.nobr_f("attr_name") - result_markup = nobr(attr) - - assert result_markup == expected_markup - - def test_nobr_f_empty_attr(self): - attr = {"attr_name": ""} - f = attr.get("attr_name") - expected_markup = Markup("{}").format(f) - - nobr = utils.nobr_f("attr_name") - result_markup = nobr(attr) - - assert result_markup == expected_markup - - def test_nobr_f_missing_attr(self): - attr = {} - f = None - expected_markup = Markup("{}").format(f) - - nobr = utils.nobr_f("attr_name") - result_markup = nobr(attr) - - assert result_markup == expected_markup - - def test_epoch(self): - test_datetime = datetime(2024, 6, 19, 12, 0, 0) - result = utils.epoch(test_datetime) - epoch_time = result[0] - - expected_epoch_time = int(time.mktime(test_datetime.timetuple())) * 1000 - - assert epoch_time == expected_epoch_time - - @pytest.mark.db_test - def test_make_cache_key(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context( - "/test/path", query_string={"key1": "value1", "key2": "value2"} - ): - expected_args = str(hash(frozenset({"key1": "value1", "key2": "value2"}.items()))) - expected_cache_key = ("/test/path" + expected_args).encode("ascii", "ignore") - result_cache_key = utils.make_cache_key() - assert result_cache_key == expected_cache_key - - @pytest.mark.db_test - def test_task_instance_link(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str( - utils.task_instance_link( - {"dag_id": "", "task_id": "", "map_index": 1, "execution_date": datetime.now()} - ) - ) - - html_map_index_none = str( - utils.task_instance_link( - {"dag_id": "", "task_id": "", "map_index": -1, "execution_date": datetime.now()} - ) - ) - - assert "%3Ca%261%3E" in html - assert "%3Cb2%3E" in html - assert "map_index" in html - assert "" not in html - assert "" not in html - - assert "%3Ca%261%3E" in html_map_index_none - assert "%3Cb2%3E" in html_map_index_none - assert "map_index" not in html_map_index_none - assert "" not in html_map_index_none - assert "" not in html_map_index_none - - @pytest.mark.db_test - def test_dag_link(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) - - assert "%3Ca%261%3E" in html - assert "" not in html - - @pytest.mark.db_test - def test_dag_link_when_dag_is_none(self): - """Test that when there is no dag_id, dag_link does not contain hyperlink""" - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str(utils.dag_link({})) - - assert "None" in html - assert "
    ", "run_id": "", "execution_date": datetime.now()}) - ) - - assert "%3Ca%261%3E" in html - assert "%3Cb2%3E" in html - assert "" not in html - assert "" not in html - - -class TestAttrRenderer: - def setup_method(self): - self.attr_renderer = utils.get_attr_renderer() - - def test_python_callable(self): - def example_callable(unused_self): - print("example") - - rendered = self.attr_renderer["python_callable"](example_callable) - assert ""example"" in rendered - - def test_python_callable_none(self): - rendered = self.attr_renderer["python_callable"](None) - assert "" == rendered - - def test_markdown(self): - markdown = "* foo\n* bar" - rendered = self.attr_renderer["doc_md"](markdown) - assert "
  • foo
  • " in rendered - assert "
  • bar
  • " in rendered - - def test_markdown_none(self): - rendered = self.attr_renderer["doc_md"](None) - assert rendered is None - - def test_get_dag_run_conf(self): - dag_run_conf = { - "1": "string", - "2": b"bytes", - "3": 123, - "4": "à".encode("latin"), - "5": datetime(2023, 1, 1), - } - expected_encoded_dag_run_conf = ( - '{"1": "string", "2": "bytes", "3": 123, "4": "à", "5": "2023-01-01T00:00:00+00:00"}' - ) - encoded_dag_run_conf, conf_is_json = utils.get_dag_run_conf( - dag_run_conf, json_encoder=utils_json.WebEncoder - ) - assert expected_encoded_dag_run_conf == encoded_dag_run_conf - - def test_encode_dag_run_none(self): - no_dag_run_result = utils.encode_dag_run(None) - assert no_dag_run_result is None - - def test_json_f_webencoder(self): - dag_run_conf = { - "1": "string", - "2": b"bytes", - "3": 123, - "4": "à".encode("latin"), - "5": datetime(2023, 1, 1), - } - expected_encoded_dag_run_conf = ( - # HTML sanitization is insane - '{"1": "string", "2": "bytes", "3": 123, "4": "\\u00e0", "5": "2023-01-01T00:00:00+00:00"}' - ) - expected_markup = Markup("{}").format(expected_encoded_dag_run_conf) - - formatter = json_f("conf") - dagrun = Mock() - dagrun.get = Mock(return_value=dag_run_conf) - - assert formatter(dagrun) == expected_markup - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_get_sensitive_variables_fields(): - with pytest.warns(DeprecationWarning) as warning: - result = utils.get_sensitive_variables_fields() - - # assert deprecation warning - assert len(warning) == 1 - assert "This function is deprecated." in str(warning[-1].message) - - from airflow.utils.log.secrets_masker import get_sensitive_variables_fields - - expected_result = get_sensitive_variables_fields() - assert result == expected_result - - -@pytest.mark.filterwarnings("ignore::DeprecationWarning") -def test_should_hide_value_for_key(): - key_name = "key" - - with pytest.warns(DeprecationWarning) as warning: - result = utils.should_hide_value_for_key(key_name) - - # assert deprecation warning - assert len(warning) == 1 - assert "This function is deprecated." in str(warning[-1].message) - - from airflow.utils.log.secrets_masker import should_hide_value_for_key - - expected_result = should_hide_value_for_key(key_name) - assert result == expected_result - - -class TestWrappedMarkdown: - def test_wrapped_markdown_with_docstring_curly_braces(self): - rendered = wrapped_markdown("{braces}", css_class="a_class") - assert ( - """

    {braces}

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_some_markdown(self): - rendered = wrapped_markdown( - """*italic* - **bold** - """, - css_class="a_class", - ) - - assert ( - """

    italic -bold

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_table(self): - rendered = wrapped_markdown( - """ -| Job | Duration | -| ----------- | ----------- | -| ETL | 14m | -""" - ) - - assert ( - """
    - - - - - - - - - - - - -
    JobDuration
    ETL14m
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_indented_lines(self): - rendered = wrapped_markdown( - """ - # header - 1st line - 2nd line - """ - ) - - assert ( - """

    header

    \n

    1st line\n2nd line

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_raw_code_block(self): - rendered = wrapped_markdown( - """\ - # Markdown code block - - Inline `code` works well. - - Code block - does not - respect - newlines - - """ - ) - - assert ( - """

    Markdown code block

    -

    Inline code works well.

    -
    Code block\ndoes not\nrespect\nnewlines\n
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_nested_list(self): - rendered = wrapped_markdown( - """ - ### Docstring with a code block - - - And - - A nested list - """ - ) - - assert ( - """

    Docstring with a code block

    -
      -
    • And -
        -
      • A nested list
      • -
      -
    • -
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_collapsible_section(self): - with conf_vars({("webserver", "allow_raw_html_descriptions"): "true"}): - rendered = wrapped_markdown( - """ -# A collapsible section with markdown -
    - Click to expand! - - ## Heading - 1. A numbered - 2. list - * With some - * Sub bullets -
    - """ - ) - - assert ( - """

    A collapsible section with markdown

    -
    - Click to expand! -

    Heading

    -
      -
    1. A numbered
    2. -
    3. list -
        -
      • With some
      • -
      • Sub bullets
      • -
      -
    4. -
    -
    -
    """ - == rendered - ) - - @pytest.mark.parametrize("allow_html", [False, True]) - def test_wrapped_markdown_with_raw_html(self, allow_html): - with conf_vars({("webserver", "allow_raw_html_descriptions"): str(allow_html)}): - HTML = "test raw HTML" - rendered = wrapped_markdown(HTML) - if allow_html: - assert HTML in rendered - else: - from markupsafe import escape - - assert escape(HTML) in rendered - - -class TestFilter: - def setup_method(self): - self.mock_datamodel = Mock() - self.mock_query = Mock(spec=Query) - self.mock_column_name = "test_column" - - def test_filter_is_null_apply(self): - filter_is_null = utils.FilterIsNull(datamodel=self.mock_datamodel, column_name=self.mock_column_name) - - self.mock_query, mock_field = get_field_setup_query( - self.mock_query, self.mock_datamodel, self.mock_column_name - ) - mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) - - result_query_filter = filter_is_null.apply(self.mock_query, None) - self.mock_query.filter.assert_called_once_with(mock_field == mock_value) - - expected_query_filter = self.mock_query.filter(mock_field == mock_value) - - assert result_query_filter == expected_query_filter - - def test_filter_is_not_null_apply(self): - filter_is_not_null = utils.FilterIsNotNull( - datamodel=self.mock_datamodel, column_name=self.mock_column_name - ) - - self.mock_query, mock_field = get_field_setup_query( - self.mock_query, self.mock_datamodel, self.mock_column_name - ) - mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) - - result_query_filter = filter_is_not_null.apply(self.mock_query, None) - self.mock_query.filter.assert_called_once_with(mock_field != mock_value) - - expected_query_filter = self.mock_query.filter(mock_field != mock_value) - - assert result_query_filter == expected_query_filter - - def test_filter_gte_none_value_apply(self): - filter_gte = utils.FilterGreaterOrEqual( - datamodel=self.mock_datamodel, column_name=self.mock_column_name - ) - - self.mock_query, mock_field = get_field_setup_query( - self.mock_query, self.mock_datamodel, self.mock_column_name - ) - mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) - - result_query_filter = filter_gte.apply(self.mock_query, mock_value) - - assert result_query_filter == self.mock_query - - def test_filter_lte_none_value_apply(self): - filter_lte = utils.FilterSmallerOrEqual( - datamodel=self.mock_datamodel, column_name=self.mock_column_name - ) - - self.mock_query, mock_field = get_field_setup_query( - self.mock_query, self.mock_datamodel, self.mock_column_name - ) - mock_value = set_value_to_type(self.mock_datamodel, self.mock_column_name, None) - - result_query_filter = filter_lte.apply(self.mock_query, mock_value) - - assert result_query_filter == self.mock_query - - -@pytest.mark.db_test -def test_get_col_default_not_existing(session): - interface = CustomSQLAInterface(obj=DagRun, session=session) - default_value = interface.get_col_default("column_not_existing") - assert default_value is None - - -@pytest.mark.db_test -def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, session): - interface = DagRunCustomSQLAInterface(obj=DagRun, session=session) - dag_ids = (f"test_dag_{x}" for x in range(1, 4)) - dates = (pendulum.datetime(2023, 1, x) for x in range(1, 4)) - for dag_id, date in itertools.product(dag_ids, dates): - with dag_maker(dag_id=dag_id) as dag: - dag.create_dagrun( - execution_date=date, state="running", run_type="scheduled", data_interval=(date, date) - ) - dag_runs = session.query(DagRun).all() - assert len(dag_runs) == 9 - assert len(set(x.run_id for x in dag_runs)) == 3 - run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00" - # we have 3 runs with this same run_id - assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3 - # each is a different dag - - # if we delete one, it shouldn't delete the others - one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete) - assert interface.delete(item=one_run) is True - session.commit() - dag_runs = session.query(DagRun).all() - # we should have one fewer dag run now - assert len(dag_runs) == 8 - - # now let's try multi delete - run_id_for_multi_delete = "scheduled__2023-01-02T00:00:00+00:00" - # verify we have 3 - runs_of_interest = [x for x in dag_runs if x.run_id == run_id_for_multi_delete] - assert len(runs_of_interest) == 3 - # and that each is different dag - assert len(set(x.dag_id for x in dag_runs)) == 3 - - to_delete = runs_of_interest[:2] - # now try multi delete - assert interface.delete_all(items=to_delete) is True - session.commit() - dag_runs = session.query(DagRun).all() - assert len(dag_runs) == 6 - assert len(set(x.dag_id for x in dag_runs)) == 3 - assert len(set(x.run_id for x in dag_runs)) == 3 diff --git a/tests/www/test_utils_REMOTE_289691.py b/tests/www/test_utils_REMOTE_289691.py deleted file mode 100644 index 640a4a10de2ae8..00000000000000 --- a/tests/www/test_utils_REMOTE_289691.py +++ /dev/null @@ -1,531 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import itertools -import re -from datetime import datetime -from unittest.mock import Mock -from urllib.parse import parse_qs - -import pendulum -import pytest -from bs4 import BeautifulSoup -from flask_wtf import FlaskForm -from markupsafe import Markup -from wtforms.fields import StringField, TextAreaField - -from airflow.models import DagRun -from airflow.utils import json as utils_json -from airflow.www import utils -from airflow.www.utils import DagRunCustomSQLAInterface, json_f, wrapped_markdown -from airflow.www.widgets import AirflowDateTimePickerROWidget, BS3TextAreaROWidget, BS3TextFieldROWidget -from tests.test_utils.config import conf_vars - - -class TestUtils: - def check_generate_pages_html( - self, - current_page, - total_pages, - window=7, - check_middle=False, - sorting_key=None, - sorting_direction=None, - ): - extra_links = 4 # first, prev, next, last - search = "'>\"/>" - if sorting_key and sorting_direction: - html_str = utils.generate_pages( - current_page, - total_pages, - search=search, - sorting_key=sorting_key, - sorting_direction=sorting_direction, - ) - else: - html_str = utils.generate_pages(current_page, total_pages, search=search) - - assert search not in html_str, "The raw search string shouldn't appear in the output" - assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" in html_str - - assert callable(html_str.__html__), "Should return something that is HTML-escaping aware" - - dom = BeautifulSoup(html_str, "html.parser") - assert dom is not None - - ulist = dom.ul - ulist_items = ulist.find_all("li") - assert min(window, total_pages) + extra_links == len(ulist_items) - - page_items = ulist_items[2:-2] - mid = len(page_items) // 2 - all_nodes = [] - pages = [] - - if sorting_key and sorting_direction: - last_page = total_pages - 1 - - if current_page <= mid or total_pages < window: - pages = list(range(min(total_pages, window))) - elif mid < current_page < last_page - mid: - pages = list(range(current_page - mid, current_page + mid + 1)) - else: - pages = list(range(total_pages - window, last_page + 1)) - - pages.append(last_page + 1) - pages.sort(reverse=True if sorting_direction == "desc" else False) - - for i, item in enumerate(page_items): - a_node = item.a - href_link = a_node["href"] - node_text = a_node.string - all_nodes.append(node_text) - if node_text == str(current_page + 1): - if check_middle: - assert mid == i - assert "javascript:void(0)" == href_link - assert "active" in item["class"] - else: - assert re.search(r"^\?", href_link), "Link is page-relative" - query = parse_qs(href_link[1:]) - assert query["page"] == [str(int(node_text) - 1)] - assert query["search"] == [search] - - if sorting_key and sorting_direction: - if pages[0] == 0: - pages = [str(page) for page in pages[1:]] - - assert pages == all_nodes - - def test_generate_pager_current_start(self): - self.check_generate_pages_html(current_page=0, total_pages=6) - - def test_generate_pager_current_middle(self): - self.check_generate_pages_html(current_page=10, total_pages=20, check_middle=True) - - def test_generate_pager_current_end(self): - self.check_generate_pages_html(current_page=38, total_pages=39) - - def test_generate_pager_current_start_with_sorting(self): - self.check_generate_pages_html( - current_page=0, total_pages=4, sorting_key="dag_id", sorting_direction="asc" - ) - - def test_params_no_values(self): - """Should return an empty string if no params are passed""" - assert "" == utils.get_params() - - def test_params_search(self): - assert "search=bash_" == utils.get_params(search="bash_") - - def test_params_none_and_zero(self): - query_str = utils.get_params(a=0, b=None, c="true") - # The order won't be consistent, but that doesn't affect behaviour of a browser - pairs = sorted(query_str.split("&")) - assert ["a=0", "c=true"] == pairs - - def test_params_all(self): - query = utils.get_params(tags=["tag1", "tag2"], status="active", page=3, search="bash_") - assert { - "tags": ["tag1", "tag2"], - "page": ["3"], - "search": ["bash_"], - "status": ["active"], - } == parse_qs(query) - - def test_params_escape(self): - assert "search=%27%3E%22%2F%3E%3Cimg+src%3Dx+onerror%3Dalert%281%29%3E" == utils.get_params( - search="'>\"/>" - ) - - def test_state_token(self): - # It's shouldn't possible to set these odd values anymore, but lets - # ensure they are escaped! - html = str(utils.state_token("")) - - assert "<script>alert(1)</script>" in html - assert "" not in html - - @pytest.mark.db_test - def test_task_instance_link(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str( - utils.task_instance_link( - {"dag_id": "", "task_id": "", "execution_date": datetime.now()} - ) - ) - - assert "%3Ca%261%3E" in html - assert "%3Cb2%3E" in html - assert "" not in html - assert "" not in html - - @pytest.mark.db_test - def test_dag_link(self): - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) - - assert "%3Ca%261%3E" in html - assert "" not in html - - @pytest.mark.db_test - def test_dag_link_when_dag_is_none(self): - """Test that when there is no dag_id, dag_link does not contain hyperlink""" - from airflow.www.app import cached_app - - with cached_app(testing=True).test_request_context(): - html = str(utils.dag_link({})) - - assert "None" in html - assert "
    ", "run_id": "", "execution_date": datetime.now()}) - ) - - assert "%3Ca%261%3E" in html - assert "%3Cb2%3E" in html - assert "" not in html - assert "" not in html - - -class TestAttrRenderer: - def setup_method(self): - self.attr_renderer = utils.get_attr_renderer() - - def test_python_callable(self): - def example_callable(unused_self): - print("example") - - rendered = self.attr_renderer["python_callable"](example_callable) - assert ""example"" in rendered - - def test_python_callable_none(self): - rendered = self.attr_renderer["python_callable"](None) - assert "" == rendered - - def test_markdown(self): - markdown = "* foo\n* bar" - rendered = self.attr_renderer["doc_md"](markdown) - assert "
  • foo
  • " in rendered - assert "
  • bar
  • " in rendered - - def test_markdown_none(self): - rendered = self.attr_renderer["doc_md"](None) - assert rendered is None - - def test_get_dag_run_conf(self): - dag_run_conf = { - "1": "string", - "2": b"bytes", - "3": 123, - "4": "à".encode("latin"), - "5": datetime(2023, 1, 1), - } - expected_encoded_dag_run_conf = ( - '{"1": "string", "2": "bytes", "3": 123, "4": "à", "5": "2023-01-01T00:00:00+00:00"}' - ) - encoded_dag_run_conf, conf_is_json = utils.get_dag_run_conf( - dag_run_conf, json_encoder=utils_json.WebEncoder - ) - assert expected_encoded_dag_run_conf == encoded_dag_run_conf - - def test_json_f_webencoder(self): - dag_run_conf = { - "1": "string", - "2": b"bytes", - "3": 123, - "4": "à".encode("latin"), - "5": datetime(2023, 1, 1), - } - expected_encoded_dag_run_conf = ( - # HTML sanitization is insane - '{"1": "string", "2": "bytes", "3": 123, "4": "\\u00e0", "5": "2023-01-01T00:00:00+00:00"}' - ) - expected_markup = Markup("{}").format(expected_encoded_dag_run_conf) - - formatter = json_f("conf") - dagrun = Mock() - dagrun.get = Mock(return_value=dag_run_conf) - - assert formatter(dagrun) == expected_markup - - -class TestWrappedMarkdown: - def test_wrapped_markdown_with_docstring_curly_braces(self): - rendered = wrapped_markdown("{braces}", css_class="a_class") - assert ( - """

    {braces}

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_some_markdown(self): - rendered = wrapped_markdown( - """*italic* - **bold** - """, - css_class="a_class", - ) - - assert ( - """

    italic -bold

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_table(self): - rendered = wrapped_markdown( - """ -| Job | Duration | -| ----------- | ----------- | -| ETL | 14m | -""" - ) - - assert ( - """
    - - - - - - - - - - - - -
    JobDuration
    ETL14m
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_indented_lines(self): - rendered = wrapped_markdown( - """ - # header - 1st line - 2nd line - """ - ) - - assert ( - """

    header

    \n

    1st line\n2nd line

    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_raw_code_block(self): - rendered = wrapped_markdown( - """\ - # Markdown code block - - Inline `code` works well. - - Code block - does not - respect - newlines - - """ - ) - - assert ( - """

    Markdown code block

    -

    Inline code works well.

    -
    Code block\ndoes not\nrespect\nnewlines\n
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_nested_list(self): - rendered = wrapped_markdown( - """ - ### Docstring with a code block - - - And - - A nested list - """ - ) - - assert ( - """

    Docstring with a code block

    -
      -
    • And -
        -
      • A nested list
      • -
      -
    • -
    -
    """ - == rendered - ) - - def test_wrapped_markdown_with_collapsible_section(self): - with conf_vars({("webserver", "allow_raw_html_descriptions"): "true"}): - rendered = wrapped_markdown( - """ -# A collapsible section with markdown -
    - Click to expand! - - ## Heading - 1. A numbered - 2. list - * With some - * Sub bullets -
    - """ - ) - - assert ( - """

    A collapsible section with markdown

    -
    - Click to expand! -

    Heading

    -
      -
    1. A numbered
    2. -
    3. list -
        -
      • With some
      • -
      • Sub bullets
      • -
      -
    4. -
    -
    -
    """ - == rendered - ) - - @pytest.mark.parametrize("allow_html", [False, True]) - def test_wrapped_markdown_with_raw_html(self, allow_html): - with conf_vars({("webserver", "allow_raw_html_descriptions"): str(allow_html)}): - HTML = "test raw HTML" - rendered = wrapped_markdown(HTML) - if allow_html: - assert HTML in rendered - else: - from markupsafe import escape - - assert escape(HTML) in rendered - - -@pytest.mark.db_test -def test_dag_run_custom_sqla_interface_delete_no_collateral_damage(dag_maker, session): - interface = DagRunCustomSQLAInterface(obj=DagRun, session=session) - dag_ids = (f"test_dag_{x}" for x in range(1, 4)) - dates = (pendulum.datetime(2023, 1, x) for x in range(1, 4)) - for dag_id, date in itertools.product(dag_ids, dates): - with dag_maker(dag_id=dag_id) as dag: - dag.create_dagrun( - execution_date=date, state="running", run_type="scheduled", data_interval=(date, date) - ) - dag_runs = session.query(DagRun).all() - assert len(dag_runs) == 9 - assert len(set(x.run_id for x in dag_runs)) == 3 - run_id_for_single_delete = "scheduled__2023-01-01T00:00:00+00:00" - # we have 3 runs with this same run_id - assert sum(1 for x in dag_runs if x.run_id == run_id_for_single_delete) == 3 - # each is a different dag - - # if we delete one, it shouldn't delete the others - one_run = next(x for x in dag_runs if x.run_id == run_id_for_single_delete) - assert interface.delete(item=one_run) is True - session.commit() - dag_runs = session.query(DagRun).all() - # we should have one fewer dag run now - assert len(dag_runs) == 8 - - # now let's try multi delete - run_id_for_multi_delete = "scheduled__2023-01-02T00:00:00+00:00" - # verify we have 3 - runs_of_interest = [x for x in dag_runs if x.run_id == run_id_for_multi_delete] - assert len(runs_of_interest) == 3 - # and that each is different dag - assert len(set(x.dag_id for x in dag_runs)) == 3 - - to_delete = runs_of_interest[:2] - # now try multi delete - assert interface.delete_all(items=to_delete) is True - session.commit() - dag_runs = session.query(DagRun).all() - assert len(dag_runs) == 6 - assert len(set(x.dag_id for x in dag_runs)) == 3 - assert len(set(x.run_id for x in dag_runs)) == 3 - - -@pytest.fixture -def app(): - from flask import Flask - - app = Flask(__name__) - app.config["WTF_CSRF_ENABLED"] = False - app.config["SECRET_KEY"] = "secret" - with app.app_context(): - yield app - - -class TestWidgets: - def test_airflow_datetime_picker_ro_widget(self, app): - class TestForm(FlaskForm): - datetime_field = StringField(widget=AirflowDateTimePickerROWidget()) - - form = TestForm() - field = form.datetime_field - - html_output = field() - - assert 'readonly="true"' in html_output - assert "input-group datetime datetimepicker" in html_output - - def test_bs3_text_field_ro_widget(self, app): - class TestForm(FlaskForm): - text_field = StringField(widget=BS3TextFieldROWidget()) - - form = TestForm() - field = form.text_field - - html_output = field() - - assert 'readonly="true"' in html_output - assert "form-control" in html_output - - def test_bs3_text_area_ro_widget(self, app): - class TestForm(FlaskForm): - textarea_field = TextAreaField(widget=BS3TextAreaROWidget()) - - form = TestForm() - field = form.textarea_field - - html_output = field() - - assert 'readonly="true"' in html_output - assert "form-control" in html_output From f8accc130781cf270cd9f123fac8f8a80e702aba Mon Sep 17 00:00:00 2001 From: andyjianzhou Date: Wed, 17 Jul 2024 13:59:50 -0600 Subject: [PATCH 20/21] fix pre-commit --- airflow/cli/commands/rotate_fernet_key_command.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/cli/commands/rotate_fernet_key_command.py b/airflow/cli/commands/rotate_fernet_key_command.py index dc2ade361c2279..b95e8f3752cc13 100644 --- a/airflow/cli/commands/rotate_fernet_key_command.py +++ b/airflow/cli/commands/rotate_fernet_key_command.py @@ -46,7 +46,8 @@ def rotate_fernet_key(args): def rotate_items_in_batches_v1(session, model_class, filter_condition=None, batch_size=100): - """Rotates Fernet keys for items of a given model in batches to avoid excessive memory usage. + """ + Rotates Fernet keys for items of a given model in batches to avoid excessive memory usage. This function is a replacement for yield_per, which is not available in SQLAlchemy 1.x. """ @@ -65,7 +66,8 @@ def rotate_items_in_batches_v1(session, model_class, filter_condition=None, batc def rotate_items_in_batches_v2(session, model_class, filter_condition=None, batch_size=100): - """Rotates Fernet keys for items of a given model in batches to avoid excessive memory usage. + """ + Rotates Fernet keys for items of a given model in batches to avoid excessive memory usage. This function is taking advantage of yield_per available in SQLAlchemy 2.x. """ From a6d102a46697a9b9455c2e903a7a0f6afef136c7 Mon Sep 17 00:00:00 2001 From: vboxuser Date: Fri, 19 Jul 2024 13:43:43 -0600 Subject: [PATCH 21/21] added test case for invalid XCom TasKInstance --- tests/models/test_xcom_arg_map.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index b997cf024ebba3..b2e885e9408333 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -212,6 +212,18 @@ def test_task_map_from_task_instance_xcom(): TaskMap.from_task_instance_xcom(ti, value) +def test_task_map_with_invalid_task_instance(): + task = EmptyOperator(task_id="test_task") + ti = TaskInstance(task=task, run_id=None, map_index=0) + ti.dag_id = "test_dag" + + # Define some arbitrary XCom-like value data + value = {"example_key": "example_value"} + + with pytest.raises(ValueError, match="cannot record task map for unrun task instance"): + TaskMap.from_task_instance_xcom(ti, value) + + def test_task_map_variant(): # Test case where keys is None task_map = TaskMap(