From f5262694894c8d72d296fd03f46163123dbc28ae Mon Sep 17 00:00:00 2001 From: Ulada Zakharava Date: Thu, 23 Nov 2023 14:33:45 +0000 Subject: [PATCH] Switch to Connexion 3 framework This is a huge PR being result of over a 100 commits made by a number of people in ##36052 and #37638. It switches to Connexion 3 as the driving backend implementation for both - Airflow REST APIs and Flask app that powers Airflow UI. It should be largely backwards compatible when it comes to behaviour of both APIs and Airflow Webserver views, however due to decisions made by Connexion 3 maintainers, it changes heavily the technology stack used under-the-hood: 1) Connexion App is an ASGI-compatible Open-API spec-first framework using ASGI as an interface between webserver and Python web application. ASGI is an asynchronous successor of WSGI. 2) Connexion itself is using Starlette to run asynchronous web services in Python. 3) We continue using gunicorn appliation server that still uses WSGI standard, which means that we can continue using Flask and we are usig standard Uvicorn ASGI webserver that converts the ASGI interface to WSGI interface of Gunicorn Some of the problems handled in this PR There were two problem was with session handling: * the get_session_cookie - did not get the right cookie - it returned "session" string. The right fix was to change cookie_jar into cookie.jar because this is where apparently TestClient of starlette is holding the cookies (visible when you debug) * The client does not accept "set_cookie" method - it accepts passing cookies via "cookies" dictionary - this is the usual httpx client - see https://www.starlette.io/testclient/ - so we have to set cookie directly in the get method to try it out Add "flask_client_with_login" for tests that neeed flask client Some tests require functionality not available to Starlette test client as they use Flask test client specific features - for those we have an option to get flask test client instead of starlette one. Fix error handling for new connection 3 approach Error handling for Connexion 3 integration needed to be reworked. The way it behaves is much the same as it works in main: * for API errors - we get application/problem+json responses * for UI erros - we have rendered views * for redirection - we have correct location header (it's been missing) * the api error handled was not added as available middleware in the www tests It should fix all test_views_base.py tests which were failing on lack of location header for redirection. Fix wrong response is tests_view_cluster_activity The problem in the test was that Starlette Test Client opens a new connection and start new session, while flask test client uses the same database session. The test did not show data because the data was not committed and session was not closed - which also failed sqlite local tests with "database is locked" error. Fix test_extra_links The tests were failing again because the dagrun created was not committed and session not closed. This worked with flask client that used the same session accidentally but did not work with test client from Starlette. Also it caused "database locked" in sqlite / local tests. Switch to non-deprecated auth manager Fix to test_views_log.py This PR partially fixes sessions and request parameter for test_views_log. Some tests are still failing but for different reasons - to be investigated. Fix views_custom_user_views tests The problem in those tests was that the check in security manager was based on the assumption that the security manager was shared between the client and test flask application - because they were coming from the same flask app. But when we use starlette, the call goes to a new process started and the user is deleted in the database - so the shortcut of checking the security manager did not work. The change is that we are now checking if the user is deleted by calling /users/show (we need a new users READ permission for that) - this way we go to the database and check if the user was indeed deleted. Fix test_task_instance_endpoint tests There were two reasons for the test failed: * when the Job was added to task instance, the task instance was not merged in session, which means that commit did not store the added Job * some of the tests were expecting a call with specific session and they failed because session was different. Replacing the session with mock.ANY tells pytest that this parameter can be anything - we will have different session when when the call will be made with ASGI/Starlette Fix parameter validation * added default value for limit parameter across the board. Connexion 3 does not like if the parameter had no default and we had not provided one - even if our custom decorated was adding it. Adding default value and updating our decorator to treat None as `default` fixed a number of problems where limits were not passed * swapped openapi specification for /datasets/{uri} and /dataset/events. Since `{uri}` was defined first, connection matched `events` with `{uri}` and chose parameter definitions from `{uri}` not events Fix test_log_enpoint tests The problem here was that some sessions should be committed/closed but also in order to run it standalone we wanted to create log templates in the database - as it relied implcitly on log templates created by other tests. Fix test_views_dagrun, test_views_tasks and test_views_log Fixed by switching to use flask client for testing rather than starlette. Starlette client in this case has some side effects that are also impacting Sqlite's session being created in a different thread and deleted with close_all_sessions fixture. Fix test_views_dagrun Fixed by switching to use flask client for testing rather than starlette. Starlette client in this case has some side effects that are also impacting Sqlite's session being created in a different thread and deleted with close_all_sessions fixture. Co-authored-by: sudipto baral Co-authored-by: satoshi-sh Co-authored-by: Maksim Yermakou Co-authored-by: Ulada Zakharava --- .github/workflows/basic-tests.yml | 2 + .../endpoints/connection_endpoint.py | 2 +- .../api_connexion/endpoints/dag_endpoint.py | 2 +- .../endpoints/dag_warning_endpoint.py | 2 +- .../endpoints/dataset_endpoint.py | 6 +- .../endpoints/event_log_endpoint.py | 2 +- .../endpoints/import_error_endpoint.py | 2 +- .../api_connexion/endpoints/log_endpoint.py | 5 +- .../api_connexion/endpoints/pool_endpoint.py | 2 +- .../endpoints/task_instance_endpoint.py | 2 +- airflow/api_connexion/exceptions.py | 55 ++- airflow/api_connexion/openapi/v1.yaml | 61 +-- airflow/api_connexion/parameters.py | 14 +- airflow/auth/managers/base_auth_manager.py | 6 +- airflow/cli/commands/internal_api_command.py | 17 +- airflow/cli/commands/webserver_command.py | 8 +- .../0074_2_0_0_resource_based_permissions.py | 4 +- ...1_remove_can_read_permission_on_config_.py | 4 +- ...resource_based_permissions_for_default_.py | 4 +- .../fab/auth_manager/fab_auth_manager.py | 25 +- airflow/utils/json.py | 3 +- airflow/www/app.py | 40 +- .../www/extensions/init_appbuilder_links.py | 2 +- airflow/www/extensions/init_views.py | 160 ++++---- airflow/www/package.json | 3 +- airflow/www/static/js/types/api-generated.ts | 80 ++-- airflow/www/views.py | 3 +- airflow/www/yarn.lock | 20 +- .../core-concepts/auth-manager.rst | 2 +- docs/apache-airflow/img/airflow_erd.sha256 | 2 +- docs/apache-airflow/img/airflow_erd.svg | 4 +- hatch_build.py | 3 +- newsfragments/37638.significant.rst | 4 + tests/api_connexion/conftest.py | 35 +- .../endpoints/test_config_endpoint.py | 98 ++--- .../endpoints/test_connection_endpoint.py | 164 +++----- .../endpoints/test_dag_endpoint.py | 309 +++++++------- .../endpoints/test_dag_run_endpoint.py | 381 +++++++++--------- .../endpoints/test_dag_source_endpoint.py | 60 ++- .../endpoints/test_dag_warning_endpoint.py | 57 ++- .../endpoints/test_dataset_endpoint.py | 178 ++++---- .../endpoints/test_event_log_endpoint.py | 134 +++--- .../endpoints/test_extra_link_endpoint.py | 61 +-- .../endpoints/test_forward_to_fab_endpoint.py | 63 +-- .../endpoints/test_health_endpoint.py | 14 +- .../endpoints/test_import_error_endpoint.py | 92 ++--- .../endpoints/test_log_endpoint.py | 132 +++--- .../test_mapped_task_instance_endpoint.py | 161 ++++---- .../endpoints/test_plugin_endpoint.py | 56 ++- .../endpoints/test_pool_endpoint.py | 150 ++++--- .../endpoints/test_provider_endpoint.py | 28 +- .../endpoints/test_task_endpoint.py | 79 ++-- .../endpoints/test_task_instance_endpoint.py | 323 +++++++-------- .../endpoints/test_variable_endpoint.py | 124 +++--- .../endpoints/test_version_endpoint.py | 6 +- .../endpoints/test_xcom_endpoint.py | 66 +-- .../schemas/test_dag_run_schema.py | 2 +- .../test_role_and_permission_schema.py | 14 +- tests/api_connexion/test_auth.py | 55 +-- tests/api_connexion/test_cors.py | 45 ++- tests/api_connexion/test_error_handling.py | 14 +- tests/api_connexion/test_security.py | 20 +- .../auth/backend/test_basic_auth.py | 12 +- .../endpoints/test_rpc_api_endpoint.py | 16 +- tests/auth/managers/test_base_auth_manager.py | 5 +- .../cli/commands/test_internal_api_command.py | 5 +- tests/cli/commands/test_webserver_command.py | 8 +- tests/conftest.py | 29 +- .../auth/backend/test_kerberos_auth.py | 15 +- tests/plugins/test_plugins_manager.py | 15 +- .../aws/auth_manager/test_aws_auth_manager.py | 6 +- .../aws/auth_manager/views/test_auth.py | 16 +- .../api/auth/backend/test_basic_auth.py | 6 +- .../test_role_and_permission_endpoint.py | 160 ++++---- .../api_endpoints/test_user_endpoint.py | 208 +++++----- .../api_endpoints/test_user_schema.py | 17 +- tests/providers/fab/auth_manager/conftest.py | 14 +- .../fab/auth_manager/decorators/test_auth.py | 18 +- .../fab/auth_manager/test_security.py | 116 +++--- .../auth_manager/views/test_permissions.py | 4 +- .../fab/auth_manager/views/test_roles_list.py | 8 +- .../fab/auth_manager/views/test_user.py | 8 +- .../fab/auth_manager/views/test_user_edit.py | 8 +- .../fab/auth_manager/views/test_user_stats.py | 10 +- .../common/auth_backend/test_google_openid.py | 16 +- tests/sensors/test_external_task_sensor.py | 4 +- tests/test_utils/api_connexion_utils.py | 2 +- tests/test_utils/decorators.py | 2 +- tests/test_utils/mock_cors_middeleware.py | 35 ++ .../remote_user_api_auth_backend.py | 2 +- tests/test_utils/www.py | 18 +- tests/utils/test_helpers.py | 6 +- tests/www/api/experimental/conftest.py | 8 +- .../experimental/test_dag_runs_endpoint.py | 16 +- tests/www/api/experimental/test_endpoints.py | 91 ++--- tests/www/test_app.py | 30 +- tests/www/test_auth.py | 8 +- tests/www/test_security_manager.py | 2 +- tests/www/test_utils.py | 14 +- tests/www/views/conftest.py | 39 +- .../www/views/test_anonymous_as_admin_role.py | 5 +- tests/www/views/test_session.py | 47 ++- tests/www/views/test_views.py | 21 +- tests/www/views/test_views_acl.py | 80 ++-- tests/www/views/test_views_base.py | 56 +-- tests/www/views/test_views_blocked.py | 2 +- .../www/views/test_views_cluster_activity.py | 6 +- tests/www/views/test_views_connection.py | 4 +- .../www/views/test_views_custom_user_views.py | 64 +-- tests/www/views/test_views_dagrun.py | 82 ++-- tests/www/views/test_views_dataset.py | 40 +- tests/www/views/test_views_extra_links.py | 61 ++- tests/www/views/test_views_grid.py | 48 ++- tests/www/views/test_views_home.py | 18 +- tests/www/views/test_views_log.py | 60 +-- tests/www/views/test_views_mount.py | 4 +- tests/www/views/test_views_paused.py | 8 +- tests/www/views/test_views_pool.py | 2 +- tests/www/views/test_views_rate_limit.py | 20 +- tests/www/views/test_views_rendered.py | 6 +- tests/www/views/test_views_robots.py | 6 +- tests/www/views/test_views_task_norun.py | 4 +- tests/www/views/test_views_tasks.py | 95 +++-- tests/www/views/test_views_trigger_dag.py | 28 +- tests/www/views/test_views_variable.py | 24 +- 125 files changed, 2702 insertions(+), 2558 deletions(-) create mode 100644 newsfragments/37638.significant.rst create mode 100644 tests/test_utils/mock_cors_middeleware.py diff --git a/.github/workflows/basic-tests.yml b/.github/workflows/basic-tests.yml index db84bae38e2e2a..3bf42b1ce815d8 100644 --- a/.github/workflows/basic-tests.yml +++ b/.github/workflows/basic-tests.yml @@ -148,6 +148,8 @@ jobs: env: HATCH_ENV: "test" working-directory: ./clients/python + - name: Compile www assets + run: breeze compile-www-assets - name: "Install Airflow in editable mode with fab for webserver tests" run: pip install -e ".[fab]" - name: "Install Python client" diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index c17a9280d78f84..452ccb42cfbbe3 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -91,7 +91,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API @provide_session def get_connections( *, - limit: int, + limit: int | None = None, offset: int = 0, order_by: str = "id", session: Session = NEW_SESSION, diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 1895bfeaec762a..1efecbbbba5db4 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -94,7 +94,7 @@ def get_dag_details( @provide_session def get_dags( *, - limit: int, + limit: int | None = None, offset: int = 0, tags: Collection[str] | None = None, dag_id_pattern: str | None = None, diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py index d59db8c3d30826..f1eeddf0c81044 100644 --- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -43,7 +43,7 @@ @provide_session def get_dag_warnings( *, - limit: int, + limit: int | None = None, dag_id: str | None = None, warning_type: str | None = None, offset: int | None = None, diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py index bfdb8d0a5e7ee2..bbc91f85eac3ad 100644 --- a/airflow/api_connexion/endpoints/dataset_endpoint.py +++ b/airflow/api_connexion/endpoints/dataset_endpoint.py @@ -82,7 +82,7 @@ def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse: @provide_session def get_datasets( *, - limit: int, + limit: int | None = None, offset: int = 0, uri_pattern: str | None = None, dag_ids: str | None = None, @@ -113,11 +113,11 @@ def get_datasets( @security.requires_access_dataset("GET") -@provide_session @format_parameters({"limit": check_limit}) +@provide_session def get_dataset_events( *, - limit: int, + limit: int | None = None, offset: int = 0, order_by: str = "timestamp", dataset_id: int | None = None, diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py b/airflow/api_connexion/endpoints/event_log_endpoint.py index 3b3dbe6efd4905..23caee37556861 100644 --- a/airflow/api_connexion/endpoints/event_log_endpoint.py +++ b/airflow/api_connexion/endpoints/event_log_endpoint.py @@ -64,7 +64,7 @@ def get_event_logs( included_events: str | None = None, before: str | None = None, after: str | None = None, - limit: int, + limit: int | None = None, offset: int | None = None, order_by: str = "event_log_id", session: Session = NEW_SESSION, diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py index 76b706eac1ae40..b63d0c30115d45 100644 --- a/airflow/api_connexion/endpoints/import_error_endpoint.py +++ b/airflow/api_connexion/endpoints/import_error_endpoint.py @@ -77,7 +77,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> @provide_session def get_import_errors( *, - limit: int, + limit: int | None = None, offset: int | None = None, order_by: str = "import_error_id", session: Session = NEW_SESSION, diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index 239f08ecdaf404..5493b6278d10b6 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -107,7 +107,10 @@ def get_log( logs = logs[0] if task_try_number is not None else logs # we must have token here, so we can safely ignore it token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment] - return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs)) + return Response( + logs_schema.dumps(LogResponseObject(continuation_token=token, content=logs)), + headers={"Content-Type": "application/json"}, + ) # text/plain. Stream logs = task_log_reader.read_log_stream(ti, task_try_number, metadata) diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index 553d50c7464b7a..ef59ed21b6321d 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -68,7 +68,7 @@ def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: @provide_session def get_pools( *, - limit: int, + limit: int | None = None, order_by: str = "id", offset: int | None = None, session: Session = NEW_SESSION, diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index a58aaee86f295c..2302bab0049223 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -296,7 +296,7 @@ def _apply_range_filter(query: Select, key: ClauseElement, value_range: tuple[T, @provide_session def get_task_instances( *, - limit: int, + limit: int | None = None, dag_id: str | None = None, dag_run_id: str | None = None, execution_date_gte: str | None = None, diff --git a/airflow/api_connexion/exceptions.py b/airflow/api_connexion/exceptions.py index 75d9261ef6d444..fa2015a2dea1c4 100644 --- a/airflow/api_connexion/exceptions.py +++ b/airflow/api_connexion/exceptions.py @@ -19,13 +19,12 @@ from http import HTTPStatus from typing import TYPE_CHECKING, Any -import werkzeug -from connexion import FlaskApi, ProblemException, problem +from connexion import ProblemException, problem from airflow.utils.docs import get_docs_url if TYPE_CHECKING: - import flask + from connexion.lifecycle import ConnexionRequest, ConnexionResponse doc_link = get_docs_url("stable-rest-api-ref.html") @@ -40,37 +39,29 @@ } -def common_error_handler(exception: BaseException) -> flask.Response: +def problem_error_handler(_request: ConnexionRequest, exception: ProblemException) -> ConnexionResponse: """Use to capture connexion exceptions and add link to the type field.""" - if isinstance(exception, ProblemException): - link = EXCEPTIONS_LINK_MAP.get(exception.status) - if link: - response = problem( - status=exception.status, - title=exception.title, - detail=exception.detail, - type=link, - instance=exception.instance, - headers=exception.headers, - ext=exception.ext, - ) - else: - response = problem( - status=exception.status, - title=exception.title, - detail=exception.detail, - type=exception.type, - instance=exception.instance, - headers=exception.headers, - ext=exception.ext, - ) + link = EXCEPTIONS_LINK_MAP.get(exception.status) + if link: + return problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=link, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) else: - if not isinstance(exception, werkzeug.exceptions.HTTPException): - exception = werkzeug.exceptions.InternalServerError() - - response = problem(title=exception.name, detail=exception.description, status=exception.code) - - return FlaskApi.get_response(response) + return problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=exception.type, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) class NotFound(ProblemException): diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index b5e3ef72e1c615..66182ebdb2d720 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -1367,6 +1367,10 @@ paths: responses: "204": description: Success. + content: + text/html: + schema: + type: string "400": $ref: "#/components/responses/BadRequest" "401": @@ -1743,6 +1747,10 @@ paths: responses: "204": description: Success. + content: + text/html: + schema: + type: string "400": $ref: "#/components/responses/BadRequest" "401": @@ -1885,8 +1893,8 @@ paths: response = self.client.get( request_url, query_string={"token": token}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain","REMOTE_USER": "test"}, + ) continuation_token = response.json["continuation_token"] metadata = URLSafeSerializer(key).loads(continuation_token) @@ -2020,7 +2028,7 @@ paths: properties: content: type: string - plain/text: + text/plain: schema: type: string @@ -2106,29 +2114,6 @@ paths: "403": $ref: "#/components/responses/PermissionDenied" - /datasets/{uri}: - parameters: - - $ref: "#/components/parameters/DatasetURI" - get: - summary: Get a dataset - description: Get a dataset by uri. - x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint - operationId: get_dataset - tags: [Dataset] - responses: - "200": - description: Success. - content: - application/json: - schema: - $ref: "#/components/schemas/Dataset" - "401": - $ref: "#/components/responses/Unauthenticated" - "403": - $ref: "#/components/responses/PermissionDenied" - "404": - $ref: "#/components/responses/NotFound" - /datasets/events: get: summary: Get dataset events @@ -2186,6 +2171,30 @@ paths: '404': $ref: '#/components/responses/NotFound' + /datasets/{uri}: + parameters: + - $ref: "#/components/parameters/DatasetURI" + get: + summary: Get a dataset + description: Get a dataset by uri. + x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint + operationId: get_dataset + tags: [Dataset] + responses: + "200": + description: Success. + content: + application/json: + schema: + $ref: "#/components/schemas/Dataset" + "401": + $ref: "#/components/responses/Unauthenticated" + "403": + $ref: "#/components/responses/PermissionDenied" + "404": + $ref: "#/components/responses/NotFound" + + /config: get: summary: Get current configuration diff --git a/airflow/api_connexion/parameters.py b/airflow/api_connexion/parameters.py index a05ded37614d47..79e34feecef3d1 100644 --- a/airflow/api_connexion/parameters.py +++ b/airflow/api_connexion/parameters.py @@ -41,7 +41,7 @@ def validate_istimezone(value: datetime) -> None: raise BadRequest("Invalid datetime format", detail="Naive datetime is disallowed") -def format_datetime(value: str) -> datetime: +def format_datetime(value: str | None) -> datetime | None: """ Format datetime objects. @@ -50,6 +50,8 @@ def format_datetime(value: str) -> datetime: This should only be used within connection views because it raises 400 """ + if value is None: + return None value = value.strip() if value[-1] != "Z": value = value.replace(" ", "+") @@ -59,7 +61,7 @@ def format_datetime(value: str) -> datetime: raise BadRequest("Incorrect datetime argument", detail=str(err)) -def check_limit(value: int) -> int: +def check_limit(value: int | None) -> int: """ Check the limit does not exceed configured value. @@ -68,7 +70,8 @@ def check_limit(value: int) -> int: """ max_val = conf.getint("api", "maximum_page_limit") # user configured max page limit fallback = conf.getint("api", "fallback_page_limit") - + if value is None: + return fallback if value > max_val: log.warning( "The limit param value %s passed in API exceeds the configured maximum page limit %s", @@ -99,8 +102,9 @@ def format_parameters_decorator(func: T) -> T: @wraps(func) def wrapped_function(*args, **kwargs): for key, formatter in params_formatters.items(): - if key in kwargs: - kwargs[key] = formatter(kwargs[key]) + value = formatter(kwargs.get(key)) + if value: + kwargs[key] = value return func(*args, **kwargs) return cast(T, wrapped_function) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 4d5c249235a69e..1475f198b3ead8 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -33,7 +33,7 @@ from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: - from flask import Blueprint + import connexion from flask_appbuilder.menu import MenuItem from sqlalchemy.orm import Session @@ -81,8 +81,8 @@ def get_cli_commands() -> list[CLICommand]: """ return [] - def get_api_endpoints(self) -> None | Blueprint: - """Return API endpoint(s) definition for the auth manager.""" + def set_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None: + """Set API endpoint(s) definition for the auth manager.""" return None def get_user_name(self) -> str: diff --git a/airflow/cli/commands/internal_api_command.py b/airflow/cli/commands/internal_api_command.py index 8c25d1fa5ae588..379393bd225844 100644 --- a/airflow/cli/commands/internal_api_command.py +++ b/airflow/cli/commands/internal_api_command.py @@ -29,8 +29,8 @@ from tempfile import gettempdir from time import sleep +import connexion import psutil -from flask import Flask from flask_appbuilder import SQLA from flask_caching import Cache from flask_wtf.csrf import CSRFProtect @@ -55,7 +55,7 @@ from airflow.www.extensions.init_views import init_api_internal, init_error_handlers log = logging.getLogger(__name__) -app: Flask | None = None +app: connexion.FlaskApp | None = None @cli_utils.action_cli @@ -74,8 +74,8 @@ def internal_api(args): log.info("Starting the Internal API server on port %s and host %s.", args.port, args.hostname) app = create_app(testing=conf.getboolean("core", "unit_test_mode")) app.run( - debug=True, # nosec - use_reloader=not app.config["TESTING"], + log_level="debug", + # reload=not app.app.config["TESTING"], port=args.port, host=args.hostname, ) @@ -102,7 +102,7 @@ def internal_api(args): "--workers", str(num_workers), "--worker-class", - str(args.workerclass), + "uvicorn.workers.UvicornWorker", "--timeout", str(worker_timeout), "--bind", @@ -198,7 +198,8 @@ def start_and_monitor_gunicorn(args): def create_app(config=None, testing=False): """Create a new instance of Airflow Internal API app.""" - flask_app = Flask(__name__) + connexion_app = connexion.FlaskApp(__name__) + flask_app = connexion_app.app flask_app.config["APP_NAME"] = "Airflow Internal API" flask_app.config["TESTING"] = testing @@ -243,11 +244,11 @@ def create_app(config=None, testing=False): with flask_app.app_context(): init_error_handlers(flask_app) - init_api_internal(flask_app, standalone_api=True) + init_api_internal(connexion_app, standalone_api=True) init_jinja_globals(flask_app) init_xframe_protection(flask_app) - return flask_app + return connexion_app def cached_app(config=None, testing=False): diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py index 4285564e1fd17f..1906ba55acdebf 100644 --- a/airflow/cli/commands/webserver_command.py +++ b/airflow/cli/commands/webserver_command.py @@ -356,11 +356,11 @@ def webserver(args): print(f"Starting the web server on port {args.port} and host {args.hostname}.") app = create_app(testing=conf.getboolean("core", "unit_test_mode")) app.run( - debug=True, - use_reloader=not app.config["TESTING"], + log_level="debug", port=args.port, host=args.hostname, - ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None, + ssl_keyfile=ssl_key if ssl_cert and ssl_key else None, + ssl_certfile=ssl_cert if ssl_cert and ssl_key else None, ) else: print( @@ -384,7 +384,7 @@ def webserver(args): "--workers", str(num_workers), "--worker-class", - str(args.workerclass), + "uvicorn.workers.UvicornWorker", "--timeout", str(worker_timeout), "--bind", diff --git a/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py b/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py index 1748ca3d5f3aa1..175f5ad380f91f 100644 --- a/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py +++ b/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py @@ -288,7 +288,7 @@ def remap_permissions(): """Apply Map Airflow permissions.""" - appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder for old, new in mapping.items(): (old_resource_name, old_action_name) = old old_permission = appbuilder.sm.get_permission(old_action_name, old_resource_name) @@ -313,7 +313,7 @@ def remap_permissions(): def undo_remap_permissions(): """Unapply Map Airflow permissions""" - appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder for old, new in mapping.items(): (new_resource_name, new_action_name) = new[0] new_permission = appbuilder.sm.get_permission(new_action_name, new_resource_name) diff --git a/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py b/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py index b9bc66d01e094f..33fbcfbf37db13 100644 --- a/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py +++ b/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py @@ -42,7 +42,7 @@ def upgrade(): log = logging.getLogger() handlers = log.handlers[:] - appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder roles_to_modify = [role for role in appbuilder.sm.get_all_roles() if role.name in ["User", "Viewer"]] can_read_on_config_perm = appbuilder.sm.get_permission( permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG @@ -59,7 +59,7 @@ def upgrade(): def downgrade(): """Add can_read action on config resource for User and Viewer role""" - appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder roles_to_modify = [role for role in appbuilder.sm.get_all_roles() if role.name in ["User", "Viewer"]] can_read_on_config_perm = appbuilder.sm.get_permission( permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG diff --git a/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py b/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py index f5e8706c09d54d..c3f1003cafb886 100644 --- a/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py +++ b/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py @@ -140,7 +140,7 @@ def remap_permissions(): """Apply Map Airflow permissions.""" - appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder for old, new in mapping.items(): (old_resource_name, old_action_name) = old old_permission = appbuilder.sm.get_permission(old_action_name, old_resource_name) @@ -165,7 +165,7 @@ def remap_permissions(): def undo_remap_permissions(): """Unapply Map Airflow permissions""" - appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder + appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).app.appbuilder for old, new in mapping.items(): (new_resource_name, new_action_name) = new[0] new_permission = appbuilder.sm.get_permission(new_action_name, new_resource_name) diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py b/airflow/providers/fab/auth_manager/fab_auth_manager.py index d01b3526bf204c..05d18db0bdc28c 100644 --- a/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -22,8 +22,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Container -from connexion import FlaskApi -from flask import Blueprint, url_for +from connexion.options import SwaggerUIOptions +from flask import url_for from sqlalchemy import select from sqlalchemy.orm import Session, joinedload @@ -82,10 +82,12 @@ ) from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.yaml import safe_load -from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED -from airflow.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver +from airflow.www.constants import SWAGGER_BUNDLE +from airflow.www.extensions.init_views import _LazyResolver if TYPE_CHECKING: + import connexion + from airflow.auth.managers.models.base_user import BaseUser from airflow.cli.cli_config import ( CLICommand, @@ -147,19 +149,24 @@ def get_cli_commands() -> list[CLICommand]: SYNC_PERM_COMMAND, # not in a command group ] - def get_api_endpoints(self) -> None | Blueprint: + def set_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None: folder = Path(__file__).parents[0].resolve() # this is airflow/auth/managers/fab/ with folder.joinpath("openapi", "v1.yaml").open() as f: specification = safe_load(f) - return FlaskApi( + + swagger_ui_options = SwaggerUIOptions( + swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True), + swagger_ui_template_dir=SWAGGER_BUNDLE, + ) + + connexion_app.add_api( specification=specification, resolver=_LazyResolver(), base_path="/auth/fab/v1", - options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, + swagger_ui_options=swagger_ui_options, strict_validation=True, validate_responses=True, - validator_map={"body": _CustomErrorRequestBodyValidator}, - ).blueprint + ) def get_user_display_name(self) -> str: """Return the user's display name associated to the user in session.""" diff --git a/airflow/utils/json.py b/airflow/utils/json.py index 4d89e340c1cd48..2540edf9a0cbb9 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -37,7 +37,8 @@ class AirflowJsonProvider(JSONProvider): def dumps(self, obj, **kwargs): kwargs.setdefault("ensure_ascii", self.ensure_ascii) kwargs.setdefault("sort_keys", self.sort_keys) - return json.dumps(obj, **kwargs, cls=WebEncoder) + kwargs.setdefault("cls", WebEncoder) + return json.dumps(obj, **kwargs) def loads(self, s: str | bytes, **kwargs): return json.loads(s, **kwargs) diff --git a/airflow/www/app.py b/airflow/www/app.py index 50e1ba2629786c..dc142127e2886d 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -20,7 +20,8 @@ import warnings from datetime import timedelta -from flask import Flask +import connexion +from flask import request from flask_appbuilder import SQLA from flask_wtf.csrf import CSRFProtect from markupsafe import Markup @@ -49,19 +50,20 @@ ) from airflow.www.extensions.init_session import init_airflow_session_interface from airflow.www.extensions.init_views import ( - init_api_auth_provider, + init_api_auth_manager, init_api_connexion, init_api_error_handlers, init_api_experimental, init_api_internal, init_appbuilder_views, + init_cors_middleware, init_error_handlers, init_flash_views, init_plugins, ) from airflow.www.extensions.init_wsgi_middlewares import init_wsgi_middleware -app: Flask | None = None +app: connexion.FlaskApp | None = None # Initializes at the module level, so plugins can access it. # See: /docs/plugins.rst @@ -70,7 +72,20 @@ def create_app(config=None, testing=False): """Create a new instance of Airflow WWW app.""" - flask_app = Flask(__name__) + connexion_app = connexion.FlaskApp(__name__) + + @connexion_app.app.before_request + def before_request(): + """Exempts the view function associated with '/api/v1' requests from CSRF protection.""" + if request.path.startswith("/api/v1"): # TODO: make sure this path is correct + view_function = connexion_app.app.view_functions.get(request.endpoint) + if view_function: + # Exempt the view function from CSRF protection + connexion_app.app.extensions["csrf"].exempt(view_function) + + init_cors_middleware(connexion_app) + + flask_app = connexion_app.app flask_app.secret_key = conf.get("webserver", "SECRET_KEY") flask_app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(minutes=settings.get_session_lifetime_config()) @@ -163,14 +178,16 @@ def create_app(config=None, testing=False): init_appbuilder_links(flask_app) init_plugins(flask_app) init_error_handlers(flask_app) - init_api_connexion(flask_app) + init_api_connexion(connexion_app) if conf.getboolean("webserver", "run_internal_api", fallback=False): if not _ENABLE_AIP_44: raise RuntimeError("The AIP_44 is not enabled so you cannot use it.") - init_api_internal(flask_app) + init_api_internal(connexion_app) init_api_experimental(flask_app) - init_api_auth_provider(flask_app) - init_api_error_handlers(flask_app) # needs to be after all api inits to let them add their path first + init_api_auth_manager(connexion_app) + init_api_error_handlers( + connexion_app + ) # needs to be after all api inits to let them add their path first get_auth_manager().init() @@ -178,7 +195,7 @@ def create_app(config=None, testing=False): init_xframe_protection(flask_app) init_airflow_session_interface(flask_app) init_check_user_active(flask_app) - return flask_app + return connexion_app def cached_app(config=None, testing=False): @@ -193,3 +210,8 @@ def purge_cached_app(): """Remove the cached version of the app in global state.""" global app app = None + + +def cached_flask_app(config=None, testing=False): + """Return flask app from connexion_app.""" + return cached_app(config=config, testing=testing).app diff --git a/airflow/www/extensions/init_appbuilder_links.py b/airflow/www/extensions/init_appbuilder_links.py index 0d2f4e13e92935..933fdd42393336 100644 --- a/airflow/www/extensions/init_appbuilder_links.py +++ b/airflow/www/extensions/init_appbuilder_links.py @@ -53,7 +53,7 @@ def init_appbuilder_links(app): appbuilder.add_link( name=RESOURCE_DOCS, label="REST API Reference (Swagger UI)", - href="/api/v1./api/v1_swagger_ui_index", + href="/api/v1/ui", category=RESOURCE_DOCS_MENU, ) appbuilder.add_link( diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py index bf6cdfdcfe84cb..3d639a47b1f601 100644 --- a/airflow/www/extensions/init_views.py +++ b/airflow/www/extensions/init_views.py @@ -22,12 +22,13 @@ from pathlib import Path from typing import TYPE_CHECKING -from connexion import FlaskApi, ProblemException, Resolver -from connexion.decorators.validation import RequestBodyValidator -from connexion.exceptions import BadRequestProblem -from flask import request +import connexion +import starlette.exceptions +from connexion import ProblemException, Resolver +from connexion.options import SwaggerUIOptions +from connexion.problem import problem -from airflow.api_connexion.exceptions import common_error_handler +from airflow.api_connexion.exceptions import problem_error_handler from airflow.configuration import conf from airflow.exceptions import RemovedInAirflow3Warning from airflow.security import permissions @@ -36,6 +37,8 @@ from airflow.www.extensions.init_auth_manager import get_auth_manager if TYPE_CHECKING: + import starlette.exceptions + from connexion.lifecycle import ConnexionRequest, ConnexionResponse from flask import Flask log = logging.getLogger(__name__) @@ -167,26 +170,6 @@ def init_error_handlers(app: Flask): from airflow.www import views app.register_error_handler(500, views.show_traceback) - app.register_error_handler(404, views.not_found) - - -def set_cors_headers_on_response(response): - """Add response headers.""" - allow_headers = conf.get("api", "access_control_allow_headers") - allow_methods = conf.get("api", "access_control_allow_methods") - allow_origins = conf.get("api", "access_control_allow_origins") - if allow_headers: - response.headers["Access-Control-Allow-Headers"] = allow_headers - if allow_methods: - response.headers["Access-Control-Allow-Methods"] = allow_methods - if allow_origins == "*": - response.headers["Access-Control-Allow-Origin"] = "*" - elif allow_origins: - allowed_origins = allow_origins.split(" ") - origin = request.environ.get("HTTP_ORIGIN", allowed_origins[0]) - if origin in allowed_origins: - response.headers["Access-Control-Allow-Origin"] = origin - return response class _LazyResolution: @@ -220,71 +203,81 @@ def resolve(self, operation): return _LazyResolution(self.resolve_function_from_operation_id, operation_id) -class _CustomErrorRequestBodyValidator(RequestBodyValidator): - """Custom request body validator that overrides error messages. - - By default, Connextion emits a very generic *None is not of type 'object'* - error when receiving an empty request body (with the view specifying the - body as non-nullable). We overrides it to provide a more useful message. - """ - - def validate_schema(self, data, url): - if not self.is_null_value_valid and data is None: - raise BadRequestProblem(detail="Request body must not be empty") - return super().validate_schema(data, url) +base_paths: list[str] = ["/auth/fab/v1"] # contains the list of base paths that have api endpoints -base_paths: list[str] = [] # contains the list of base paths that have api endpoints - - -def init_api_error_handlers(app: Flask) -> None: +def init_api_error_handlers(connexion_app: connexion.FlaskApp) -> None: """Add error handlers for 404 and 405 errors for existing API paths.""" from airflow.www import views - @app.errorhandler(404) - def _handle_api_not_found(ex): + def _handle_api_not_found(error) -> ConnexionResponse | str: + from flask.globals import request + if any([request.path.startswith(p) for p in base_paths]): # 404 errors are never handled on the blueprint level # unless raised from a view func so actual 404 errors, # i.e. "no route for it" defined, need to be handled # here on the application level - return common_error_handler(ex) - else: - return views.not_found(ex) - - @app.errorhandler(405) - def _handle_method_not_allowed(ex): - if any([request.path.startswith(p) for p in base_paths]): - return common_error_handler(ex) - else: - return views.method_not_allowed(ex) - - app.register_error_handler(ProblemException, common_error_handler) + return connexion_app._http_exception(error) + return views.not_found(error) + def _handle_api_method_not_allowed(error) -> ConnexionResponse | str: + from flask.globals import request -def init_api_connexion(app: Flask) -> None: + if any([request.path.startswith(p) for p in base_paths]): + return connexion_app._http_exception(error) + return views.method_not_allowed(error) + + def _handle_redirect( + request: ConnexionRequest, ex: starlette.exceptions.HTTPException + ) -> ConnexionResponse: + return problem( + title=connexion.http_facts.HTTP_STATUS_CODES.get(ex.status_code), + detail=ex.detail, + headers={"Location": ex.detail}, + status=ex.status_code, + ) + + # in case of 404 and 405 we handle errors at the Flask APP level in order to have access to + # context and be able to render the error page for the UI + connexion_app.app.register_error_handler(404, _handle_api_not_found) + connexion_app.app.register_error_handler(405, _handle_api_method_not_allowed) + + # We should handle redirects at connexion_app level - the requests will be redirected to the target + # location - so they can return application/problem+json response with the Location header regardless + # ot the request path - does not matter if it is API or UI request + connexion_app.add_error_handler(301, _handle_redirect) + connexion_app.add_error_handler(302, _handle_redirect) + connexion_app.add_error_handler(307, _handle_redirect) + connexion_app.add_error_handler(308, _handle_redirect) + + # Everything else we handle at the connexion_app level by default error handler + connexion_app.add_error_handler(ProblemException, problem_error_handler) + + +def init_api_connexion(connexion_app: connexion.FlaskApp) -> None: """Initialize Stable API.""" base_path = "/api/v1" base_paths.append(base_path) with ROOT_APP_DIR.joinpath("api_connexion", "openapi", "v1.yaml").open() as f: specification = safe_load(f) - api_bp = FlaskApi( + swagger_ui_options = SwaggerUIOptions( + swagger_ui=SWAGGER_ENABLED, + swagger_ui_template_dir=SWAGGER_BUNDLE, + ) + + connexion_app.add_api( specification=specification, resolver=_LazyResolver(), base_path=base_path, - options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, + swagger_ui_options=swagger_ui_options, strict_validation=True, validate_responses=True, - validator_map={"body": _CustomErrorRequestBodyValidator}, - ).blueprint - api_bp.after_request(set_cors_headers_on_response) - - app.register_blueprint(api_bp) - app.extensions["csrf"].exempt(api_bp) + ) -def init_api_internal(app: Flask, standalone_api: bool = False) -> None: +def init_api_internal(connexion_app: connexion.FlaskApp, standalone_api: bool = False) -> None: """Initialize Internal API.""" if not standalone_api and not conf.getboolean("webserver", "run_internal_api", fallback=False): return @@ -292,18 +285,18 @@ def init_api_internal(app: Flask, standalone_api: bool = False) -> None: base_paths.append("/internal_api/v1") with ROOT_APP_DIR.joinpath("api_internal", "openapi", "internal_api_v1.yaml").open() as f: specification = safe_load(f) - api_bp = FlaskApi( + swagger_ui_options = SwaggerUIOptions( + swagger_ui=SWAGGER_ENABLED, + swagger_ui_template_dir=SWAGGER_BUNDLE, + ) + + connexion_app.add_api( specification=specification, base_path="/internal_api/v1", - options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, + swagger_ui_options=swagger_ui_options, strict_validation=True, validate_responses=True, - ).blueprint - api_bp.after_request(set_cors_headers_on_response) - - app.register_blueprint(api_bp) - app.after_request_funcs.setdefault(api_bp.name, []).append(set_cors_headers_on_response) - app.extensions["csrf"].exempt(api_bp) + ) def init_api_experimental(app): @@ -324,11 +317,20 @@ def init_api_experimental(app): app.extensions["csrf"].exempt(endpoints.api_experimental) -def init_api_auth_provider(app): +def init_api_auth_manager(connexion_app: connexion.FlaskApp): """Initialize the API offered by the auth manager.""" auth_mgr = get_auth_manager() - blueprint = auth_mgr.get_api_endpoints() - if blueprint: - base_paths.append(blueprint.url_prefix) - app.register_blueprint(blueprint) - app.extensions["csrf"].exempt(blueprint) + auth_mgr.set_api_endpoints(connexion_app) + + +def init_cors_middleware(connexion_app: connexion.FlaskApp): + from starlette.middleware.cors import CORSMiddleware + + connexion_app.add_middleware( + CORSMiddleware, + connexion.middleware.MiddlewarePosition.BEFORE_ROUTING, + allow_origins=conf.get("api", "access_control_allow_origins"), + allow_credentials=True, + allow_methods=conf.get("api", "access_control_allow_methods"), + allow_headers=conf.get("api", "access_control_allow_headers"), + ) diff --git a/airflow/www/package.json b/airflow/www/package.json index 22b6f882d3ed76..2699d49b9f5df8 100644 --- a/airflow/www/package.json +++ b/airflow/www/package.json @@ -141,7 +141,8 @@ "reactflow": "^11.7.4", "redoc": "^2.0.0-rc.72", "remark-gfm": "^3.0.1", - "swagger-ui-dist": "4.1.3", + "sanitize-html": "^2.12.1", + "swagger-ui-dist": "5.11.8", "tsconfig-paths": "^3.14.2", "type-fest": "^2.17.0", "url-search-params-polyfill": "^8.1.0", diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts index b8da89e55604b1..657d339c95903f 100644 --- a/airflow/www/static/js/types/api-generated.ts +++ b/airflow/www/static/js/types/api-generated.ts @@ -572,8 +572,8 @@ export interface paths { * response = self.client.get( * request_url, * query_string={"token": token}, - * headers={"Accept": "text/plain"}, - * environ_overrides={"REMOTE_USER": "test"}, + * headers={"Accept": "text/plain","REMOTE_USER": "test"}, + * * ) * continuation_token = response.json["continuation_token"] * metadata = URLSafeSerializer(key).loads(continuation_token) @@ -671,6 +671,12 @@ export interface paths { "/datasets": { get: operations["get_datasets"]; }; + "/datasets/events": { + /** Get dataset events */ + get: operations["get_dataset_events"]; + /** Create dataset event */ + post: operations["create_dataset_event"]; + }; "/datasets/{uri}": { /** Get a dataset by uri. */ get: operations["get_dataset"]; @@ -681,12 +687,6 @@ export interface paths { }; }; }; - "/datasets/events": { - /** Get dataset events */ - get: operations["get_dataset_events"]; - /** Create dataset event */ - post: operations["create_dataset_event"]; - }; "/config": { get: operations["get_config"]; }; @@ -3681,7 +3681,11 @@ export interface operations { }; responses: { /** Success. */ - 204: never; + 204: { + content: { + "text/html": string; + }; + }; 400: components["responses"]["BadRequest"]; 401: components["responses"]["Unauthenticated"]; 403: components["responses"]["PermissionDenied"]; @@ -4165,7 +4169,11 @@ export interface operations { }; responses: { /** Success. */ - 204: never; + 204: { + content: { + "text/html": string; + }; + }; 400: components["responses"]["BadRequest"]; 401: components["responses"]["Unauthenticated"]; 403: components["responses"]["PermissionDenied"]; @@ -4320,8 +4328,8 @@ export interface operations { * response = self.client.get( * request_url, * query_string={"token": token}, - * headers={"Accept": "text/plain"}, - * environ_overrides={"REMOTE_USER": "test"}, + * headers={"Accept": "text/plain","REMOTE_USER": "test"}, + * * ) * continuation_token = response.json["continuation_token"] * metadata = URLSafeSerializer(key).loads(continuation_token) @@ -4468,7 +4476,7 @@ export interface operations { "application/json": { content?: string; }; - "plain/text": string; + "text/plain": string; }; }; 401: components["responses"]["Unauthenticated"]; @@ -4543,26 +4551,6 @@ export interface operations { 403: components["responses"]["PermissionDenied"]; }; }; - /** Get a dataset by uri. */ - get_dataset: { - parameters: { - path: { - /** The encoded Dataset URI */ - uri: components["parameters"]["DatasetURI"]; - }; - }; - responses: { - /** Success. */ - 200: { - content: { - "application/json": components["schemas"]["Dataset"]; - }; - }; - 401: components["responses"]["Unauthenticated"]; - 403: components["responses"]["PermissionDenied"]; - 404: components["responses"]["NotFound"]; - }; - }; /** Get dataset events */ get_dataset_events: { parameters: { @@ -4622,6 +4610,26 @@ export interface operations { }; }; }; + /** Get a dataset by uri. */ + get_dataset: { + parameters: { + path: { + /** The encoded Dataset URI */ + uri: components["parameters"]["DatasetURI"]; + }; + }; + responses: { + /** Success. */ + 200: { + content: { + "application/json": components["schemas"]["Dataset"]; + }; + }; + 401: components["responses"]["Unauthenticated"]; + 403: components["responses"]["PermissionDenied"]; + 404: components["responses"]["NotFound"]; + }; + }; get_config: { parameters: { query: { @@ -5502,15 +5510,15 @@ export type GetDagWarningsVariables = CamelCasedPropertiesDeep< export type GetDatasetsVariables = CamelCasedPropertiesDeep< operations["get_datasets"]["parameters"]["query"] >; -export type GetDatasetVariables = CamelCasedPropertiesDeep< - operations["get_dataset"]["parameters"]["path"] ->; export type GetDatasetEventsVariables = CamelCasedPropertiesDeep< operations["get_dataset_events"]["parameters"]["query"] >; export type CreateDatasetEventVariables = CamelCasedPropertiesDeep< operations["create_dataset_event"]["requestBody"]["content"]["application/json"] >; +export type GetDatasetVariables = CamelCasedPropertiesDeep< + operations["get_dataset"]["parameters"]["path"] +>; export type GetConfigVariables = CamelCasedPropertiesDeep< operations["get_config"]["parameters"]["query"] >; diff --git a/airflow/www/views.py b/airflow/www/views.py index 328312658bd185..0a60385d4f4ba4 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3223,7 +3223,6 @@ def historical_metrics_data(self): """Return cluster activity historical metrics.""" start_date = _safe_parse_datetime(request.args.get("start_date")) end_date = _safe_parse_datetime(request.args.get("end_date")) - with create_session() as session: # DagRuns dag_run_types = session.execute( @@ -3575,7 +3574,7 @@ class RedocView(AirflowBaseView): @expose("/redoc") def redoc(self): """Redoc API documentation.""" - openapi_spec_url = url_for("/api/v1./api/v1_openapi_yaml") + openapi_spec_url = "/api/v1/openapi.yaml" return self.render_template("airflow/redoc.html", openapi_spec_url=openapi_spec_url) diff --git a/airflow/www/yarn.lock b/airflow/www/yarn.lock index b4ec5af7a21cef..097855d3b2fb8c 100644 --- a/airflow/www/yarn.lock +++ b/airflow/www/yarn.lock @@ -10386,6 +10386,18 @@ safe-regex-test@^1.0.0: resolved "https://registry.yarnpkg.com/safer-buffer/-/safer-buffer-2.1.2.tgz#44fa161b0187b9549dd84bb91802f9bd8385cd6a" integrity sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg== +sanitize-html@^2.12.1: + version "2.12.1" + resolved "https://registry.yarnpkg.com/sanitize-html/-/sanitize-html-2.12.1.tgz#280a0f5c37305222921f6f9d605be1f6558914c7" + integrity sha512-Plh+JAn0UVDpBRP/xEjsk+xDCoOvMBwQUf/K+/cBAVuTbtX8bj2VB7S1sL1dssVpykqp0/KPSesHrqXtokVBpA== + dependencies: + deepmerge "^4.2.2" + escape-string-regexp "^4.0.0" + htmlparser2 "^8.0.0" + is-plain-object "^5.0.0" + parse-srcset "^1.0.2" + postcss "^8.3.11" + sax@^1.2.4: version "1.2.4" resolved "https://registry.yarnpkg.com/sax/-/sax-1.2.4.tgz#2816234e2378bddc4e5354fab5caa895df7100d9" @@ -11022,10 +11034,10 @@ svgo@^2.7.0: picocolors "^1.0.0" stable "^0.1.8" -swagger-ui-dist@4.1.3: - version "4.1.3" - resolved "https://registry.yarnpkg.com/swagger-ui-dist/-/swagger-ui-dist-4.1.3.tgz#2be9f9de9b5c19132fa4a5e40933058c151563dc" - integrity sha512-WvfPSfAAMlE/sKS6YkW47nX/hA7StmhYnAHc6wWCXNL0oclwLj6UXv0hQCkLnDgvebi0MEV40SJJpVjKUgH1IQ== +swagger-ui-dist@5.11.8: + version "5.11.8" + resolved "https://registry.yarnpkg.com/swagger-ui-dist/-/swagger-ui-dist-5.11.8.tgz#5f92f1f4ca979a5df847da5df180c8b10ccc3e0c" + integrity sha512-IfPtCPdf6opT5HXrzHO4kjL1eco0/8xJCtcs7ilhKuzatrpF2j9s+3QbOag6G3mVFKf+g+Ca5UG9DquVUs2obA== swagger2openapi@^7.0.6: version "7.0.6" diff --git a/docs/apache-airflow/core-concepts/auth-manager.rst b/docs/apache-airflow/core-concepts/auth-manager.rst index aaead4a2b3aa4a..9edb51e14991fe 100644 --- a/docs/apache-airflow/core-concepts/auth-manager.rst +++ b/docs/apache-airflow/core-concepts/auth-manager.rst @@ -163,7 +163,7 @@ Auth managers may vend CLI commands which will be included in the ``airflow`` co Rest API ^^^^^^^^ -Auth managers may vend Rest API endpoints which will be included in the :doc:`/stable-rest-api-ref` by implementing the ``get_api_endpoints`` method. The endpoints can be used to manage resources such as users, groups, roles (if any) handled by your auth manager. Endpoints are only vended for the currently configured auth manager. +Auth managers may vend Rest API endpoints which will be included in the :doc:`/stable-rest-api-ref` by implementing the ``set_api_endpoints`` method. The endpoints can be used to manage resources such as users, groups, roles (if any) handled by your auth manager. Endpoints are only vended for the currently configured auth manager. Next Steps ^^^^^^^^^^ diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index 8947b7e6315981..768ed9173c6158 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -072fb4b43a86ccb57765ec3f163350519773be83ab38b7ac747d25e1197233e8 \ No newline at end of file +ecc9e116e1692b948b7e7e26645ce055edc5385bc600b6126d904565a6a6af04 diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index bf4c6c94906a06..019ff78714fa6e 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -1421,7 +1421,7 @@ task_instance--xcom -0..N +1 1 @@ -1442,7 +1442,7 @@ task_instance--xcom -1 +0..N 1 diff --git a/hatch_build.py b/hatch_build.py index 7fddae5bb25454..f520f76d4ed89d 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -427,7 +427,7 @@ # The usage was added in #30596, seemingly only to override and improve the default error message. # Either revert that change or find another way, preferably without using connexion internals. # This limit can be removed after https://github.com/apache/airflow/issues/35234 is fixed - "connexion[flask]>=2.10.0,<3.0", + "connexion[flask,uvicorn]>=3.0", "cron-descriptor>=1.2.24", "croniter>=2.0.2", "cryptography>=39.0.0", @@ -486,6 +486,7 @@ # The issue tracking it is https://github.com/apache/airflow/issues/28723 "sqlalchemy>=1.4.36,<2.0", "sqlalchemy-jsonfield>=1.0", + "starlette>=0.37.1", "tabulate>=0.7.5", "tenacity>=6.2.0,!=8.2.0", "termcolor>=1.1.0", diff --git a/newsfragments/37638.significant.rst b/newsfragments/37638.significant.rst new file mode 100644 index 00000000000000..7e498df5bb617b --- /dev/null +++ b/newsfragments/37638.significant.rst @@ -0,0 +1,4 @@ +Replaced test_should_respond_400_on_invalid_request with test_ignore_read_only_fields in the test_dag_endpoint.py. + +Connexion V3 request body validator doesn't raise the read-only property error and just ignore the read-only field. +You can find the detail about the change `here `_ diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index abd09fa1c02ecd..8479863c824597 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -21,6 +21,7 @@ from airflow.www import app from tests.test_utils.config import conf_vars from tests.test_utils.decorators import dont_initialize_flask_app_submodules +from tests.test_utils.mock_cors_middeleware import init_mock_cors_middleware @pytest.fixture(scope="session") @@ -30,6 +31,32 @@ def minimal_app_for_api(): "init_appbuilder", "init_api_experimental_auth", "init_api_connexion", + "init_jinja_globals", + "init_api_error_handlers", + "init_airflow_session_interface", + "init_appbuilder_views", + ] + ) + def factory(): + with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): + _app = app.create_app( + testing=True, + config={"WTF_CSRF_ENABLED": False, "AUTH_ROLE_PUBLIC": None}, + ) # type:ignore + init_mock_cors_middleware(_app, allow_origins=["http://apache.org", "http://example.com"]) + return _app + + return factory() + + +@pytest.fixture(scope="session") +def minimal_app_for_api_cors_allow_all(): + @dont_initialize_flask_app_submodules( + skip_all_except=[ + "init_appbuilder", + "init_api_experimental_auth", + "init_api_connexion", + "init_jinja_globals", "init_api_error_handlers", "init_airflow_session_interface", "init_appbuilder_views", @@ -38,7 +65,7 @@ def minimal_app_for_api(): def factory(): with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore - _app.config["AUTH_ROLE_PUBLIC"] = None + init_mock_cors_middleware(_app, allow_origins=["*"]) return _app return factory() @@ -63,9 +90,9 @@ def dagbag(): @pytest.fixture def set_auto_role_public(request): app = request.getfixturevalue("minimal_app_for_api") - auto_role_public = app.config["AUTH_ROLE_PUBLIC"] - app.config["AUTH_ROLE_PUBLIC"] = request.param + auto_role_public = app.app.config["AUTH_ROLE_PUBLIC"] + app.app.config["AUTH_ROLE_PUBLIC"] = request.param yield - app.config["AUTH_ROLE_PUBLIC"] = auto_role_public + app.app.config["AUTH_ROLE_PUBLIC"] = auto_role_public diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py index 3dd5814e5d79e4..8384f8fb061cf7 100644 --- a/tests/api_connexion/endpoints/test_config_endpoint.py +++ b/tests/api_connexion/endpoints/test_config_endpoint.py @@ -22,7 +22,7 @@ import pytest from airflow.security import permissions -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -49,33 +49,31 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore with conf_vars({("webserver", "expose_config"): "True"}): yield minimal_app_for_api - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestGetConfig: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_text_plain(self, mock_as_dict): - response = self.client.get( - "/api/v1/config", headers={"Accept": "text/plain"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/config", headers={"Accept": "text/plain", "REMOTE_USER": "test"}) mock_as_dict.assert_called_with(display_source=False, display_sensitive=True) assert response.status_code == 200 expected = textwrap.dedent( @@ -88,14 +86,12 @@ def test_should_respond_200_text_plain(self, mock_as_dict): smtp_mail_from = airflow@example.com """ ) - assert expected == response.data.decode() + assert expected == response.text @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) @conf_vars({("webserver", "expose_config"): "non-sensitive-only"}) def test_should_respond_200_text_plain_with_non_sensitive_only(self, mock_as_dict): - response = self.client.get( - "/api/v1/config", headers={"Accept": "text/plain"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/config", headers={"Accept": "text/plain", "REMOTE_USER": "test"}) mock_as_dict.assert_called_with(display_source=False, display_sensitive=False) assert response.status_code == 200 expected = textwrap.dedent( @@ -108,14 +104,13 @@ def test_should_respond_200_text_plain_with_non_sensitive_only(self, mock_as_dic smtp_mail_from = airflow@example.com """ ) - assert expected == response.data.decode() + assert expected == response.text @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_application_json(self, mock_as_dict): response = self.client.get( "/api/v1/config", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) mock_as_dict.assert_called_with(display_source=False, display_sensitive=True) assert response.status_code == 200 @@ -136,14 +131,13 @@ def test_should_respond_200_application_json(self, mock_as_dict): }, ] } - assert expected == response.json + assert response.json() == expected @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_single_section_as_text_plain(self, mock_as_dict): response = self.client.get( "/api/v1/config?section=smtp", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) mock_as_dict.assert_called_with(display_source=False, display_sensitive=True) assert response.status_code == 200 @@ -154,14 +148,13 @@ def test_should_respond_200_single_section_as_text_plain(self, mock_as_dict): smtp_mail_from = airflow@example.com """ ) - assert expected == response.data.decode() + assert expected == response.text @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_single_section_as_json(self, mock_as_dict): response = self.client.get( "/api/v1/config?section=smtp", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) mock_as_dict.assert_called_with(display_source=False, display_sensitive=True) assert response.status_code == 200 @@ -176,38 +169,35 @@ def test_should_respond_200_single_section_as_json(self, mock_as_dict): }, ] } - assert expected == response.json + assert expected == response.json() @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_404_when_section_not_exist(self, mock_as_dict): response = self.client.get( "/api/v1/config?section=smtp1", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert "section=smtp1 not found." in response.json["detail"] + assert "section=smtp1 not found." in response.json()["detail"] @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_406(self, mock_as_dict): response = self.client.get( "/api/v1/config", - headers={"Accept": "application/octet-stream"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/octet-stream", "REMOTE_USER": "test"}, ) assert response.status_code == 406 def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/config", headers={"Accept": "application/json"}) - assert_401(response) + assert response.status_code == 401 def test_should_raises_403_unauthorized(self): response = self.client.get( "/api/v1/config", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"Accept": "application/json", "REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -216,11 +206,10 @@ def test_should_raises_403_unauthorized(self): def test_should_respond_403_when_expose_config_off(self): response = self.client.get( "/api/v1/config", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 403 - assert "chose not to expose" in response.json["detail"] + assert "chose not to expose" in response.json()["detail"] @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -236,15 +225,14 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestGetValue: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_text_plain(self, mock_as_dict): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert response.status_code == 200 expected = textwrap.dedent( @@ -253,7 +241,7 @@ def test_should_respond_200_text_plain(self, mock_as_dict): smtp_mail_from = airflow@example.com """ ) - assert expected == response.data.decode() + assert expected == response.text @patch( "airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", @@ -272,8 +260,7 @@ def test_should_respond_200_text_plain(self, mock_as_dict): def test_should_respond_200_text_plain_with_non_sensitive_only(self, mock_as_dict, section, option): response = self.client.get( f"/api/v1/config/section/{section}/option/{option}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert response.status_code == 200 expected = textwrap.dedent( @@ -282,14 +269,13 @@ def test_should_respond_200_text_plain_with_non_sensitive_only(self, mock_as_dic {option} = < hidden > """ ) - assert expected == response.data.decode() + assert expected == response.text @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_200_application_json(self, mock_as_dict): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 200 expected = { @@ -302,25 +288,23 @@ def test_should_respond_200_application_json(self, mock_as_dict): }, ] } - assert expected == response.json + assert expected == response.json() @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_404_when_option_not_exist(self, mock_as_dict): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from1", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert "The option [smtp/smtp_mail_from1] is not found in config." in response.json["detail"] + assert "The option [smtp/smtp_mail_from1] is not found in config." in response.json()["detail"] @patch("airflow.api_connexion.endpoints.config_endpoint.conf.as_dict", return_value=MOCK_CONF) def test_should_respond_406(self, mock_as_dict): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "application/octet-stream"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/octet-stream", "REMOTE_USER": "test"}, ) assert response.status_code == 406 @@ -329,13 +313,12 @@ def test_should_raises_401_unauthenticated(self): "/api/v1/config/section/smtp/option/smtp_mail_from", headers={"Accept": "application/json"} ) - assert_401(response) + assert response.status_code == 401 def test_should_raises_403_unauthorized(self): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"Accept": "application/json", "REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -344,11 +327,10 @@ def test_should_raises_403_unauthorized(self): def test_should_respond_403_when_expose_config_off(self): response = self.client.get( "/api/v1/config/section/smtp/option/smtp_mail_from", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 403 - assert "chose not to expose" in response.json["detail"] + assert "chose not to expose" in response.json()["detail"] @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py index c88b8a56de9d53..8a209af3a1b933 100644 --- a/tests/api_connexion/endpoints/test_connection_endpoint.py +++ b/tests/api_connexion/endpoints/test_connection_endpoint.py @@ -26,7 +26,7 @@ from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.security import permissions from airflow.utils.session import provide_session -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_connections from tests.test_utils.www import _check_last_log @@ -36,9 +36,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -48,19 +48,19 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestConnectionEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore # we want only the connection created here for this test clear_db_connections(False) @@ -81,20 +81,16 @@ def test_delete_should_respond_204(self, session): session.commit() conn = session.query(Connection).all() assert len(conn) == 1 - response = self.client.delete( - "/api/v1/connections/test-connection", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.delete("/api/v1/connections/test-connection", headers={"REMOTE_USER": "test"}) assert response.status_code == 204 connection = session.query(Connection).all() assert len(connection) == 0 _check_last_log(session, dag_id=None, event="api.connection.delete", execution_date=None) def test_delete_should_respond_404(self): - response = self.client.delete( - "/api/v1/connections/test-connection", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.delete("/api/v1/connections/test-connection", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "The Connection with connection_id: `test-connection` was not found", "status": 404, "title": "Connection not found", @@ -104,11 +100,11 @@ def test_delete_should_respond_404(self): def test_should_raises_401_unauthenticated(self): response = self.client.delete("/api/v1/connections/test-connection") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( - "/api/v1/connections/test-connection-id", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/api/v1/connections/test-connection-id", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -145,11 +141,9 @@ def test_should_respond_200(self, session): session.commit() result = session.query(Connection).all() assert len(result) == 1 - response = self.client.get( - "/api/v1/connections/test-connection-id", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/connections/test-connection-id", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "connection_id": "test-connection-id", "conn_type": "mysql", "description": "test description", @@ -171,28 +165,24 @@ def test_should_mask_sensitive_values_in_extra(self, session): session.add(connection_model) session.commit() - response = self.client.get( - "/api/v1/connections/test-connection-id", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/connections/test-connection-id", headers={"REMOTE_USER": "test"}) - assert response.json["extra"] == '{"nonsensitive": "just_a_value", "api_token": "***"}' + assert response.json()["extra"] == '{"nonsensitive": "just_a_value", "api_token": "***"}' def test_should_respond_404(self): - response = self.client.get( - "/api/v1/connections/invalid-connection", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/connections/invalid-connection", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "The Connection with connection_id: `invalid-connection` was not found", "status": 404, "title": "Connection not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/connections/test-connection-id") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -229,9 +219,9 @@ def test_should_respond_200(self, session): session.commit() result = session.query(Connection).all() assert len(result) == 2 - response = self.client.get("/api/v1/connections", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/connections", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "connections": [ { "connection_id": "test-connection-id-1", @@ -264,11 +254,11 @@ def test_should_respond_200_with_order_by(self, session): result = session.query(Connection).all() assert len(result) == 2 response = self.client.get( - "/api/v1/connections?order_by=-connection_id", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/connections?order_by=-connection_id", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 # Using - means descending - assert response.json == { + assert response.json() == { "connections": [ { "connection_id": "test-connection-id-2", @@ -295,7 +285,7 @@ def test_should_respond_200_with_order_by(self, session): def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/connections") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -352,10 +342,10 @@ def test_handle_limit_offset(self, url, expected_conn_ids, session): connections = self._create_connections(10) session.add_all(connections) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - conn_ids = [conn["connection_id"] for conn in response.json["connections"] if conn] + assert response.json()["total_entries"] == 10 + conn_ids = [conn["connection_id"] for conn in response.json()["connections"] if conn] assert conn_ids == expected_conn_ids def test_should_respect_page_size_limit_default(self, session): @@ -363,23 +353,21 @@ def test_should_respect_page_size_limit_default(self, session): session.add_all(connection_models) session.commit() - response = self.client.get("/api/v1/connections", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/connections", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["connections"]) == 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["connections"]) == 100 def test_invalid_order_by_raises_400(self, session): connection_models = self._create_connections(200) session.add_all(connection_models) session.commit() - response = self.client.get( - "/api/v1/connections?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/connections?order_by=invalid", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert ( - response.json["detail"] == "Ordering with 'invalid' is disallowed or" + response.json()["detail"] == "Ordering with 'invalid' is disallowed or" " the attribute does not exist on the model" ) @@ -388,11 +376,11 @@ def test_limit_of_zero_should_return_default(self, session): session.add_all(connection_models) session.commit() - response = self.client.get("/api/v1/connections?limit=0", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/connections?limit=0", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["connections"]) == 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["connections"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): @@ -400,9 +388,9 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): session.add_all(connection_models) session.commit() - response = self.client.get("/api/v1/connections?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/connections?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["connections"]) == 150 + assert len(response.json()["connections"]) == 150 def _create_connections(self, count): return [ @@ -424,7 +412,7 @@ def test_patch_should_respond_200(self, payload, session): self._create_connection(session) response = self.client.patch( - "/api/v1/connections/test-connection-id", json=payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/connections/test-connection-id", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 _check_last_log(session, dag_id=None, event="api.connection.edit", execution_date=None) @@ -442,12 +430,12 @@ def test_patch_should_respond_200_with_update_mask(self, session): response = self.client.patch( "/api/v1/connections/test-connection-id?update_mask=port,login", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 connection = session.query(Connection).filter_by(conn_id=test_connection).first() assert connection.password is None - assert response.json == { + assert response.json() == { "connection_id": test_connection, # not updated "conn_type": "test_type", # Not updated "description": None, # Not updated @@ -513,10 +501,10 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask( response = self.client.patch( f"/api/v1/connections/test-connection-id?{update_mask}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == error_message + assert response.json()["detail"] == error_message @pytest.mark.parametrize( "payload, error_message", @@ -552,15 +540,15 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask( def test_patch_should_respond_400_for_invalid_update(self, payload, error_message, session): self._create_connection(session) response = self.client.patch( - "/api/v1/connections/test-connection-id", json=payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/connections/test-connection-id", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert error_message in response.json["detail"] + assert error_message in response.json()["detail"] def test_patch_should_respond_404_not_found(self): payload = {"connection_id": "test-connection-id", "conn_type": "test-type", "port": 90} response = self.client.patch( - "/api/v1/connections/test-connection-id", json=payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/connections/test-connection-id", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 assert { @@ -568,7 +556,7 @@ def test_patch_should_respond_404_not_found(self): "status": 404, "title": "Connection not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): self._create_connection(session) @@ -578,7 +566,7 @@ def test_should_raises_401_unauthenticated(self, session): json={"connection_id": "test-connection-id", "conn_type": "test_type", "extra": "{'key': 'var'}"}, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -599,9 +587,7 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestPostConnection(TestConnectionEndpoint): def test_post_should_respond_200(self, session): payload = {"connection_id": "test-connection-id", "conn_type": "test_type"} - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 connection = session.query(Connection).all() assert len(connection) == 1 @@ -612,11 +598,9 @@ def test_post_should_respond_200(self, session): def test_post_should_respond_200_extra_null(self, session): payload = {"connection_id": "test-connection-id", "conn_type": "test_type", "extra": None} - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["extra"] is None + assert response.json()["extra"] is None connection = session.query(Connection).all() assert len(connection) == 1 assert connection[0].conn_id == "test-connection-id" @@ -626,11 +610,9 @@ def test_post_should_respond_400_for_invalid_payload(self): payload = { "connection_id": "test-connection-id", } # conn_type missing - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'conn_type': ['Missing data for required field.']}", "status": 400, "title": "Bad Request", @@ -639,11 +621,9 @@ def test_post_should_respond_400_for_invalid_payload(self): def test_post_should_respond_400_for_invalid_conn_id(self): payload = {"connection_id": "****", "conn_type": "test_type"} - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "The key '****' has to be made of " "alphanumeric characters, dashes, dots and underscores exclusively", "status": 400, @@ -653,16 +633,12 @@ def test_post_should_respond_400_for_invalid_conn_id(self): def test_post_should_respond_409_already_exist(self): payload = {"connection_id": "test-connection-id", "conn_type": "test_type"} - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 # Another request - response = self.client.post( - "/api/v1/connections", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 409 - assert response.json == { + assert response.json() == { "detail": "Connection already exist. ID: test-connection-id", "status": 409, "title": "Conflict", @@ -674,7 +650,7 @@ def test_should_raises_401_unauthenticated(self): "/api/v1/connections", json={"connection_id": "test-connection-id", "conn_type": "test_type"} ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -693,11 +669,9 @@ class TestConnection(TestConnectionEndpoint): @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) def test_should_respond_200(self): payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"} - response = self.client.post( - "/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections/test", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "status": True, "message": "Connection successfully tested", } @@ -705,7 +679,7 @@ def test_should_respond_200(self): @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) def test_connection_env_is_cleaned_after_run(self): payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"} - self.client.post("/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"}) + self.client.post("/api/v1/connections/test", json=payload, headers={"REMOTE_USER": "test"}) assert not any([key.startswith(CONN_ENV_PREFIX) for key in os.environ.keys()]) @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) @@ -713,11 +687,9 @@ def test_post_should_respond_400_for_invalid_payload(self): payload = { "connection_id": "test-connection-id", } # conn_type missing - response = self.client.post( - "/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections/test", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'conn_type': ['Missing data for required field.']}", "status": 400, "title": "Bad Request", @@ -729,13 +701,11 @@ def test_should_raises_401_unauthenticated(self): "/api/v1/connections/test", json={"connection_id": "test-connection-id", "conn_type": "test_type"} ) - assert_401(response) + assert response.status_code == 401 def test_should_respond_403_by_default(self): payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"} - response = self.client.post( - "/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/api/v1/connections/test", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 403 assert response.text == ( "Testing connections is disabled in Airflow configuration. " diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index b514faba276d99..5804178a024e0b 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -31,7 +31,7 @@ from airflow.security import permissions from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags from tests.test_utils.www import _check_last_log @@ -53,10 +53,10 @@ def current_file_token(url_safe_serializer) -> str: @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -65,13 +65,13 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_1", access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_1", access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, ) @@ -94,13 +94,13 @@ def configured_app(minimal_app_for_api): dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3} - app.dag_bag = dag_bag + connexion_app.app.dag_bag = dag_bag - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_granular_permissions") # type: ignore class TestDagEndpoint: @@ -113,8 +113,9 @@ def clean_db(): @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: self.clean_db() - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = configured_app.app + self.client = self.connexion_app.test_client() # type:ignore self.dag_id = DAG_ID self.dag2_id = DAG2_ID self.dag3_id = DAG3_ID @@ -177,7 +178,7 @@ class TestGetDag(TestDagEndpoint): @conf_vars({("webserver", "secret_key"): "mysecret"}) def test_should_respond_200(self): self._create_dag_models(1) - response = self.client.get("/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/dags/TEST_DAG_1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "dag_id": "TEST_DAG_1", @@ -208,7 +209,7 @@ def test_should_respond_200(self): "timetable_description": None, "has_import_errors": False, "pickle_id": None, - } == response.json + } == response.json() @conf_vars({("webserver", "secret_key"): "mysecret"}) def test_should_respond_200_with_schedule_interval_none(self, session): @@ -220,7 +221,7 @@ def test_should_respond_200_with_schedule_interval_none(self, session): ) session.add(dag_model) session.commit() - response = self.client.get("/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/dags/TEST_DAG_1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "dag_id": "TEST_DAG_1", @@ -251,17 +252,17 @@ def test_should_respond_200_with_schedule_interval_none(self, session): "timetable_description": None, "has_import_errors": False, "pickle_id": None, - } == response.json + } == response.json() def test_should_respond_200_with_granular_dag_access(self): self._create_dag_models(1) response = self.client.get( - "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + "/api/v1/dags/TEST_DAG_1", headers={"REMOTE_USER": "test_granular_permissions"} ) assert response.status_code == 200 def test_should_respond_404(self): - response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/dags/INVALID_DAG", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 def test_should_raises_401_unauthenticated(self): @@ -269,18 +270,18 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/dags/TEST_DAG_1") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test_no_permissions"} + f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 def test_should_respond_403_with_granular_access_for_different_dag(self): self._create_dag_models(3) response = self.client.get( - "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + "/api/v1/dags/TEST_DAG_2", headers={"REMOTE_USER": "test_granular_permissions"} ) assert response.status_code == 403 @@ -295,9 +296,9 @@ def test_should_respond_403_with_granular_access_for_different_dag(self): def test_should_return_specified_fields(self, fields): self._create_dag_models(1) response = self.client.get( - f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", headers={"REMOTE_USER": "test"} ) - res_json = response.json + res_json = response.json() assert len(res_json.keys()) == len(fields) for field in fields: assert field in res_json @@ -313,7 +314,7 @@ def test_should_return_specified_fields(self, fields): def test_should_respond_400_with_not_exists_fields(self, fields): self._create_dag_models(1) response = self.client.get( - f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/TEST_DAG_1?fields={','.join(fields)}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -340,11 +341,9 @@ class TestGetDagDetails(TestDagEndpoint): def test_should_respond_200(self, url_safe_serializer): self._create_dag_model_for_details_endpoint(self.dag_id) current_file_token = url_safe_serializer.dumps("/tmp/dag.py") - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - last_parsed = response.json["last_parsed"] + last_parsed = response.json()["last_parsed"] expected = { "catchup": True, "concurrency": 16, @@ -396,16 +395,14 @@ def test_should_respond_200(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - assert response.json == expected + assert response.json() == expected def test_should_respond_200_with_dataset_expression(self, url_safe_serializer): self._create_dag_model_for_details_endpoint_with_dataset_expression(self.dag_id) current_file_token = url_safe_serializer.dumps("/tmp/dag.py") - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - last_parsed = response.json["last_parsed"] + last_parsed = response.json()["last_parsed"] expected = { "catchup": True, "concurrency": 16, @@ -462,16 +459,14 @@ def test_should_respond_200_with_dataset_expression(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - assert response.json == expected + assert response.json() == expected def test_should_response_200_with_doc_md_none(self, url_safe_serializer): current_file_token = url_safe_serializer.dumps("/tmp/dag.py") self._create_dag_model_for_details_endpoint(self.dag2_id) - response = self.client.get( - f"/api/v1/dags/{self.dag2_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag2_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - last_parsed = response.json["last_parsed"] + last_parsed = response.json()["last_parsed"] expected = { "catchup": True, "concurrency": 16, @@ -516,16 +511,14 @@ def test_should_response_200_with_doc_md_none(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - assert response.json == expected + assert response.json() == expected def test_should_response_200_for_null_start_date(self, url_safe_serializer): current_file_token = url_safe_serializer.dumps("/tmp/dag.py") self._create_dag_model_for_details_endpoint(self.dag3_id) - response = self.client.get( - f"/api/v1/dags/{self.dag3_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag3_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - last_parsed = response.json["last_parsed"] + last_parsed = response.json()["last_parsed"] expected = { "catchup": True, "concurrency": 16, @@ -570,17 +563,17 @@ def test_should_response_200_for_null_start_date(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - assert response.json == expected + assert response.json() == expected def test_should_respond_200_serialized(self, url_safe_serializer): current_file_token = url_safe_serializer.dumps("/tmp/dag.py") self._create_dag_model_for_details_endpoint(self.dag_id) # Get the dag out of the dagbag before we patch it to an empty one - SerializedDagModel.write_dag(self.app.dag_bag.get_dag(self.dag_id)) + SerializedDagModel.write_dag(self.flask_app.dag_bag.get_dag(self.dag_id)) # Create empty app with empty dagbag to check if DAG is read from db dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True) - patcher = unittest.mock.patch.object(self.app, "dag_bag", dag_bag) + patcher = unittest.mock.patch.object(self.flask_app, "dag_bag", dag_bag) patcher.start() expected = { @@ -633,19 +626,15 @@ def test_should_respond_200_serialized(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - expected.update({"last_parsed": response.json["last_parsed"]}) - assert response.json == expected + expected.update({"last_parsed": response.json()["last_parsed"]}) + assert response.json() == expected patcher.stop() - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/details", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 expected = { "catchup": True, @@ -697,20 +686,20 @@ def test_should_respond_200_serialized(self, url_safe_serializer): "timetable_description": None, "timezone": UTC_JSON_REPR, } - expected.update({"last_parsed": response.json["last_parsed"]}) - assert response.json == expected + expected.update({"last_parsed": response.json()["last_parsed"]}) + assert response.json() == expected def test_should_raises_401_unauthenticated(self): response = self.client.get(f"/api/v1/dags/{self.dag_id}/details") - assert_401(response) + assert response.status_code == 401 def test_should_raise_404_when_dag_is_not_found(self): response = self.client.get( - "/api/v1/dags/non_existing_dag_id/details", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/dags/non_existing_dag_id/details", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "The DAG with dag_id: non_existing_dag_id was not found", "status": 404, "title": "DAG not found", @@ -729,10 +718,10 @@ def test_should_return_specified_fields(self, fields): self._create_dag_model_for_details_endpoint(self.dag2_id) response = self.client.get( f"/api/v1/dags/{self.dag2_id}/details?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - res_json = response.json + res_json = response.json() assert len(res_json.keys()) == len(fields) for field in fields: assert field in res_json @@ -742,7 +731,7 @@ def test_should_respond_400_with_not_exists_fields(self): self._create_dag_model_for_details_endpoint(self.dag2_id) response = self.client.get( f"/api/v1/dags/{self.dag2_id}/details?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -768,7 +757,7 @@ def test_should_respond_200(self, session, url_safe_serializer): dags_query = session.query(DagModel).filter(~DagModel.is_subdag) assert len(dags_query.all()) == 3 - response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") file_token2 = url_safe_serializer.dumps("/tmp/dag_2.py") @@ -843,12 +832,12 @@ def test_should_respond_200(self, session, url_safe_serializer): }, ], "total_entries": 2, - } == response.json + } == response.json() def test_only_active_true_returns_active_dags(self, url_safe_serializer): self._create_dag_models(1) self._create_deactivated_dag() - response = self.client.get("api/v1/dags?only_active=True", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags?only_active=True", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") assert response.status_code == 200 assert { @@ -888,12 +877,12 @@ def test_only_active_true_returns_active_dags(self, url_safe_serializer): } ], "total_entries": 1, - } == response.json + } == response.json() def test_only_active_false_returns_all_dags(self, url_safe_serializer): self._create_dag_models(1) self._create_deactivated_dag() - response = self.client.get("api/v1/dags?only_active=False", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags?only_active=False", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") file_token_2 = url_safe_serializer.dumps("/tmp/dag_del_1.py") assert response.status_code == 200 @@ -967,7 +956,7 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer): }, ], "total_entries": 2, - } == response.json + } == response.json() @pytest.mark.parametrize( "url, expected_dag_ids", @@ -989,9 +978,9 @@ def test_filter_dags_by_tags_works(self, url, expected_dag_ids): dag3.sync_to_db() dag4.sync_to_db() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + dag_ids = [dag["dag_id"] for dag in response.json()["dags"]] assert expected_dag_ids == dag_ids @@ -1017,20 +1006,18 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): dag3.sync_to_db() dag4.sync_to_db() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dag_ids = {dag["dag_id"] for dag in response.json["dags"]} + dag_ids = {dag["dag_id"] for dag in response.json()["dags"]} assert expected_dag_ids == dag_ids def test_should_respond_200_with_granular_dag_access(self): self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) + response = self.client.get("/api/v1/dags", headers={"REMOTE_USER": "test_granular_permissions"}) assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" + assert len(response.json()["dags"]) == 1 + assert response.json()["dags"][0]["dag_id"] == "TEST_DAG_1" @pytest.mark.parametrize( "url, expected_dag_ids", @@ -1064,41 +1051,41 @@ def test_should_respond_200_with_granular_dag_access(self): def test_should_respond_200_and_handle_pagination(self, url, expected_dag_ids): self._create_dag_models(10) - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + dag_ids = [dag["dag_id"] for dag in response.json()["dags"]] assert expected_dag_ids == dag_ids - assert 10 == response.json["total_entries"] + assert 10 == response.json()["total_entries"] def test_should_respond_200_default_limit(self): self._create_dag_models(101) - response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert 100 == len(response.json["dags"]) - assert 101 == response.json["total_entries"] + assert 100 == len(response.json()["dags"]) + assert 101 == response.json()["total_entries"] def test_should_raises_401_unauthenticated(self): response = self.client.get("api/v1/dags") - assert_401(response) + assert response.status_code == 401 def test_should_respond_403_unauthorized(self): self._create_dag_models(1) - response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test_no_permissions"}) + response = self.client.get("api/v1/dags", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 def test_paused_true_returns_paused_dags(self, url_safe_serializer): self._create_dag_models(1, dag_id_prefix="TEST_DAG_PAUSED", is_paused=True) self._create_dag_models(1, dag_id_prefix="TEST_DAG_UNPAUSED", is_paused=False) - response = self.client.get("api/v1/dags?paused=True", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags?paused=True", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") assert response.status_code == 200 assert { @@ -1138,12 +1125,12 @@ def test_paused_true_returns_paused_dags(self, url_safe_serializer): } ], "total_entries": 1, - } == response.json + } == response.json() def test_paused_false_returns_unpaused_dags(self, url_safe_serializer): self._create_dag_models(1, dag_id_prefix="TEST_DAG_PAUSED", is_paused=True) self._create_dag_models(1, dag_id_prefix="TEST_DAG_UNPAUSED", is_paused=False) - response = self.client.get("api/v1/dags?paused=False", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags?paused=False", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") assert response.status_code == 200 assert { @@ -1183,12 +1170,12 @@ def test_paused_false_returns_unpaused_dags(self, url_safe_serializer): } ], "total_entries": 1, - } == response.json + } == response.json() def test_paused_none_returns_all_dags(self, url_safe_serializer): self._create_dag_models(1, dag_id_prefix="TEST_DAG_PAUSED", is_paused=True) self._create_dag_models(1, dag_id_prefix="TEST_DAG_UNPAUSED", is_paused=False) - response = self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags", headers={"REMOTE_USER": "test"}) file_token = url_safe_serializer.dumps("/tmp/dag_1.py") assert response.status_code == 200 assert { @@ -1261,19 +1248,17 @@ def test_paused_none_returns_all_dags(self, url_safe_serializer): }, ], "total_entries": 2, - } == response.json + } == response.json() def test_should_return_specified_fields(self): self._create_dag_models(2) self._create_deactivated_dag() fields = ["dag_id", "file_token", "owners"] - response = self.client.get( - f"api/v1/dags?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"api/v1/dags?fields={','.join(fields)}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - res_json = response.json + res_json = response.json() for dag in res_json["dags"]: assert len(dag.keys()) == len(fields) for field in fields: @@ -1283,9 +1268,7 @@ def test_should_respond_400_with_not_exists_fields(self): self._create_dag_models(1) self._create_deactivated_dag() fields = ["#caw&c"] - response = self.client.get( - f"api/v1/dags?fields={','.join(fields)}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"api/v1/dags?fields={','.join(fields)}", headers={"REMOTE_USER": "test"}) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -1314,7 +1297,7 @@ def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, sessio response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 expected_response = { @@ -1350,7 +1333,7 @@ def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, sessio "has_import_errors": False, "pickle_id": None, } - assert response.json == expected_response + assert response.json() == expected_response _check_last_log( session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None, expected_extra=payload ) @@ -1362,28 +1345,26 @@ def test_should_respond_200_on_patch_with_granular_dag_access(self, session): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert response.status_code == 200 _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None) - def test_should_respond_400_on_invalid_request(self): + def test_ignore_read_only_fields(self): patch_body = { - "is_paused": True, + "is_paused": False, "schedule_interval": { "__type": "CronExpression", "value": "1 1 * * *", }, } dag_model = self._create_dag_model() - response = self.client.patch(f"/api/v1/dags/{dag_model.dag_id}", json=patch_body) - assert response.status_code == 400 - assert response.json == { - "detail": "Property is read-only - 'schedule_interval'", - "status": 400, - "title": "Bad Request", - "type": EXCEPTIONS_LINK_MAP[400], - } + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}", json=patch_body, headers={"REMOTE_USER": "test"} + ) + assert response.status_code == 200 + assert response.json()["is_paused"] is False + assert response.json()["schedule_interval"] == {"__type": "CronExpression", "value": "2 2 * * *"} def test_validation_error_raises_400(self): patch_body = { @@ -1393,10 +1374,10 @@ def test_validation_error_raises_400(self): response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}", json=patch_body, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'ispaused': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1408,10 +1389,10 @@ def test_non_existing_dag_raises_not_found(self): "is_paused": True, } response = self.client.patch( - "/api/v1/dags/non_existing_dag", json=patch_body, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/dags/non_existing_dag", json=patch_body, headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": None, "status": 404, "title": "Dag with id: 'non_existing_dag' not found", @@ -1419,7 +1400,7 @@ def test_non_existing_dag_raises_not_found(self): } def test_should_respond_404(self): - response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/dags/INVALID_DAG", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 @provide_session @@ -1439,7 +1420,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_respond_200_with_update_mask(self, url_safe_serializer): file_token = url_safe_serializer.dumps("/tmp/dag_1.py") @@ -1450,7 +1431,7 @@ def test_should_respond_200_with_update_mask(self, url_safe_serializer): response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}?update_mask=is_paused", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -1487,7 +1468,7 @@ def test_should_respond_200_with_update_mask(self, url_safe_serializer): "has_import_errors": False, "pickle_id": None, } - assert response.json == expected_response + assert response.json() == expected_response @pytest.mark.parametrize( "payload, update_mask, error_message", @@ -1514,10 +1495,10 @@ def test_should_respond_400_for_invalid_fields_in_update_mask(self, payload, upd response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}?{update_mask}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == error_message + assert response.json()["detail"] == error_message def test_should_respond_403_unauthorized(self): dag_model = self._create_dag_model() @@ -1526,7 +1507,7 @@ def test_should_respond_403_unauthorized(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1566,7 +1547,7 @@ def test_should_respond_200_on_patch_is_paused(self, session, url_safe_serialize json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -1640,7 +1621,7 @@ def test_should_respond_200_on_patch_is_paused(self, session, url_safe_serialize }, ], "total_entries": 2, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_dags", execution_date=None) def test_should_respond_200_on_patch_is_paused_using_update_mask(self, session, url_safe_serializer): @@ -1657,7 +1638,7 @@ def test_should_respond_200_on_patch_is_paused_using_update_mask(self, session, json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -1731,7 +1712,7 @@ def test_should_respond_200_on_patch_is_paused_using_update_mask(self, session, }, ], "total_entries": 2, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_dags", execution_date=None) def test_wrong_value_as_update_mask_rasise(self, session): @@ -1746,11 +1727,11 @@ def test_wrong_value_as_update_mask_rasise(self, session): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "Only `is_paused` field can be updated through the REST API", "status": 400, "title": "Bad Request", @@ -1769,11 +1750,11 @@ def test_invalid_request_body_raises_badrequest(self, session): json={ "ispaused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'ispaused': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1789,7 +1770,7 @@ def test_only_active_true_returns_active_dags(self, url_safe_serializer, session json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert { @@ -1829,7 +1810,7 @@ def test_only_active_true_returns_active_dags(self, url_safe_serializer, session } ], "total_entries": 1, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_dags", execution_date=None) def test_only_active_false_returns_all_dags(self, url_safe_serializer, session): @@ -1841,7 +1822,7 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer, session): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) file_token_2 = url_safe_serializer.dumps("/tmp/dag_del_1.py") @@ -1916,7 +1897,7 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer, session): }, ], "total_entries": 2, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_dags", execution_date=None) @pytest.mark.parametrize( @@ -1943,10 +1924,10 @@ def test_filter_dags_by_tags_works(self, url, expected_dag_ids): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + dag_ids = [dag["dag_id"] for dag in response.json()["dags"]] assert expected_dag_ids == dag_ids @@ -1977,10 +1958,10 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - dag_ids = {dag["dag_id"] for dag in response.json["dags"]} + dag_ids = {dag["dag_id"] for dag in response.json()["dags"]} assert expected_dag_ids == dag_ids @@ -1991,11 +1972,11 @@ def test_should_respond_200_with_granular_dag_access(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" + assert len(response.json()["dags"]) == 1 + assert response.json()["dags"][0]["dag_id"] == "TEST_DAG_1" @pytest.mark.parametrize( "url, expected_dag_ids", @@ -2034,15 +2015,15 @@ def test_should_respond_200_and_handle_pagination(self, url, expected_dag_ids): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - dag_ids = [dag["dag_id"] for dag in response.json["dags"]] + dag_ids = [dag["dag_id"] for dag in response.json()["dags"]] assert expected_dag_ids == dag_ids - assert 10 == response.json["total_entries"] + assert 10 == response.json()["total_entries"] def test_should_respond_200_default_limit(self): self._create_dag_models(101) @@ -2052,13 +2033,13 @@ def test_should_respond_200_default_limit(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert 100 == len(response.json["dags"]) - assert 101 == response.json["total_entries"] + assert 100 == len(response.json()["dags"]) + assert 101 == response.json()["total_entries"] def test_should_raises_401_unauthenticated(self): response = self.client.patch( @@ -2068,7 +2049,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_respond_403_unauthorized(self): self._create_dag_models(1) @@ -2077,7 +2058,7 @@ def test_should_respond_403_unauthorized(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -2092,7 +2073,7 @@ def test_should_respond_200_and_pause_dags(self, url_safe_serializer): json={ "is_paused": True, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -2166,7 +2147,7 @@ def test_should_respond_200_and_pause_dags(self, url_safe_serializer): }, ], "total_entries": 2, - } == response.json + } == response.json() @provide_session def test_should_respond_200_and_pause_dag_pattern(self, session, url_safe_serializer): @@ -2179,7 +2160,7 @@ def test_should_respond_200_and_pause_dag_pattern(self, session, url_safe_serial json={ "is_paused": True, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -2253,7 +2234,7 @@ def test_should_respond_200_and_pause_dag_pattern(self, session, url_safe_serial }, ], "total_entries": 2, - } == response.json + } == response.json() dags_not_updated = session.query(DagModel).filter(~DagModel.is_paused) assert len(dags_not_updated.all()) == 8 @@ -2268,7 +2249,7 @@ def test_should_respond_200_and_reverse_ordering(self, session, url_safe_seriali response = self.client.get( "/api/v1/dags?order_by=-dag_id", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -2342,7 +2323,7 @@ def test_should_respond_200_and_reverse_ordering(self, session, url_safe_seriali }, ], "total_entries": 2, - } == response.json + } == response.json() def test_should_respons_400_dag_id_pattern_missing(self): self._create_dag_models(1) @@ -2351,7 +2332,7 @@ def test_should_respons_400_dag_id_pattern_missing(self): json={ "is_paused": False, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 @@ -2385,7 +2366,7 @@ def test_that_dag_can_be_deleted(self, session): response = self.client.delete( "/api/v1/dags/TEST_DAG_1", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 204 _check_last_log(session, dag_id="TEST_DAG_1", event="api.delete_dag", execution_date=None) @@ -2393,10 +2374,10 @@ def test_that_dag_can_be_deleted(self, session): def test_raise_when_dag_is_not_found(self): response = self.client.delete( "/api/v1/dags/TEST_DAG_1", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": None, "status": 404, "title": "Dag with id: 'TEST_DAG_1' not found", @@ -2412,10 +2393,10 @@ def test_raises_when_task_instances_of_dag_is_still_running(self, dag_maker, ses session.flush() response = self.client.delete( "/api/v1/dags/TEST_DAG_1", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409 - assert response.json == { + assert response.json() == { "detail": "Task instances of dag with id: 'TEST_DAG_1' are still running", "status": 409, "title": "Conflict", @@ -2426,7 +2407,7 @@ def test_users_without_delete_permission_cannot_delete_dag(self): self._create_dag_models(1) response = self.client.delete( "/api/v1/dags/TEST_DAG_1", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 5182ef427e6245..f2e999f605f2ee 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -34,7 +34,7 @@ from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags from tests.test_utils.www import _check_last_log @@ -44,10 +44,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -62,7 +62,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_dag_view_only", role_name="TestViewDags", permissions=[ @@ -74,7 +74,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_view_dags", role_name="TestViewDags", permissions=[ @@ -83,25 +83,25 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_granular_permissions", role_name="TestGranularDag", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_ID", access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_view_only") # type: ignore - delete_user(app, username="test_view_dags") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_roles(app) + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_dag_view_only") # type: ignore + delete_user(connexion_app.app, username="test_view_dags") # type: ignore + delete_user(connexion_app.app, username="test_granular_permissions") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_roles(connexion_app.app) class TestDagRunEndpoint: @@ -111,8 +111,9 @@ class TestDagRunEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = configured_app.app + self.client = self.connexion_app.test_client() # type:ignore clear_db_runs() clear_db_serialized_dags() clear_db_dags() @@ -122,13 +123,14 @@ def teardown_method(self) -> None: clear_db_dags() clear_db_serialized_dags() - def _create_dag(self, dag_id): + def _create_dag(self, dag_id, is_active=True, has_import_errors=False): dag_instance = DagModel(dag_id=dag_id) - dag_instance.is_active = True + dag_instance.is_active = is_active + dag_instance.has_import_errors = has_import_errors with create_session() as session: session.add(dag_instance) dag = DAG(dag_id=dag_id, schedule=None) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) return dag_instance def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): @@ -176,21 +178,21 @@ def test_should_respond_204(self, session): session.add_all(self._create_test_dag_run()) session.commit() response = self.client.delete( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", headers={"REMOTE_USER": "test"} ) assert response.status_code == 204 # Check if the Dag Run is deleted from the database response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 def test_should_respond_404(self): response = self.client.delete( - "api/v1/dags/INVALID_DAG_RUN/dagRuns/INVALID_DAG_RUN", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/INVALID_DAG_RUN/dagRuns/INVALID_DAG_RUN", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "DAGRun with DAG ID: 'INVALID_DAG_RUN' and DagRun ID: 'INVALID_DAG_RUN' not found", "status": 404, "title": "Not Found", @@ -205,12 +207,12 @@ def test_should_raises_401_unauthenticated(self, session): "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -243,10 +245,10 @@ def test_should_respond_200(self, session): result = session.query(DagRun).all() assert len(result) == 1 response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "TEST_DAG_ID", "dag_run_id": "TEST_DAG_RUN_ID", "end_date": None, @@ -265,7 +267,7 @@ def test_should_respond_200(self, session): def test_should_respond_404(self): response = self.client.get( - "api/v1/dags/invalid-id/dagRuns/invalid-id", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/invalid-id/dagRuns/invalid-id", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 expected_resp = { @@ -274,7 +276,7 @@ def test_should_respond_404(self): "title": "DAGRun not found", "type": EXCEPTIONS_LINK_MAP[404], } - assert expected_resp == response.json + assert expected_resp == response.json() def test_should_raises_401_unauthenticated(self, session): dagrun_model = DagRun( @@ -290,7 +292,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "fields", @@ -315,11 +317,10 @@ def test_should_return_specified_fields(self, session, fields): assert len(result) == 1 response = self.client.get( f"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - res_json = response.json - print("get dagRun", res_json) + res_json = response.json() assert len(res_json.keys()) == len(fields) for field in fields: assert field in res_json @@ -341,7 +342,7 @@ def test_should_respond_400_with_not_exists_fields(self, session): fields = ["#caw&c"] response = self.client.get( f"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -374,11 +375,9 @@ def test_should_respond_200(self, session): self._create_test_dag_run() result = session.query(DagRun).all() assert len(result) == 2 - response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -421,22 +420,22 @@ def test_filter_by_state(self, session): self._create_test_dag_run(state="queued", idx_start=3) assert session.query(DagRun).count() == 4 response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns?state=running,queued", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns?state=running,queued", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == 4 - assert response.json["dag_runs"][0]["state"] == response.json["dag_runs"][1]["state"] == "running" - assert response.json["dag_runs"][2]["state"] == response.json["dag_runs"][3]["state"] == "queued" + assert response.json()["total_entries"] == 4 + assert response.json()["dag_runs"][0]["state"] == response.json()["dag_runs"][1]["state"] == "running" + assert response.json()["dag_runs"][2]["state"] == response.json()["dag_runs"][3]["state"] == "queued" def test_invalid_order_by_raises_400(self): self._create_test_dag_run() response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns?order_by=invalid", headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_return_correct_results_with_order_by(self, session): self._create_test_dag_run() @@ -444,13 +443,13 @@ def test_return_correct_results_with_order_by(self, session): assert len(result) == 2 response = self.client.get( "api/v1/dags/TEST_DAG_ID/dagRuns?order_by=-execution_date", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert self.default_time < self.default_time_2 # - means descending - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -491,19 +490,19 @@ def test_return_correct_results_with_order_by(self, session): def test_should_return_all_with_tilde_as_dag_id_and_all_dag_permissions(self): self._create_test_dag_run(extra_dag=True) expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID", "TEST_DAG_ID_3", "TEST_DAG_ID_4"] - response = self.client.get("api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/dags/~/dagRuns", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] + dag_run_ids = [dag_run["dag_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): self._create_test_dag_run(extra_dag=True) expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] response = self.client.get( - "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + "api/v1/dags/~/dagRuns", headers={"REMOTE_USER": "test_granular_permissions"} ) assert response.status_code == 200 - dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] + dag_run_ids = [dag_run["dag_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def test_should_raises_401_unauthenticated(self): @@ -511,7 +510,7 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "fields", @@ -526,10 +525,10 @@ def test_should_return_specified_fields(self, session, fields): assert len(result) == 2 response = self.client.get( f"api/v1/dags/TEST_DAG_ID/dagRuns?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - for dag_run in response.json["dag_runs"]: + for dag_run in response.json()["dag_runs"]: assert len(dag_run.keys()) == len(fields) for field in fields: assert field in dag_run @@ -539,7 +538,7 @@ def test_should_respond_400_with_not_exists_fields(self): fields = ["#caw&c"] response = self.client.get( f"api/v1/dags/TEST_DAG_ID/dagRuns?fields={','.join(fields)}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400, f"Current code: {response.status_code}" @@ -600,31 +599,29 @@ class TestGetDagRunsPagination(TestDagRunEndpoint): ) def test_handle_limit_and_offset(self, url, expected_dag_run_ids): self._create_dag_runs(10) - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"]] + assert response.json()["total_entries"] == 10 + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def test_should_respect_page_size_limit(self): self._create_dag_runs(200) - response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["dag_runs"]) == 100 # default is 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["dag_runs"]) == 100 # default is 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self): self._create_dag_runs(200) response = self.client.get( - "api/v1/dags/TEST_DAG_ID/dagRuns?limit=180", environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns?limit=180", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert len(response.json["dag_runs"]) == 150 + assert len(response.json()["dag_runs"]) == 150 def _create_dag_runs(self, count): dag_runs = [ @@ -712,10 +709,10 @@ def test_date_filters_gte_and_lte(self, url, expected_dag_run_ids, session): d.updated_at = d.execution_date session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_dag_run_ids) - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"]] + assert response.json()["total_entries"] == len(expected_dag_run_ids) + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def _create_dag_runs(self): @@ -766,10 +763,10 @@ class TestGetDagRunsEndDateFilters(TestDagRunEndpoint): ) def test_end_date_gte_lte(self, url, expected_dag_run_ids): self._create_test_dag_run("success") # state==success, then end date is today - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_dag_run_ids) - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"] if dag_run] + assert response.json()["total_entries"] == len(expected_dag_run_ids) + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"] if dag_run] assert dag_run_ids == expected_dag_run_ids @@ -779,10 +776,10 @@ def test_should_respond_200(self): response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"]}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -825,10 +822,10 @@ def test_raises_validation_error_for_invalid_request(self): response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dagids": ["TEST_DAG_ID"]}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'dagids': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -841,22 +838,22 @@ def test_filter_by_state(self): response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"], "states": ["running", "queued"]}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 4 - assert response.json["dag_runs"][0]["state"] == response.json["dag_runs"][1]["state"] == "running" - assert response.json["dag_runs"][2]["state"] == response.json["dag_runs"][3]["state"] == "queued" + assert response.json()["total_entries"] == 4 + assert response.json()["dag_runs"][0]["state"] == response.json()["dag_runs"][1]["state"] == "running" + assert response.json()["dag_runs"][2]["state"] == response.json()["dag_runs"][3]["state"] == "queued" def test_order_by_descending_works(self): self._create_test_dag_run() response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"], "order_by": "-dag_run_id"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -899,21 +896,21 @@ def test_order_by_raises_for_invalid_attr(self): response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"], "order_by": "-dag_ru"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 msg = "Ordering with 'dag_ru' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): self._create_test_dag_run(extra_dag=True) response = self.client.post( "api/v1/dags/~/dagRuns/list", json={"dag_ids": []}, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_runs": [ { "dag_id": "TEST_DAG_ID", @@ -966,17 +963,17 @@ def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions def test_payload_validation(self, payload, error): self._create_test_dag_run() response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert response.json.get("detail") == error + assert response.json()["detail"] == error def test_should_raises_401_unauthenticated(self): self._create_test_dag_run() response = self.client.post("api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"]}) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -1033,23 +1030,21 @@ class TestGetDagRunBatchPagination(TestDagRunEndpoint): def test_handle_limit_and_offset(self, payload, expected_dag_run_ids): self._create_dag_runs(10) response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"]] + assert response.json()["total_entries"] == 10 + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def test_should_respect_page_size_limit(self): self._create_dag_runs(200) - response = self.client.post( - "api/v1/dags/~/dagRuns/list", json={}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("api/v1/dags/~/dagRuns/list", json={}, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["dag_runs"]) == 100 # default is 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["dag_runs"]) == 100 # default is 100 def _create_dag_runs(self, count): dag_runs = [ @@ -1114,11 +1109,11 @@ class TestGetDagRunBatchDateFilters(TestDagRunEndpoint): def test_date_filters_gte_and_lte(self, payload, expected_dag_run_ids): self._create_dag_runs() response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_dag_run_ids) - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"]] + assert response.json()["total_entries"] == len(expected_dag_run_ids) + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"]] assert dag_run_ids == expected_dag_run_ids def _create_dag_runs(self): @@ -1186,10 +1181,10 @@ def test_naive_date_filters_raises_400(self, payload, expected_response): self._create_dag_runs() response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert response.json["detail"] == expected_response + assert response.json()["detail"] == expected_response @pytest.mark.parametrize( "payload, expected_dag_run_ids", @@ -1207,11 +1202,11 @@ def test_naive_date_filters_raises_400(self, payload, expected_response): def test_end_date_gte_lte(self, payload, expected_dag_run_ids): self._create_test_dag_run("success") # state==success, then end date is today response = self.client.post( - "api/v1/dags/~/dagRuns/list", json=payload, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/~/dagRuns/list", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_dag_run_ids) - dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json["dag_runs"] if dag_run] + assert response.json()["total_entries"] == len(expected_dag_run_ids) + dag_run_ids = [dag_run["dag_run_id"] for dag_run in response.json()["dag_runs"] if dag_run] assert dag_run_ids == expected_dag_run_ids @@ -1267,7 +1262,7 @@ def test_should_respond_200( response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 @@ -1287,7 +1282,7 @@ def test_should_respond_200( expected_data_interval_start = data_interval_start expected_data_interval_end = data_interval_end - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": "TEST_DAG_ID", "dag_run_id": expected_dag_run_id, @@ -1310,10 +1305,10 @@ def test_raises_validation_error_for_invalid_request(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"executiondate": "2020-11-10T08:25:56Z"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'executiondate': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1330,10 +1325,10 @@ def test_dagrun_creation_exception_is_handled(self, mock_get_app, session): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"execution_date": "2020-11-10T08:25:56Z"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": error_message, "status": 400, "title": "Bad Request", @@ -1341,34 +1336,28 @@ def test_dagrun_creation_exception_is_handled(self, mock_get_app, session): } def test_should_respond_404_if_a_dag_is_inactive(self, session): - dm = self._create_dag("TEST_INACTIVE_DAG_ID") - dm.is_active = False - session.add(dm) - session.flush() + self._create_dag("TEST_INACTIVE_DAG_ID", is_active=False) response = self.client.post( "api/v1/dags/TEST_INACTIVE_DAG_ID/dagRuns", json={}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 404 + assert response.json()["status"] == 404 def test_should_respond_400_if_a_dag_has_import_errors(self, session): """Test that if a dagmodel has import errors, dags won't be triggered""" - dm = self._create_dag("TEST_DAG_ID") - dm.has_import_errors = True - session.add(dm) - session.flush() + self._create_dag("TEST_DAG_ID", has_import_errors=True) response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert { + assert response.json() == { "detail": "DAG with dag_id: 'TEST_DAG_ID' has import errors", "status": 400, "title": "DAG cannot be triggered", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } def test_should_response_200_for_matching_execution_date_logical_date(self): execution_date = "2020-11-10T08:25:56.939143+00:00" @@ -1380,12 +1369,12 @@ def test_should_response_200_for_matching_execution_date_logical_date(self): "execution_date": execution_date, "logical_date": logical_date, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) dag_run_id = f"manual__{logical_date}" assert response.status_code == 200 - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": "TEST_DAG_ID", "dag_run_id": dag_run_id, @@ -1409,11 +1398,11 @@ def test_should_response_400_for_conflicting_execution_date_logical_date(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"execution_date": execution_date, "logical_date": logical_date}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["title"] == "logical_date conflicts with execution_date" - assert response.json["detail"] == (f"'{logical_date}' != '{execution_date}'") + assert response.json()["title"] == "logical_date conflicts with execution_date" + assert response.json()["detail"] == (f"'{logical_date}' != '{execution_date}'") @pytest.mark.parametrize( "data_interval_start, data_interval_end, expected", @@ -1451,10 +1440,10 @@ def test_should_response_400_for_missing_start_date_or_end_date( "data_interval_start": data_interval_start, "data_interval_end": data_interval_end, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected @pytest.mark.parametrize( "data, expected", @@ -1480,10 +1469,10 @@ def test_should_response_400_for_missing_start_date_or_end_date( def test_should_response_400_for_naive_datetime_and_bad_datetime(self, data, expected): self._create_dag("TEST_DAG_ID") response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", json=data, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns", json=data, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected @pytest.mark.parametrize( "data, expected", @@ -1501,16 +1490,16 @@ def test_should_response_400_for_naive_datetime_and_bad_datetime(self, data, exp def test_should_response_400_for_non_dict_dagrun_conf(self, data, expected): self._create_dag("TEST_DAG_ID") response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", json=data, environ_overrides={"REMOTE_USER": "test"} + "api/v1/dags/TEST_DAG_ID/dagRuns", json=data, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected def test_response_404(self): response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", json={"dag_run_id": "TEST_DAG_RUN", "execution_date": self.default_time}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 assert { @@ -1518,7 +1507,7 @@ def test_response_404(self): "status": 404, "title": "DAG not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() @pytest.mark.parametrize( "url, request_json, expected_response", @@ -1530,7 +1519,7 @@ def test_response_404(self): "execution_date": "2020-06-12T18:00:00+00:00", }, { - "detail": "Property is read-only - 'start_date'", + "detail": "{'start_date': ['Unknown field.']}", "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], @@ -1541,7 +1530,7 @@ def test_response_404(self): "api/v1/dags/TEST_DAG_ID/dagRuns", {"state": "failed", "execution_date": "2020-06-12T18:00:00+00:00"}, { - "detail": "Property is read-only - 'state'", + "detail": "{'state': ['Unknown field.']}", "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], @@ -1552,9 +1541,9 @@ def test_response_404(self): ) def test_response_400(self, url, request_json, expected_response): self._create_dag("TEST_DAG_ID") - response = self.client.post(url, json=request_json, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.post(url, json=request_json, headers={"REMOTE_USER": "test"}) assert response.status_code == 400, response.data - assert expected_response == response.json + assert expected_response == response.json() def test_response_409(self): self._create_test_dag_run() @@ -1564,10 +1553,10 @@ def test_response_409(self): "dag_run_id": "TEST_DAG_RUN_ID_1", "execution_date": self.default_time_3, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409, response.data - assert response.json == { + assert response.json() == { "detail": "DAGRun with DAG ID: 'TEST_DAG_ID' and " "DAGRun ID: 'TEST_DAG_RUN_ID_1' already exists", "status": 409, @@ -1584,11 +1573,11 @@ def test_response_409_when_execution_date_is_same(self): "dag_run_id": "TEST_DAG_RUN_ID_6", "execution_date": self.default_time, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409, response.data - assert response.json == { + assert response.json() == { "detail": "DAGRun with DAG ID: 'TEST_DAG_ID' and " "DAGRun logical date: '2020-06-11 18:00:00+00:00' already exists", "status": 409, @@ -1605,7 +1594,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "username", @@ -1619,7 +1608,7 @@ def test_should_raises_403_unauthorized(self, username): "dag_run_id": "TEST_DAG_RUN_ID_1", "execution_date": self.default_time, }, - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, ) assert response.status_code == 403 @@ -1652,7 +1641,7 @@ def test_should_respond_200(self, state, run_type, dag_maker, session): dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id, run_type=run_type) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1665,7 +1654,7 @@ def test_should_respond_200(self, state, run_type, dag_maker, session): response = self.client.patch( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) if state != "queued": @@ -1674,7 +1663,7 @@ def test_should_respond_200(self, state, run_type, dag_maker, session): dr = session.query(DagRun).filter(DagRun.run_id == dr.run_id).first() assert response.status_code == 200 - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": dag_id, "dag_run_id": dag_run_id, @@ -1695,17 +1684,21 @@ def test_schema_validation_error_raises(self, dag_maker, session): dag_id = "TEST_DAG_ID" dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: - EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) - dag_maker.create_dagrun(run_id=dag_run_id) + task = EmptyOperator(task_id="task_id", dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) + dr = dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED) + ti = dr.get_task_instance(task_id="task_id") + ti.task = task + ti.state = State.SUCCESS + session.merge(ti) + session.commit() response = self.client.patch( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}", json={"states": "success"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'states': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1726,10 +1719,10 @@ def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, response = self.client.patch( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": f"'{invalid_state}' is not one of ['success', 'failed', 'queued'] - 'state'", "status": 400, "title": "Bad Request", @@ -1744,7 +1737,7 @@ def test_should_raises_401_unauthenticated(self, session): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.patch( @@ -1752,7 +1745,7 @@ def test_should_raise_403_forbidden(self): json={ "state": "success", }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1762,7 +1755,7 @@ def test_should_respond_404(self): json={ "state": "success", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -1776,7 +1769,7 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id, run_type=DagRunType.SCHEDULED) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1798,7 +1791,7 @@ def test_should_respond_200(self, dag_maker, session): dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1811,12 +1804,12 @@ def test_should_respond_200(self, dag_maker, session): response = self.client.post( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) dr = session.query(DagRun).filter(DagRun.run_id == dr.run_id).first() assert response.status_code == 200 - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": dag_id, "dag_run_id": dag_run_id, @@ -1840,16 +1833,20 @@ def test_schema_validation_error_raises_for_invalid_fields(self, dag_maker, sess dag_id = "TEST_DAG_ID" dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: - EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) - dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED) + task = EmptyOperator(task_id="task_id", dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) + dr = dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED) + ti = dr.get_task_instance(task_id="task_id") + ti.task = task + ti.state = State.SUCCESS + session.merge(ti) + session.commit() response = self.client.post( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear", json={"dryrun": False}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'dryrun': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -1862,7 +1859,7 @@ def test_dry_run(self, dag_maker, session): dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1875,11 +1872,11 @@ def test_dry_run(self, dag_maker, session): response = self.client.post( f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear", json=request_json, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "task_instances": [ { "dag_id": dag_id, @@ -1904,7 +1901,7 @@ def test_should_raises_401_unauthenticated(self, session): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.post( @@ -1912,7 +1909,7 @@ def test_should_raise_403_forbidden(self): json={ "dry_run": True, }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1922,7 +1919,7 @@ def test_should_respond_404(self): json={ "dry_run": True, }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -1936,7 +1933,7 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c dag_run_id = "TEST_DAG_RUN_ID" with dag_maker(dag_id) as dag: task = EmptyOperator(task_id="task_id", dag=dag) - self.app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) dr = dag_maker.create_dagrun(run_id=dag_run_id, run_type=DagRunType.SCHEDULED) ti = dr.get_task_instance(task_id="task_id") ti.task = task @@ -1982,7 +1979,7 @@ def test_should_respond_200(self, dag_maker, session): response = self.client.get( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/upstreamDatasetEvents", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 expected_response = { @@ -2013,12 +2010,12 @@ def test_should_respond_200(self, dag_maker, session): ], "total_entries": 1, } - assert response.json == expected_response + assert response.json() == expected_response def test_should_respond_404(self): response = self.client.get( "api/v1/dags/invalid-id/dagRuns/invalid-id/upstreamDatasetEvents", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 expected_resp = { @@ -2027,7 +2024,7 @@ def test_should_respond_404(self): "title": "DAGRun not found", "type": EXCEPTIONS_LINK_MAP[404], } - assert expected_resp == response.json + assert expected_resp == response.json() def test_should_raises_401_unauthenticated(self, session): dagrun_model = DagRun( @@ -2043,7 +2040,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/upstreamDatasetEvents") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -2092,13 +2089,13 @@ def test_should_respond_200(self, dag_maker, session): response = self.client.patch( f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", json={"note": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) dr = session.query(DagRun).filter(DagRun.run_id == created_dr.run_id).first() assert response.status_code == 200, response.text assert dr.note == new_note_value - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": dr.dag_id, "dag_run_id": dr.run_id, @@ -2121,10 +2118,10 @@ def test_should_respond_200(self, dag_maker, session): response = self.client.patch( f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "conf": {}, "dag_id": dr.dag_id, "dag_run_id": dr.run_id, @@ -2159,10 +2156,10 @@ def test_schema_validation_error_raises(self, dag_maker, session): response = self.client.patch( f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", json={"notes": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": "{'notes': ['Unknown field.']}", "status": 400, "title": "Bad Request", @@ -2174,13 +2171,13 @@ def test_should_raises_401_unauthenticated(self, session): "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/setNote", json={"note": "I am setting a note while being unauthenticated."}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.patch( "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/setNote", json={"note": "I am setting a note without the proper permissions."}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -2188,7 +2185,7 @@ def test_should_respond_404(self): response = self.client.patch( "api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/setNote", json={"note": "I am setting a note on a DAG that doesn't exist."}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -2201,7 +2198,7 @@ def test_should_respond_200_with_anonymous_user(self, dag_maker, session): from airflow.www import app as application app = application.create_app(testing=True) - app.config["AUTH_ROLE_PUBLIC"] = "Admin" + app.app.config["AUTH_ROLE_PUBLIC"] = "Admin" dag_runs = self._create_test_dag_run(DagRunState.SUCCESS) session.add_all(dag_runs) session.commit() diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index 14c7d1534d4dcf..1688600fe2245c 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -24,7 +24,7 @@ from airflow.models import DagBag from airflow.security import permissions -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags pytestmark = pytest.mark.db_test @@ -42,38 +42,38 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], # type: ignore ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore TEST_DAG_ID, access_control={"Test": [permissions.ACTION_CAN_READ]}, ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore EXAMPLE_DAG_ID, access_control={"Test": [permissions.ACTION_CAN_READ]}, ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore TEST_MULTIPLE_DAGS_ID, access_control={"Test": [permissions.ACTION_CAN_READ]}, ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestGetSource: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore self.clear_db() def teardown_method(self) -> None: @@ -100,12 +100,9 @@ def test_should_respond_200_text(self, url_safe_serializer): dag_docstring = self._get_dag_file_docstring(test_dag.fileloc) url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.get( - url, headers={"Accept": "text/plain"}, environ_overrides={"REMOTE_USER": "test"} - ) - + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert 200 == response.status_code - assert dag_docstring in response.data.decode() + assert dag_docstring in response.text assert "text/plain" == response.headers["Content-Type"] def test_should_respond_200_json(self, url_safe_serializer): @@ -115,12 +112,10 @@ def test_should_respond_200_json(self, url_safe_serializer): dag_docstring = self._get_dag_file_docstring(test_dag.fileloc) url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.get( - url, headers={"Accept": "application/json"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(url, headers={"Accept": "application/json", "REMOTE_USER": "test"}) assert 200 == response.status_code - assert dag_docstring in response.json["content"] + assert dag_docstring in response.json()["content"] assert "application/json" == response.headers["Content-Type"] def test_should_respond_406(self, url_safe_serializer): @@ -129,18 +124,14 @@ def test_should_respond_406(self, url_safe_serializer): test_dag: DAG = dagbag.dags[TEST_DAG_ID] url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.get( - url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(url, headers={"Accept": "image/webp", "REMOTE_USER": "test"}) assert 406 == response.status_code def test_should_respond_404(self): wrong_fileloc = "abcd1234" url = f"/api/v1/dagSources/{wrong_fileloc}" - response = self.client.get( - url, headers={"Accept": "application/json"}, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(url, headers={"Accept": "application/json", "REMOTE_USER": "test"}) assert 404 == response.status_code @@ -154,7 +145,7 @@ def test_should_raises_401_unauthenticated(self, url_safe_serializer): headers={"Accept": "text/plain"}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self, url_safe_serializer): dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) @@ -163,8 +154,7 @@ def test_should_raise_403_forbidden(self, url_safe_serializer): response = self.client.get( f"/api/v1/dagSources/{url_safe_serializer.dumps(first_dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -175,12 +165,11 @@ def test_should_respond_403_not_readable(self, url_safe_serializer): response = self.client.get( f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) read_dag = self.client.get( f"/api/v1/dags/{NOT_READABLE_DAG_ID}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 403 assert read_dag.status_code == 403 @@ -192,13 +181,12 @@ def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_se response = self.client.get( f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) read_dag = self.client.get( f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 403 assert read_dag.status_code == 200 diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index cc398329b96449..915f8b959000a5 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -24,7 +24,7 @@ from airflow.models.dagwarning import DagWarning from airflow.security import permissions from airflow.utils.session import create_session -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags pytestmark = pytest.mark.db_test @@ -32,9 +32,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[ @@ -42,9 +42,9 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), ], # type: ignore ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test_with_dag2_read", role_name="TestWithDag2Read", permissions=[ @@ -53,11 +53,11 @@ def configured_app(minimal_app_for_api): ], # type: ignore ) - yield minimal_app_for_api + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_with_dag2_read") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_with_dag2_read") # type: ignore class TestBaseDagWarning: @@ -65,8 +65,8 @@ class TestBaseDagWarning: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self) -> None: clear_db_dag_warnings() @@ -95,11 +95,11 @@ def setup_method(self): def test_response_one(self): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test"}, - query_string={"dag_id": "dag1", "warning_type": "non-existent pool"}, + headers={"REMOTE_USER": "test"}, + params={"dag_id": "dag1", "warning_type": "non-existent pool"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dag_warnings": [ { @@ -115,11 +115,11 @@ def test_response_one(self): def test_response_some(self): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test"}, - query_string={"warning_type": "non-existent pool"}, + headers={"REMOTE_USER": "test"}, + params={"warning_type": "non-existent pool"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["dag_warnings"]) == 2 assert response_data == { "dag_warnings": ANY, @@ -129,11 +129,11 @@ def test_response_some(self): def test_response_none(self, session): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test"}, - query_string={"dag_id": "missing_dag"}, + headers={"REMOTE_USER": "test"}, + params={"dag_id": "missing_dag"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dag_warnings": [], "total_entries": 0, @@ -142,11 +142,11 @@ def test_response_none(self, session): def test_response_all(self): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["dag_warnings"]) == 2 assert response_data == { "dag_warnings": ANY, @@ -155,19 +155,17 @@ def test_response_all(self): def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/dagWarnings") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/dagWarnings", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/dagWarnings", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): response = self.client.get( "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, - query_string={"dag_id": "dag1"}, + headers={"REMOTE_USER": "test_with_dag2_read"}, + params={"dag_id": "dag1"}, ) assert response.status_code == 403 @@ -178,7 +176,6 @@ def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): ) def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): response = self.client.get( - "/api/v1/dagWarnings", - query_string={"dag_id": "dag1", "warning_type": "non-existent pool"}, + "/api/v1/dagWarnings?dag_id=dag1&warning_type=non-existent+pool", ) assert response.status_code == expected_status_code diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py index 5b6e2f24146e47..f49fa5e26ea2f2 100644 --- a/tests/api_connexion/endpoints/test_dataset_endpoint.py +++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py @@ -37,7 +37,7 @@ from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_datasets, clear_db_runs @@ -48,9 +48,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -58,9 +58,9 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DATASET), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_queued_event", role_name="TestQueuedEvent", permissions=[ @@ -70,11 +70,11 @@ def configured_app(minimal_app_for_api): ], ) - yield app + yield connexion_app - delete_user(app, username="test_queued_event") # type: ignore - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_queued_event") # type: ignore class TestDatasetEndpoint: @@ -82,8 +82,8 @@ class TestDatasetEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() clear_db_datasets() clear_db_runs() @@ -112,10 +112,10 @@ def test_should_respond_200(self, session): with assert_queries_count(5): response = self.client.get( f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key', safe='')}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "id": 1, "uri": "s3://bucket/key", "extra": {"foo": "bar"}, @@ -128,7 +128,7 @@ def test_should_respond_200(self, session): def test_should_respond_404(self): response = self.client.get( f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key', safe='')}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 assert { @@ -136,12 +136,12 @@ def test_should_respond_404(self): "status": 404, "title": "Dataset not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): self._create_dataset(session) response = self.client.get(f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key', safe='')}") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -177,10 +177,10 @@ def test_should_respond_200(self, session): assert session.query(DatasetModel).count() == 2 with assert_queries_count(8): - response = self.client.get("/api/v1/datasets", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -220,12 +220,12 @@ def test_order_by_raises_400_for_invalid_attr(self, session): assert session.query(DatasetModel).count() == 2 response = self.client.get( - "/api/v1/datasets?order_by=fake", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets?order_by=fake", headers={"REMOTE_USER": "test"} ) # missing attr assert response.status_code == 400 msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_raises_401_unauthenticated(self, session): datasets = [ @@ -243,7 +243,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get("/api/v1/datasets") - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize( "url, expected_datasets", @@ -273,9 +273,9 @@ def test_filter_datasets_by_uri_pattern_works(self, url, expected_datasets, sess dataset4 = DatasetModel("wasb://some_dataset_bucket_/key") session.add_all([dataset1, dataset2, dataset3, dataset4]) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dataset_urls = {dataset["uri"] for dataset in response.json["datasets"]} + dataset_urls = {dataset["uri"] for dataset in response.json()["datasets"]} assert expected_datasets == dataset_urls @pytest.mark.parametrize("dag_ids, expected_num", [("dag1,dag2", 2), ("dag3", 1), ("dag2,dag3", 2)]) @@ -294,11 +294,9 @@ def test_filter_datasets_by_dag_ids_works(self, dag_ids, expected_num, session): task_ref1 = TaskOutletDatasetReference(dag_id="dag3", task_id="task1", dataset=dataset3) session.add_all([dataset1, dataset2, dataset3, dag1, dag2, dag3, dag_ref1, dag_ref2, task_ref1]) session.commit() - response = self.client.get( - f"/api/v1/datasets?dag_ids={dag_ids}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/datasets?dag_ids={dag_ids}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["datasets"]) == expected_num @pytest.mark.parametrize( @@ -323,10 +321,10 @@ def test_filter_datasets_by_dag_ids_and_uri_pattern_works( session.commit() response = self.client.get( f"/api/v1/datasets?dag_ids={dag_ids}&uri_pattern={uri_pattern}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["datasets"]) == expected_num @pytest.mark.parametrize( @@ -383,10 +381,10 @@ def test_limit_and_offset(self, url, expected_dataset_uris, session): session.add_all(datasets) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - dataset_uris = [dataset["uri"] for dataset in response.json["datasets"]] + dataset_uris = [dataset["uri"] for dataset in response.json()["datasets"]] assert dataset_uris == expected_dataset_uris def test_should_respect_page_size_limit_default(self, session): @@ -402,10 +400,10 @@ def test_should_respect_page_size_limit_default(self, session): session.add_all(datasets) session.commit() - response = self.client.get("/api/v1/datasets", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["datasets"]) == 100 + assert len(response.json()["datasets"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): @@ -421,10 +419,10 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): session.add_all(datasets) session.commit() - response = self.client.get("/api/v1/datasets?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["datasets"]) == 150 + assert len(response.json()["datasets"]) == 150 class TestGetDatasetEvents(TestDatasetEndpoint): @@ -445,10 +443,10 @@ def test_should_respond_200(self, session): session.commit() assert session.query(DatasetEvent).count() == 2 - response = self.client.get("/api/v1/datasets/events", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets/events", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dataset_events": [ { @@ -507,12 +505,10 @@ def test_filtering(self, attr, value, session): session.commit() assert session.query(DatasetEvent).count() == 3 - response = self.client.get( - f"/api/v1/datasets/events?{attr}={value}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/datasets/events?{attr}={value}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dataset_events": [ { @@ -550,16 +546,16 @@ def test_order_by_raises_400_for_invalid_attr(self, session): assert session.query(DatasetEvent).count() == 2 response = self.client.get( - "/api/v1/datasets/events?order_by=fake", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets/events?order_by=fake", headers={"REMOTE_USER": "test"} ) # missing attr assert response.status_code == 400 msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_raises_401_unauthenticated(self, session): response = self.client.get("/api/v1/datasets/events") - assert_401(response) + assert response.status_code == 401 def test_includes_created_dagrun(self, session): self._create_dataset(session) @@ -587,10 +583,10 @@ def test_includes_created_dagrun(self, session): event.created_dagruns.append(dagrun) session.commit() - response = self.client.get("/api/v1/datasets/events", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets/events", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "dataset_events": [ { @@ -662,11 +658,11 @@ def test_should_respond_200(self, session): self._create_dataset(session) event_payload = {"dataset_uri": "s3://bucket/key", "extra": {"foo": "bar"}} response = self.client.post( - "/api/v1/datasets/events", json=event_payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets/events", json=event_payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "id": ANY, "created_dagruns": [], @@ -692,7 +688,7 @@ def test_should_mask_sensitive_extra_logs(self, session): self._create_dataset(session) event_payload = {"dataset_uri": "s3://bucket/key", "extra": {"password": "bar"}} response = self.client.post( - "/api/v1/datasets/events", json=event_payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets/events", json=event_payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 @@ -709,14 +705,14 @@ def test_order_by_raises_400_for_invalid_attr(self, session): self._create_dataset(session) event_invalid_payload = {"dataset_uri": "TEST_DATASET_URI", "extra": {"foo": "bar"}, "fake": {}} response = self.client.post( - "/api/v1/datasets/events", json=event_invalid_payload, environ_overrides={"REMOTE_USER": "test"} + "/api/v1/datasets/events", json=event_invalid_payload, headers={"REMOTE_USER": "test"} ) - assert response.status_code == 400 + assert response.json()["status"] == 400 def test_should_raises_401_unauthenticated(self, session): self._create_dataset(session) response = self.client.post("/api/v1/datasets/events", json={"dataset_uri": "TEST_DATASET_URI"}) - assert_401(response) + assert response.json()["status"] == 401 @pytest.mark.parametrize( "set_auto_role_public, expected_status_code", @@ -775,10 +771,10 @@ def test_limit_and_offset(self, url, expected_event_runids, session): session.add_all(events) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - event_runids = [event["source_run_id"] for event in response.json["dataset_events"]] + event_runids = [event["source_run_id"] for event in response.json()["dataset_events"]] assert event_runids == expected_event_runids def test_should_respect_page_size_limit_default(self, session): @@ -797,10 +793,10 @@ def test_should_respect_page_size_limit_default(self, session): session.add_all(events) session.commit() - response = self.client.get("/api/v1/datasets/events", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/datasets/events", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["dataset_events"]) == 100 + assert len(response.json()["dataset_events"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): @@ -819,12 +815,10 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): session.add_all(events) session.commit() - response = self.client.get( - "/api/v1/datasets/events?limit=180", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/datasets/events?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["dataset_events"]) == 150 + assert len(response.json()["dataset_events"]) == 150 class TestQueuedEventEndpoint(TestDatasetEndpoint): @@ -855,11 +849,11 @@ def test_should_respond_200(self, session, create_dummy_dag): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "created_at": self.default_time, "uri": "s3://bucket/key", "dag_id": "dag", @@ -871,7 +865,7 @@ def test_should_respond_404(self): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -880,7 +874,7 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" @@ -888,7 +882,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get(f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self, session): dag_id = "dummy" @@ -896,7 +890,7 @@ def test_should_raise_403_forbidden(self, session): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -938,7 +932,7 @@ def test_delete_should_respond_204(self, session, create_dummy_dag): response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 204 @@ -954,7 +948,7 @@ def test_should_respond_404(self): response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -963,20 +957,20 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" dataset_uri = "dummy" response = self.client.delete(f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self, session): dag_id = "dummy" dataset_uri = "dummy" response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -991,11 +985,11 @@ def test_should_respond_200(self, session, create_dummy_dag): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "queued_events": [ { "created_at": self.default_time, @@ -1011,7 +1005,7 @@ def test_should_respond_404(self): response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -1020,21 +1014,21 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): dag_id = "dummy" response = self.client.get(f"/api/v1/dags/{dag_id}/datasets/queuedEvent") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dag_id = "dummy" response = self.client.get( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1064,7 +1058,7 @@ def test_should_respond_404(self): response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -1073,21 +1067,21 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): dag_id = "dummy" response = self.client.delete(f"/api/v1/dags/{dag_id}/datasets/queuedEvent") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dag_id = "dummy" response = self.client.delete( f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1129,11 +1123,11 @@ def test_should_respond_200(self, session, create_dummy_dag): response = self.client.get( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "queued_events": [ { "created_at": self.default_time, @@ -1149,7 +1143,7 @@ def test_should_respond_404(self): response = self.client.get( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -1158,21 +1152,21 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" response = self.client.get(f"/api/v1/datasets/queuedEvent/{dataset_uri}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dataset_uri = "not_exists" response = self.client.get( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -1208,7 +1202,7 @@ def test_delete_should_respond_204(self, session, create_dummy_dag): response = self.client.delete( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 204 @@ -1221,7 +1215,7 @@ def test_should_respond_404(self): response = self.client.delete( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, + headers={"REMOTE_USER": "test_queued_event"}, ) assert response.status_code == 404 @@ -1230,21 +1224,21 @@ def test_should_respond_404(self): "status": 404, "title": "Queue event not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" response = self.client.delete(f"/api/v1/datasets/queuedEvent/{dataset_uri}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dataset_uri = "not_exists" response = self.client.delete( f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index 6738858ddd00fa..aca91ae59a0e59 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -22,7 +22,7 @@ from airflow.models import Log from airflow.security import permissions from airflow.utils import timezone -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_logs @@ -31,34 +31,34 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore ) create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test_granular", role_name="TestGranular", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_ID_1", access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "TEST_DAG_ID_2", access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_granular") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_granular") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore @pytest.fixture @@ -91,7 +91,9 @@ def maker(event, when, **kwargs): log_model.dttm = when session.add(log_model) + session.commit() session.flush() + session.close() return log_model return maker @@ -100,8 +102,8 @@ def maker(event, when, **kwargs): class TestEventLogEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore clear_db_logs() self.default_time = timezone.parse("2020-06-10T20:00:00+00:00") self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00") @@ -116,9 +118,7 @@ def teardown_method(self) -> None: ) def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, log_model): event_log_id = log_model.id - response = self.client.get( - f"/api/v1/eventLogs/{event_log_id}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/eventLogs/{event_log_id}", headers={"REMOTE_USER": "test"}) response = self.client.get("/api/v1/eventLogs") @@ -128,11 +128,9 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestGetEventLog(TestEventLogEndpoint): def test_should_respond_200(self, log_model): event_log_id = log_model.id - response = self.client.get( - f"/api/v1/eventLogs/{event_log_id}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/eventLogs/{event_log_id}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "event_log_id": event_log_id, "event": "TEST_EVENT", "dag_id": "TEST_DAG_ID", @@ -145,26 +143,24 @@ def test_should_respond_200(self, log_model): } def test_should_respond_404(self): - response = self.client.get("/api/v1/eventLogs/1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/eventLogs/1", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": None, "status": 404, "title": "Event Log not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, log_model): event_log_id = log_model.id response = self.client.get(f"/api/v1/eventLogs/{event_log_id}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/eventLogs", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/eventLogs", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @pytest.mark.parametrize( @@ -188,10 +184,12 @@ def test_should_respond_200(self, session, create_log_model): log_model_3.dttm = self.default_time_2 session.add(log_model_3) + session.commit() session.flush() - response = self.client.get("/api/v1/eventLogs", environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.get("/api/v1/eventLogs", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "event_logs": [ { "event_log_id": log_model_1.id, @@ -236,12 +234,12 @@ def test_order_eventlogs_by_owner(self, create_log_model, session): log_model_3 = Log(event="cli_scheduler", owner="root", extra='{"host_name": "e24b454f002a"}') log_model_3.dttm = self.default_time_2 session.add(log_model_3) + session.commit() session.flush() - response = self.client.get( - "/api/v1/eventLogs?order_by=-owner", environ_overrides={"REMOTE_USER": "test"} - ) + session.close() + response = self.client.get("/api/v1/eventLogs?order_by=-owner", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "event_logs": [ { "event_log_id": log_model_2.id, @@ -283,7 +281,7 @@ def test_order_eventlogs_by_owner(self, create_log_model, session): def test_should_raises_401_unauthenticated(self, log_model): response = self.client.get("/api/v1/eventLogs") - assert_401(response) + assert response.status_code == 401 def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): eventlog1 = create_log_model( @@ -302,33 +300,36 @@ def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, s ) session.add_all([eventlog1, eventlog2]) session.commit() + session.close() for attr in ["dag_id", "task_id", "owner", "event"]: attr_value = f"TEST_{attr}_1".upper() response = self.client.get( - f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} + f"/api/v1/eventLogs?{attr}={attr_value}", headers={"REMOTE_USER": "test_granular"} ) assert response.status_code == 200 - assert response.json["total_entries"] == 1 - assert len(response.json["event_logs"]) == 1 - assert response.json["event_logs"][0][attr] == attr_value + assert {eventlog[attr] for eventlog in response.json()["event_logs"]} == {attr_value} + assert response.json()["total_entries"] == 1 + assert len(response.json()["event_logs"]) == 1 + assert response.json()["event_logs"][0][attr] == attr_value def test_should_filter_eventlogs_by_when(self, create_log_model, session): eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time) eventlog2 = create_log_model(event="TEST_EVENT_2", when=self.default_time_2) session.add_all([eventlog1, eventlog2]) session.commit() + session.close() for when_attr, expected_eventlog_event in { "before": "TEST_EVENT_1", "after": "TEST_EVENT_2", }.items(): response = self.client.get( f"/api/v1/eventLogs?{when_attr}=2020-06-10T20%3A00%3A01%2B00%3A00", # self.default_time + 1s - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 1 - assert len(response.json["event_logs"]) == 1 - assert response.json["event_logs"][0]["event"] == expected_eventlog_event + assert response.json()["total_entries"] == 1 + assert len(response.json()["event_logs"]) == 1 + assert response.json()["event_logs"][0]["event"] == expected_eventlog_event def test_should_filter_eventlogs_by_run_id(self, create_log_model, session): eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time, run_id="run_1") @@ -336,29 +337,30 @@ def test_should_filter_eventlogs_by_run_id(self, create_log_model, session): eventlog3 = create_log_model(event="TEST_EVENT_3", when=self.default_time, run_id="run_2") session.add_all([eventlog1, eventlog2, eventlog3]) session.commit() + session.close() for run_id, expected_eventlogs in { "run_1": {"TEST_EVENT_1"}, "run_2": {"TEST_EVENT_2", "TEST_EVENT_3"}, }.items(): response = self.client.get( f"/api/v1/eventLogs?run_id={run_id}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == len(expected_eventlogs) - assert len(response.json["event_logs"]) == len(expected_eventlogs) - assert {eventlog["event"] for eventlog in response.json["event_logs"]} == expected_eventlogs - assert all({eventlog["run_id"] == run_id for eventlog in response.json["event_logs"]}) + assert response.json()["total_entries"] == len(expected_eventlogs) + assert len(response.json()["event_logs"]) == len(expected_eventlogs) + assert {eventlog["event"] for eventlog in response.json()["event_logs"]} == expected_eventlogs + assert all({eventlog["run_id"] == run_id for eventlog in response.json()["event_logs"]}) def test_should_filter_eventlogs_by_included_events(self, create_log_model): for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: create_log_model(event=event, when=self.default_time) response = self.client.get( "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, + headers={"REMOTE_USER": "test_granular"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["event_logs"]) == 2 assert response_data["total_entries"] == 2 assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} @@ -368,10 +370,10 @@ def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): create_log_model(event=event, when=self.default_time) response = self.client.get( "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, + headers={"REMOTE_USER": "test_granular"}, ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert len(response_data["event_logs"]) == 1 assert response_data["total_entries"] == 1 assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} @@ -437,46 +439,48 @@ def test_handle_limit_and_offset(self, url, expected_events, task_instance, sess log_models = self._create_event_logs(task_instance, 10) session.add_all(log_models) session.commit() - - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - events = [event_log["event"] for event_log in response.json["event_logs"]] + assert response.json()["total_entries"] == 10 + events = [event_log["event"] for event_log in response.json()["event_logs"]] assert events == expected_events def test_should_respect_page_size_limit_default(self, task_instance, session): log_models = self._create_event_logs(task_instance, 200) session.add_all(log_models) + session.commit() session.flush() + session.close() - response = self.client.get("/api/v1/eventLogs", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/eventLogs", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["event_logs"]) == 100 # default 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["event_logs"]) == 100 # default 100 def test_should_raise_400_for_invalid_order_by_name(self, task_instance, session): log_models = self._create_event_logs(task_instance, 200) session.add_all(log_models) + session.commit() session.flush() - - response = self.client.get( - "/api/v1/eventLogs?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} - ) + session.close() + response = self.client.get("/api/v1/eventLogs?order_by=invalid", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, task_instance, session): log_models = self._create_event_logs(task_instance, 200) session.add_all(log_models) + session.commit() session.flush() - - response = self.client.get("/api/v1/eventLogs?limit=180", environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.get("/api/v1/eventLogs?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["event_logs"]) == 150 + assert len(response.json()["event_logs"]) == 150 def _create_event_logs(self, task_instance, count): return [Log(event=f"TEST_EVENT_{i}", task_instance=task_instance) for i in range(1, count + 1)] diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index 3e803a4bf4a573..19c43d3c738f52 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -42,10 +42,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -54,12 +54,12 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestGetExtraLinks: @@ -70,13 +70,13 @@ def setup_attrs(self, configured_app, session) -> None: clear_db_runs() clear_db_xcom() - self.app = configured_app + self.connexion_app = configured_app self.dag = self._create_dag() - self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {self.dag.dag_id: self.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.connexion_app.app.dag_bag = DagBag(os.devnull, include_examples=False) + self.connexion_app.app.dag_bag.dags = {self.dag.dag_id: self.dag} # type: ignore + self.connexion_app.app.dag_bag.sync_to_db() # type: ignore self.dag.create_dagrun( run_id="TEST_DAG_RUN_ID", @@ -86,9 +86,10 @@ def setup_attrs(self, configured_app, session) -> None: session=session, data_interval=DataInterval(timezone.datetime(2020, 1, 1), timezone.datetime(2020, 1, 2)), ) + session.commit() session.flush() - - self.client = self.app.test_client() # type:ignore + session.close() + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self) -> None: clear_db_runs() @@ -124,7 +125,7 @@ def _create_dag(self): ], ) def test_should_respond_404(self, url, expected_title, expected_detail): - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert 404 == response.status_code assert { @@ -132,12 +133,12 @@ def test_should_respond_404(self, url, expected_title, expected_detail): "status": 404, "title": expected_title, "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raise_403_forbidden(self): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -152,23 +153,23 @@ def test_should_respond_200(self): ) response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data + assert 200 == response.status_code assert { "BigQuery Console": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID" - } == response.json + } == response.json() @mock_plugin_manager(plugins=[]) def test_should_respond_200_missing_xcom(self): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data - assert {"BigQuery Console": None} == response.json + assert 200 == response.status_code + assert {"BigQuery Console": None} == response.json() @mock_plugin_manager(plugins=[]) def test_should_respond_200_multiple_links(self): @@ -181,24 +182,24 @@ def test_should_respond_200_multiple_links(self): ) response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_MULTIPLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data + assert 200 == response.status_code assert { "BigQuery Console #1": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID_1", "BigQuery Console #2": "https://console.cloud.google.com/bigquery?j=TEST_JOB_ID_2", - } == response.json + } == response.json() @mock_plugin_manager(plugins=[]) def test_should_respond_200_multiple_links_missing_xcom(self): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_MULTIPLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data - assert {"BigQuery Console #1": None, "BigQuery Console #2": None} == response.json + assert 200 == response.status_code + assert {"BigQuery Console #1": None, "BigQuery Console #2": None} == response.json() def test_should_respond_200_support_plugins(self): class GoogleLink(BaseOperatorLink): @@ -229,10 +230,10 @@ class AirflowTestPlugin(AirflowPlugin): with mock_plugin_manager(plugins=[AirflowTestPlugin]): response = self.client.get( "/api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/TEST_SINGLE_QUERY/links", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert 200 == response.status_code, response.data + assert 200 == response.status_code assert { "BigQuery Console": None, "Google": "https://www.google.com", @@ -240,4 +241,4 @@ class AirflowTestPlugin(AirflowPlugin): "https://s3.amazonaws.com/airflow-logs/" "TEST_DAG_ID/TEST_SINGLE_QUERY/2020-01-01T00%3A00%3A00%2B00%3A00" ), - } == response.json + } == response.json() diff --git a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py index a9f2d9ceb46916..3a71fc9d67e286 100644 --- a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py +++ b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py @@ -59,7 +59,7 @@ def autoclean_user_payload(autoclean_username, autoclean_email): @pytest.fixture def autoclean_admin_user(configured_app, autoclean_user_payload): - security_manager = configured_app.appbuilder.sm + security_manager = configured_app.app.appbuilder.sm return security_manager.add_user( role=security_manager.find_role("Admin"), **autoclean_user_payload, @@ -82,9 +82,9 @@ def autoclean_email(): @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -100,28 +100,29 @@ def configured_app(minimal_app_for_api): ], ) - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore class TestFABforwarding: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = configured_app.app + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self): """ Delete all roles except these ones. Test and TestNoPermissions are deleted by delete_user above """ - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session existing_roles = set(EXISTING_ROLES) existing_roles.update(["Test", "TestNoPermissions"]) roles = session.query(Role).filter(~Role.name.in_(existing_roles)).all() for role in roles: - delete_role(self.app, role.name) + delete_role(self.flask_app, role.name) users = session.query(User).filter(User.changed_on == timezone.parse(DEFAULT_TIME)) users.delete(synchronize_session=False) session.commit() @@ -130,31 +131,31 @@ def teardown_method(self): class TestFABRoleForwarding(TestFABforwarding): @mock.patch("airflow.api_connexion.endpoints.forward_to_fab_endpoint.get_auth_manager") def test_raises_400_if_manager_is_not_fab(self, mock_get_auth_manager): - mock_get_auth_manager.return_value = BaseAuthManager(self.app.appbuilder) - response = self.client.get("api/v1/roles", environ_overrides={"REMOTE_USER": "test"}) + mock_get_auth_manager.return_value = BaseAuthManager(self.flask_app.appbuilder) + response = self.client.get("api/v1/roles", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert ( - response.json["detail"] + response.json()["detail"] == "This endpoint is only available when using the default auth manager FabAuthManager." ) def test_get_role_forwards_to_fab(self): - resp = self.client.get("api/v1/roles/Test", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/roles/Test", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_get_roles_forwards_to_fab(self): - resp = self.client.get("api/v1/roles", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/roles", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_delete_role_forwards_to_fab(self): - role = create_role(self.app, "mytestrole") - resp = self.client.delete(f"api/v1/roles/{role.name}", environ_overrides={"REMOTE_USER": "test"}) + role = create_role(self.flask_app, "mytestrole") + resp = self.client.delete(f"api/v1/roles/{role.name}", headers={"REMOTE_USER": "test"}) assert resp.status_code == 204 def test_patch_role_forwards_to_fab(self): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") resp = self.client.patch( - f"api/v1/roles/{role.name}", json={"name": "Test2"}, environ_overrides={"REMOTE_USER": "test"} + f"api/v1/roles/{role.name}", json={"name": "Test2"}, headers={"REMOTE_USER": "test"} ) assert resp.status_code == 200 @@ -163,11 +164,11 @@ def test_post_role_forwards_to_fab(self): "name": "Test2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } - resp = self.client.post("api/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.post("api/v1/roles", json=payload, headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_get_role_permissions_forwards_to_fab(self): - resp = self.client.get("api/v1/permissions", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/permissions", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 @@ -192,29 +193,29 @@ def _create_users(self, count, roles=None): def test_get_user_forwards_to_fab(self): users = self._create_users(1) - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session session.add_all(users) session.commit() - resp = self.client.get("api/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/users/TEST_USER1", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_get_users_forwards_to_fab(self): users = self._create_users(2) - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session session.add_all(users) session.commit() - resp = self.client.get("api/v1/users", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.get("api/v1/users", headers={"REMOTE_USER": "test"}) assert resp.status_code == 200 def test_post_user_forwards_to_fab(self, autoclean_username, autoclean_user_payload): response = self.client.post( "/api/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() - security_manager = self.app.appbuilder.sm + security_manager = self.flask_app.appbuilder.sm user = security_manager.find_user(autoclean_username) assert user is not None assert user.roles == [security_manager.find_role("Public")] @@ -225,14 +226,14 @@ def test_patch_user_forwards_to_fab(self, autoclean_username, autoclean_user_pay response = self.client.patch( f"/api/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() def test_delete_user_forwards_to_fab(self): users = self._create_users(1) - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session session.add_all(users) session.commit() - resp = self.client.delete("api/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test"}) + resp = self.client.delete("api/v1/users/TEST_USER1", headers={"REMOTE_USER": "test"}) assert resp.status_code == 204 diff --git a/tests/api_connexion/endpoints/test_health_endpoint.py b/tests/api_connexion/endpoints/test_health_endpoint.py index 7d73b338e5105d..3f68f75a4ba651 100644 --- a/tests/api_connexion/endpoints/test_health_endpoint.py +++ b/tests/api_connexion/endpoints/test_health_endpoint.py @@ -36,8 +36,8 @@ class TestHealthTestBase: @pytest.fixture(autouse=True) def setup_attrs(self, minimal_app_for_api) -> None: - self.app = minimal_app_for_api - self.client = self.app.test_client() # type:ignore + self.connexion_app = minimal_app_for_api + self.client = self.connexion_app.test_client() # type:ignore with create_session() as session: session.query(Job).delete() @@ -54,7 +54,8 @@ def test_healthy_scheduler_status(self, session): SchedulerJobRunner(job=job) session.add(job) session.commit() - resp_json = self.client.get("/api/v1/health").json + session.close() + resp_json = self.client.get("/api/v1/health").json() assert "healthy" == resp_json["metadatabase"]["status"] assert "healthy" == resp_json["scheduler"]["status"] assert ( @@ -69,7 +70,8 @@ def test_unhealthy_scheduler_is_slow(self, session): SchedulerJobRunner(job=job) session.add(job) session.commit() - resp_json = self.client.get("/api/v1/health").json + session.close() + resp_json = self.client.get("/api/v1/health").json() assert "healthy" == resp_json["metadatabase"]["status"] assert "unhealthy" == resp_json["scheduler"]["status"] assert ( @@ -78,7 +80,7 @@ def test_unhealthy_scheduler_is_slow(self, session): ) def test_unhealthy_scheduler_no_job(self): - resp_json = self.client.get("/api/v1/health").json + resp_json = self.client.get("/api/v1/health").json() assert "healthy" == resp_json["metadatabase"]["status"] assert "unhealthy" == resp_json["scheduler"]["status"] assert resp_json["scheduler"]["latest_scheduler_heartbeat"] is None @@ -86,6 +88,6 @@ def test_unhealthy_scheduler_no_job(self): @mock.patch.object(SchedulerJobRunner, "most_recent_job") def test_unhealthy_metadatabase_status(self, most_recent_job_mock): most_recent_job_mock.side_effect = Exception - resp_json = self.client.get("/api/v1/health").json + resp_json = self.client.get("/api/v1/health").json() assert "unhealthy" == resp_json["metadatabase"]["status"] assert resp_json["scheduler"]["latest_scheduler_heartbeat"] is None diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py b/tests/api_connexion/endpoints/test_import_error_endpoint.py index ce084165d24e27..ad710b80179190 100644 --- a/tests/api_connexion/endpoints/test_import_error_endpoint.py +++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py @@ -26,7 +26,7 @@ from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_import_errors @@ -37,9 +37,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test", role_name="Test", permissions=[ @@ -47,16 +47,16 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), ], # type: ignore ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore create_user( - app, # type:ignore + connexion_app.app, # type:ignore username="test_single_dag", role_name="TestSingleDAG", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore ) + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore # For some reason, DAG level permissions are not synced when in the above list of perms, # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( + connexion_app.app.appbuilder.sm.bulk_sync_roles( [ { "role": "TestSingleDAG", @@ -65,11 +65,11 @@ def configured_app(minimal_app_for_api): ] ) - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_single_dag") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_single_dag") # type: ignore class TestBaseImportError: @@ -77,8 +77,8 @@ class TestBaseImportError: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore clear_db_import_errors() clear_db_dags() @@ -103,12 +103,10 @@ def test_response_200(self, session): session.add(import_error) session.commit() - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/importErrors/{import_error.id}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() response_data["import_error_id"] = 1 assert { "filename": "Lorem_ipsum.py", @@ -118,14 +116,14 @@ def test_response_200(self, session): } == response_data def test_response_404(self): - response = self.client.get("/api/v1/importErrors/2", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/importErrors/2", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "The ImportError with import_error_id: `2` was not found", "status": 404, "title": "Import error not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): import_error = ParseImportError( @@ -138,12 +136,10 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get(f"/api/v1/importErrors/{import_error.id}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 def test_should_raise_403_forbidden_without_dag_read(self, session): @@ -156,7 +152,7 @@ def test_should_raise_403_forbidden_without_dag_read(self, session): session.commit() response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + f"/api/v1/importErrors/{import_error.id}", headers={"REMOTE_USER": "test_single_dag"} ) assert response.status_code == 403 @@ -173,11 +169,11 @@ def test_should_return_200_with_single_dag_read(self, session): session.commit() response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + f"/api/v1/importErrors/{import_error.id}", headers={"REMOTE_USER": "test_single_dag"} ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() response_data["import_error_id"] = 1 assert { "filename": "Lorem_ipsum.py", @@ -199,11 +195,11 @@ def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, sessio session.commit() response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + f"/api/v1/importErrors/{import_error.id}", headers={"REMOTE_USER": "test_single_dag"} ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() response_data["import_error_id"] = 1 assert { "filename": "Lorem_ipsum.py", @@ -226,10 +222,10 @@ def test_get_import_errors(self, session): session.add_all(import_error) session.commit() - response = self.client.get("/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() self._normalize_import_errors(response_data["import_errors"]) assert { "import_errors": [ @@ -262,11 +258,11 @@ def test_get_import_errors_order_by(self, session): session.commit() response = self.client.get( - "/api/v1/importErrors?order_by=-timestamp", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/importErrors?order_by=-timestamp", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - response_data = response.json + response_data = response.json() self._normalize_import_errors(response_data["import_errors"]) assert { "import_errors": [ @@ -298,13 +294,11 @@ def test_order_by_raises_400_for_invalid_attr(self, session): session.add_all(import_error) session.commit() - response = self.client.get( - "/api/v1/importErrors?order_by=timest", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/importErrors?order_by=timest", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'timest' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_raises_401_unauthenticated(self, session): import_error = [ @@ -320,7 +314,7 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get("/api/v1/importErrors") - assert_401(response) + assert response.status_code == 401 def test_get_import_errors_single_dag(self, session): for dag_id in TEST_DAG_IDS: @@ -335,12 +329,10 @@ def test_get_import_errors_single_dag(self, session): session.add(importerror) session.commit() - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test_single_dag"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() self._normalize_import_errors(response_data["import_errors"]) assert { "import_errors": [ @@ -368,12 +360,10 @@ def test_get_import_errors_single_dag_in_dagfile(self, session): session.add(importerror) session.commit() - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test_single_dag"}) assert response.status_code == 200 - response_data = response.json + response_data = response.json() self._normalize_import_errors(response_data["import_errors"]) assert { "import_errors": [ @@ -415,10 +405,10 @@ def test_limit_and_offset(self, url, expected_import_error_ids, session): session.add_all(import_errors) session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - import_ids = [pool["filename"] for pool in response.json["import_errors"]] + import_ids = [pool["filename"] for pool in response.json()["import_errors"]] assert import_ids == expected_import_error_ids def test_should_respect_page_size_limit_default(self, session): @@ -432,9 +422,9 @@ def test_should_respect_page_size_limit_default(self, session): ] session.add_all(import_errors) session.commit() - response = self.client.get("/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/importErrors", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["import_errors"]) == 100 + assert len(response.json()["import_errors"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): @@ -448,8 +438,6 @@ def test_should_return_conf_max_if_req_max_above_conf(self, session): ] session.add_all(import_errors) session.commit() - response = self.client.get( - "/api/v1/importErrors?limit=180", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/importErrors?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["import_errors"]) == 150 + assert len(response.json()["import_errors"]) == 150 diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index d472b6902b3b19..05fce4e381629e 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -33,7 +33,7 @@ from airflow.security import permissions from airflow.utils import timezone from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_runs pytestmark = pytest.mark.db_test @@ -41,10 +41,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, + connexion_app.app, username="test", role_name="Test", permissions=[ @@ -52,12 +52,12 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") - yield app + yield connexion_app - delete_user(app, username="test") - delete_user(app, username="test_no_permissions") + delete_user(connexion_app.app, username="test") + delete_user(connexion_app.app, username="test_no_permissions") class TestGetLog: @@ -71,8 +71,9 @@ class TestGetLog: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app, configure_loggers, dag_maker, session) -> None: - self.app = configured_app - self.client = self.app.test_client() + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # Make sure that the configure_logging is not cached self.old_modules = dict(sys.modules) @@ -92,7 +93,7 @@ def add_one(x: int): start_date=timezone.parse(self.default_time), ) - configured_app.dag_bag.bag_dag(dag, root_dag=dag) + self.flask_app.dag_bag.bag_dag(dag, root_dag=dag) # Add dummy dag for checking picking correct log with same task_id and different dag_id case. with dag_maker( @@ -105,13 +106,15 @@ def add_one(x: int): execution_date=timezone.parse(self.default_time), start_date=timezone.parse(self.default_time), ) - configured_app.dag_bag.bag_dag(dummy_dag, root_dag=dummy_dag) + self.flask_app.dag_bag.bag_dag(dummy_dag, root_dag=dummy_dag) for ti in dr.task_instances: ti.try_number = 1 ti.hostname = "localhost" self.ti = dr.task_instances[0] + session.commit() + session.close() @pytest.fixture def configure_loggers(self, tmp_path, create_log_template): @@ -145,6 +148,11 @@ def configure_loggers(self, tmp_path, create_log_template): logging.config.dictConfig(logging_config) + create_log_template( + "dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ ti.task_id }}/" + "{% if ti.map_index >= 0 %}map_index={{ ti.map_index }}/{% endif %}" + "attempt={{ try_number }}.log" + ) yield logging.config.dictConfig(DEFAULT_LOGGING_CONFIG) @@ -153,23 +161,22 @@ def teardown_method(self): clear_db_runs() def test_should_respond_200_json(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": False}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) expected_filename = ( f"{self.log_dir}/dag_id={self.DAG_ID}/run_id={self.RUN_ID}/task_id={self.TASK_ID}/attempt=1.log" ) assert ( - response.json["content"] + response.json()["content"] == f"[('localhost', '*** Found local files:\\n*** * {expected_filename}\\nLog for testing.')]" ) - info = serializer.loads(response.json["continuation_token"]) + info = serializer.loads(response.json()["continuation_token"]) assert info == {"end_of_log": True, "log_pos": 16} assert 200 == response.status_code @@ -191,19 +198,18 @@ def test_should_respond_200_json(self): def test_should_respond_200_text_plain(self, request_url, expected_filename, extra_query_string): expected_filename = expected_filename.replace("LOG_DIR", str(self.log_dir)) - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( request_url, - query_string={"token": token, **extra_query_string}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token, **extra_query_string}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert 200 == response.status_code assert ( - response.data.decode("utf-8") + response.text == f"localhost\n*** Found local files:\n*** * {expected_filename}\nLog for testing.\n" ) @@ -226,40 +232,39 @@ def test_get_logs_of_removed_task(self, request_url, expected_filename, extra_qu expected_filename = expected_filename.replace("LOG_DIR", str(self.log_dir)) # Recreate DAG without tasks - dagbag = self.app.dag_bag + dagbag = self.flask_app.dag_bag dag = DAG(self.DAG_ID, start_date=timezone.parse(self.default_time)) del dagbag.dags[self.DAG_ID] dagbag.bag_dag(dag=dag, root_dag=dag) - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( request_url, - query_string={"token": token, **extra_query_string}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token, **extra_query_string}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert 200 == response.status_code assert ( - response.data.decode("utf-8") + response.text == f"localhost\n*** Found local files:\n*** * {expected_filename}\nLog for testing.\n" ) def test_get_logs_response_with_ti_equal_to_none(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/Invalid-Task-ID/logs/1", - query_string={"token": token}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": None, "status": 404, "title": "TaskInstance not found", @@ -277,43 +282,40 @@ def test_get_logs_with_metadata_as_download_large_file(self): response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/" f"taskInstances/{self.TASK_ID}/logs/1?full_content=True", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) - assert "1st line" in response.data.decode("utf-8") - assert "2nd line" in response.data.decode("utf-8") - assert "3rd line" in response.data.decode("utf-8") - assert "should never be read" not in response.data.decode("utf-8") + assert "1st line" in response.text + assert "2nd line" in response.text + assert "3rd line" in response.text + assert "should never be read" not in response.text @mock.patch("airflow.api_connexion.endpoints.log_endpoint.TaskLogReader") def test_get_logs_for_handler_without_read_method(self, mock_log_reader): type(mock_log_reader.return_value).supports_read = PropertyMock(return_value=False) - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": False}) # check guessing response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Content-Type": "application/jso"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"Content-Type": "application/json", "REMOTE_USER": "test"}, ) assert 400 == response.status_code - assert "Task log handler does not support read logs." in response.data.decode("utf-8") + assert "Task log handler does not support read logs." in response.text def test_bad_signature_raises(self): token = {"download_logs": False} response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) - assert response.json == { + assert response.json() == { "detail": None, "status": 400, "title": "Bad Signature. Please use only the tokens provided by the API.", @@ -324,11 +326,10 @@ def test_raises_404_for_invalid_dag_run_id(self): response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/NO_DAG_RUN/" # invalid run_id f"taskInstances/{self.TASK_ID}/logs/1?", - headers={"Accept": "application/json"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"Accept": "application/json", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": None, "status": 404, "title": "TaskInstance not found", @@ -336,55 +337,52 @@ def test_raises_404_for_invalid_dag_run_id(self): } def test_should_raises_401_unauthenticated(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": False}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, + params={"token": token}, headers={"Accept": "application/json"}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + params={"token": token}, + headers={"Accept": "text/plain", "REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 def test_should_raise_404_when_missing_map_index_param_for_mapped_task(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.MAPPED_TASK_ID}/logs/1", - query_string={"token": token}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "TaskInstance not found" + assert response.json()["title"] == "TaskInstance not found" def test_should_raise_404_when_filtering_on_map_index_for_unmapped_task(self): - key = self.app.config["SECRET_KEY"] + key = self.flask_app.config["SECRET_KEY"] serializer = URLSafeSerializer(key) token = serializer.dumps({"download_logs": True}) response = self.client.get( f"api/v1/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskInstances/{self.TASK_ID}/logs/1", - query_string={"token": token, "map_index": 0}, - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, + params={"token": token, "map_index": 0}, + headers={"Accept": "text/plain", "REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "TaskInstance not found" + assert response.json()["title"] == "TaskInstance not found" diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index 8d5c854eb4d83f..6ca9b571f3dfec 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -33,7 +33,7 @@ from airflow.utils.session import provide_session from airflow.utils.state import State, TaskInstanceState from airflow.utils.timezone import datetime -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.mock_operators import MockOperator @@ -48,9 +48,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -61,13 +61,13 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_roles(app) + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_roles(connexion_app.app) class TestMappedTaskInstanceEndpoint: @@ -87,8 +87,9 @@ def setup_attrs(self, configured_app) -> None: "queue": "default_queue", "job_id": 0, } - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore clear_db_runs() clear_db_sla_miss() clear_rendered_ti_fields() @@ -132,9 +133,9 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): setattr(ti, "start_date", DEFAULT_DATETIME_1) session.add(ti) - self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {dag_id: dag_maker.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.flask_app.dag_bag = DagBag(os.devnull, include_examples=False) + self.flask_app.dag_bag.dags = {dag_id: dag_maker.dag} # type: ignore + self.flask_app.dag_bag.sync_to_db() # type: ignore session.flush() mapped.expand_mapped_task(dr.run_id, session=session) @@ -201,10 +202,10 @@ class TestNonExistent(TestMappedTaskInstanceEndpoint): def test_non_existent_task_instance(self, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "DAG mapped_tis not found" + assert response.json()["title"] == "DAG mapped_tis not found" class TestGetMappedTaskInstance(TestMappedTaskInstanceEndpoint): @@ -212,10 +213,10 @@ class TestGetMappedTaskInstance(TestMappedTaskInstanceEndpoint): def test_mapped_task_instances(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/0", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "mapped_tis", "dag_run_id": "run_mapped_tis", "duration": None, @@ -250,22 +251,22 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/1", ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 def test_without_map_index_returns_custom_404(self, one_task_with_mapped_tis): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "Task instance is mapped, add the map_index value to the URL", "status": 404, "title": "Task instance not found", @@ -275,20 +276,20 @@ def test_without_map_index_returns_custom_404(self, one_task_with_mapped_tis): def test_one_mapped_task_works(self, one_task_with_single_mapped_ti): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/0", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/1", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "Task instance is mapped, add the map_index value to the URL", "status": 404, "title": "Task instance not found", @@ -301,71 +302,73 @@ class TestGetMappedTaskInstances(TestMappedTaskInstanceEndpoint): def test_mapped_task_instances(self, one_task_with_many_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 @provide_session def test_mapped_task_instances_offset_limit(self, one_task_with_many_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?offset=4&limit=10", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 10 - assert list(range(4, 14)) == [ti["map_index"] for ti in response.json["task_instances"]] + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 10 + assert list(range(4, 14)) == [ti["map_index"] for ti in response.json()["task_instances"]] @provide_session def test_mapped_task_instances_order(self, one_task_with_many_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 - assert list(range(100)) == [ti["map_index"] for ti in response.json["task_instances"]] + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 + assert list(range(100)) == [ti["map_index"] for ti in response.json()["task_instances"]] @provide_session def test_mapped_task_instances_reverse_order(self, one_task_with_many_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?order_by=-map_index", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 - assert list(range(109, 9, -1)) == [ti["map_index"] for ti in response.json["task_instances"]] + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 + assert list(range(109, 9, -1)) == [ti["map_index"] for ti in response.json()["task_instances"]] @provide_session def test_mapped_task_instances_state_order(self, one_task_with_many_mapped_tis, session): + session.commit() + session.close() response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?order_by=-state", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 assert list(range(5)) + list(range(25, 110)) + list(range(5, 15)) == [ - ti["map_index"] for ti in response.json["task_instances"] + ti["map_index"] for ti in response.json()["task_instances"] ] # State ascending response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?order_by=state", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 110 - assert len(response.json["task_instances"]) == 100 + assert response.json()["total_entries"] == 110 + assert len(response.json()["task_instances"]) == 100 assert list(range(5, 25)) + list(range(90, 110)) + list(range(25, 85)) == [ - ti["map_index"] for ti in response.json["task_instances"] + ti["map_index"] for ti in response.json()["task_instances"] ] @provide_session @@ -373,100 +376,100 @@ def test_mapped_task_instances_invalid_order(self, one_task_with_many_mapped_tis response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?order_by=unsupported", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "Ordering with 'unsupported' is not supported" + assert response.json()["detail"] == "Ordering with 'unsupported' is not supported" @provide_session def test_mapped_task_instances_with_date(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" f"?start_date_gte={QUOTED_DEFAULT_DATETIME_STR_1}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 3 - assert len(response.json["task_instances"]) == 3 + assert response.json()["total_entries"] == 3 + assert len(response.json()["task_instances"]) == 3 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" f"?start_date_gte={QUOTED_DEFAULT_DATETIME_STR_2}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] @provide_session def test_mapped_task_instances_with_state(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?state=success", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 3 - assert len(response.json["task_instances"]) == 3 + assert response.json()["total_entries"] == 3 + assert len(response.json()["task_instances"]) == 3 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?state=running", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] @provide_session def test_mapped_task_instances_with_pool(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" "?pool=default_pool", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 3 - assert len(response.json["task_instances"]) == 3 + assert response.json()["total_entries"] == 3 + assert len(response.json()["task_instances"]) == 3 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?pool=test_pool", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] @provide_session def test_mapped_task_instances_with_queue(self, one_task_with_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?queue=default", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 3 - assert len(response.json["task_instances"]) == 3 + assert response.json()["total_entries"] == 3 + assert len(response.json()["task_instances"]) == 3 response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped?queue=test_queue", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] @provide_session def test_mapped_task_instances_with_zero_mapped(self, one_task_with_zero_mapped_tis, session): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["total_entries"] == 0 - assert response.json["task_instances"] == [] + assert response.json()["total_entries"] == 0 + assert response.json()["task_instances"] == [] def test_should_raise_404_not_found_for_nonexistent_task(self): response = self.client.get( "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/nonexistent_task/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "Task id nonexistent_task not found" + assert response.json()["title"] == "Task id nonexistent_task not found" diff --git a/tests/api_connexion/endpoints/test_plugin_endpoint.py b/tests/api_connexion/endpoints/test_plugin_endpoint.py index f56d04a7644335..92c29f3535add5 100644 --- a/tests/api_connexion/endpoints/test_plugin_endpoint.py +++ b/tests/api_connexion/endpoints/test_plugin_endpoint.py @@ -29,7 +29,7 @@ from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.base import Timetable from airflow.utils.module_loading import qualname -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.mock_plugins import mock_plugin_manager @@ -103,19 +103,19 @@ class MockPlugin(AirflowPlugin): @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestPluginsEndpoint: @@ -124,8 +124,8 @@ def setup_attrs(self, configured_app) -> None: """ Setup For XCom endpoint TC """ - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore class TestGetPlugins(TestPluginsEndpoint): @@ -133,9 +133,9 @@ def test_get_plugins_return_200(self): mock_plugin = MockPlugin() mock_plugin.name = "test_plugin" with mock_plugin_manager(plugins=[mock_plugin]): - response = self.client.get("api/v1/plugins", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/plugins", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "plugins": [ { "appbuilder_menu_items": [appbuilder_menu_items], @@ -167,24 +167,22 @@ def test_get_plugins_works_with_more_plugins(self): mock_plugin_2 = AirflowPlugin() mock_plugin_2.name = "test_plugin2" with mock_plugin_manager(plugins=[mock_plugin, mock_plugin_2]): - response = self.client.get("api/v1/plugins", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/plugins", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 2 + assert response.json()["total_entries"] == 2 def test_get_plugins_return_200_if_no_plugins(self): with mock_plugin_manager(plugins=[]): - response = self.client.get("api/v1/plugins", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("api/v1/plugins", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/plugins") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/plugins", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/plugins", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @@ -230,35 +228,35 @@ class TestGetPluginsPagination(TestPluginsEndpoint): def test_handle_limit_offset(self, url, expected_plugin_names): plugins = self._create_plugins(10) with mock_plugin_manager(plugins=plugins): - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - plugin_names = [plugin["name"] for plugin in response.json["plugins"] if plugin] + assert response.json()["total_entries"] == 10 + plugin_names = [plugin["name"] for plugin in response.json()["plugins"] if plugin] assert plugin_names == expected_plugin_names def test_should_respect_page_size_limit_default(self): plugins = self._create_plugins(200) with mock_plugin_manager(plugins=plugins): - response = self.client.get("/api/v1/plugins", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/plugins", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["plugins"]) == 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["plugins"]) == 100 def test_limit_of_zero_should_return_default(self): plugins = self._create_plugins(200) with mock_plugin_manager(plugins=plugins): - response = self.client.get("/api/v1/plugins?limit=0", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/plugins?limit=0", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 200 - assert len(response.json["plugins"]) == 100 + assert response.json()["total_entries"] == 200 + assert len(response.json()["plugins"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self): plugins = self._create_plugins(200) with mock_plugin_manager(plugins=plugins): - response = self.client.get("/api/v1/plugins?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/plugins?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["plugins"]) == 150 + assert len(response.json()["plugins"]) == 150 def _create_plugins(self, count): plugins = [] diff --git a/tests/api_connexion/endpoints/test_pool_endpoint.py b/tests/api_connexion/endpoints/test_pool_endpoint.py index f709bda9a1ed68..b7b56c59f5464b 100644 --- a/tests/api_connexion/endpoints/test_pool_endpoint.py +++ b/tests/api_connexion/endpoints/test_pool_endpoint.py @@ -22,7 +22,7 @@ from airflow.models.pool import Pool from airflow.security import permissions from airflow.utils.session import provide_session -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools from tests.test_utils.www import _check_last_log @@ -32,10 +32,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -45,19 +45,19 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestBasePoolEndpoints: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore clear_db_pools() def teardown_method(self) -> None: @@ -69,9 +69,10 @@ def test_response_200(self, session): pool_model = Pool(pool="test_pool_a", slots=3, include_deferred=True) session.add(pool_model) session.commit() + session.close() result = session.query(Pool).all() assert len(result) == 2 # accounts for the default pool as well - response = self.client.get("/api/v1/pools", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "pools": [ @@ -101,15 +102,16 @@ def test_response_200(self, session): }, ], "total_entries": 2, - } == response.json + } == response.json() def test_response_200_with_order_by(self, session): pool_model = Pool(pool="test_pool_a", slots=3, include_deferred=True) session.add(pool_model) session.commit() + session.close() result = session.query(Pool).all() assert len(result) == 2 # accounts for the default pool as well - response = self.client.get("/api/v1/pools?order_by=slots", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools?order_by=slots", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "pools": [ @@ -139,15 +141,15 @@ def test_response_200_with_order_by(self, session): }, ], "total_entries": 2, - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/pools") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get("/api/v1/pools", environ_overrides={"REMOTE_USER": "test_no_permissions"}) + response = self.client.get("/api/v1/pools", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @@ -178,46 +180,48 @@ def test_limit_and_offset(self, url, expected_pool_ids, session): pools = [Pool(pool=f"test_pool{i}", slots=1, include_deferred=False) for i in range(1, 121)] session.add_all(pools) session.commit() + session.close() result = session.query(Pool).count() assert result == 121 # accounts for default pool as well - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - pool_ids = [pool["name"] for pool in response.json["pools"]] + pool_ids = [pool["name"] for pool in response.json()["pools"]] assert pool_ids == expected_pool_ids def test_should_respect_page_size_limit_default(self, session): pools = [Pool(pool=f"test_pool{i}", slots=1, include_deferred=False) for i in range(1, 121)] session.add_all(pools) session.commit() + session.close() result = session.query(Pool).count() assert result == 121 - response = self.client.get("/api/v1/pools", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["pools"]) == 100 + assert len(response.json()["pools"]) == 100 def test_should_raise_400_for_invalid_orderby(self, session): pools = [Pool(pool=f"test_pool{i}", slots=1, include_deferred=False) for i in range(1, 121)] session.add_all(pools) session.commit() + session.close() result = session.query(Pool).count() assert result == 121 - response = self.client.get( - "/api/v1/pools?order_by=open_slots", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/pools?order_by=open_slots", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'open_slots' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self, session): pools = [Pool(pool=f"test_pool{i}", slots=1, include_deferred=False) for i in range(1, 200)] session.add_all(pools) session.commit() + session.close() result = session.query(Pool).count() assert result == 200 - response = self.client.get("/api/v1/pools?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["pools"]) == 150 + assert len(response.json()["pools"]) == 150 class TestGetPool(TestBasePoolEndpoints): @@ -225,7 +229,8 @@ def test_response_200(self, session): pool_model = Pool(pool="test_pool_a", slots=3, include_deferred=True) session.add(pool_model) session.commit() - response = self.client.get("/api/v1/pools/test_pool_a", environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.get("/api/v1/pools/test_pool_a", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "name": "test_pool_a", @@ -238,22 +243,22 @@ def test_response_200(self, session): "open_slots": 3, "description": None, "include_deferred": True, - } == response.json + } == response.json() def test_response_404(self): - response = self.client.get("/api/v1/pools/invalid_pool", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/pools/invalid_pool", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "Pool with name:'invalid_pool' not found", "status": 404, "title": "Not Found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/pools/default_pool") - assert_401(response) + assert response.status_code == 401 class TestDeletePool(TestBasePoolEndpoints): @@ -262,45 +267,57 @@ def test_response_204(self, session): pool_instance = Pool(pool=pool_name, slots=3, include_deferred=False) session.add(pool_instance) session.commit() - - response = self.client.delete(f"api/v1/pools/{pool_name}", environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.delete(f"api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) assert response.status_code == 204 # Check if the pool is deleted from the db - response = self.client.get(f"api/v1/pools/{pool_name}", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(f"api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 _check_last_log(session, dag_id=None, event="api.delete_pool", execution_date=None) def test_response_404(self): - response = self.client.delete("api/v1/pools/invalid_pool", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.delete("api/v1/pools/invalid_pool", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "Pool with name:'invalid_pool' not found", "status": 404, "title": "Not Found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): pool_name = "test_pool" pool_instance = Pool(pool=pool_name, slots=3, include_deferred=False) session.add(pool_instance) session.commit() - + session.close() response = self.client.delete(f"api/v1/pools/{pool_name}") - assert_401(response) + assert response.status_code == 401 # Should still exists - response = self.client.get(f"/api/v1/pools/{pool_name}", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(f"/api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 + def test_response_204(self, session): + pool_name = "test_pool" + pool_instance = Pool(pool=pool_name, slots=3, include_deferred=False) + session.add(pool_instance) + session.commit() + session.close() + response = self.client.delete(f"api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) + assert response.status_code == 204 + # Check if the pool is deleted from the db + response = self.client.get(f"api/v1/pools/{pool_name}", headers={"REMOTE_USER": "test"}) + assert response.status_code == 404 + class TestPostPool(TestBasePoolEndpoints): def test_response_200(self, session): response = self.client.post( "api/v1/pools", json={"name": "test_pool_a", "slots": 3, "description": "test pool", "include_deferred": True}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert { @@ -314,7 +331,7 @@ def test_response_200(self, session): "open_slots": 3, "description": "test pool", "include_deferred": True, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.post_pool", execution_date=None) def test_response_409(self, session): @@ -322,10 +339,11 @@ def test_response_409(self, session): pool_instance = Pool(pool=pool_name, slots=3, include_deferred=False) session.add(pool_instance) session.commit() + session.close() response = self.client.post( "api/v1/pools", json={"name": "test_pool_a", "slots": 3, "include_deferred": False}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409 assert { @@ -333,7 +351,7 @@ def test_response_409(self, session): "status": 409, "title": "Conflict", "type": EXCEPTIONS_LINK_MAP[409], - } == response.json + } == response.json() @pytest.mark.parametrize( "request_json, error_detail", @@ -361,21 +379,19 @@ def test_response_409(self, session): ], ) def test_response_400(self, request_json, error_detail): - response = self.client.post( - "api/v1/pools", json=request_json, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("api/v1/pools", json=request_json, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert { "detail": error_detail, "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.post("api/v1/pools", json={"name": "test_pool_a", "slots": 3}) - assert_401(response) + assert response.status_code == 401 class TestPatchPool(TestBasePoolEndpoints): @@ -383,10 +399,11 @@ def test_response_200(self, session): pool = Pool(pool="test_pool", slots=2, include_deferred=True) session.add(pool) session.commit() + session.close() response = self.client.patch( "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3, "include_deferred": False}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert { @@ -400,7 +417,7 @@ def test_response_200(self, session): "slots": 3, "description": None, "include_deferred": False, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_pool", execution_date=None) @pytest.mark.parametrize( @@ -422,8 +439,9 @@ def test_response_400(self, error_detail, request_json, session): pool = Pool(pool="test_pool", slots=2, include_deferred=False) session.add(pool) session.commit() + session.close() response = self.client.patch( - "api/v1/pools/test_pool", json=request_json, environ_overrides={"REMOTE_USER": "test"} + "api/v1/pools/test_pool", json=request_json, headers={"REMOTE_USER": "test"} ) assert response.status_code == 400 assert { @@ -431,13 +449,13 @@ def test_response_400(self, error_detail, request_json, session): "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } == response.json() def test_not_found_when_no_pool_available(self): response = self.client.patch( "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 assert { @@ -445,31 +463,31 @@ def test_not_found_when_no_pool_available(self): "status": 404, "title": "Not Found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self, session): pool = Pool(pool="test_pool", slots=2, include_deferred=False) session.add(pool) session.commit() - + session.close() response = self.client.patch( "api/v1/pools/test_pool", json={"name": "test_pool_a", "slots": 3}, ) - assert_401(response) + assert response.status_code == 401 class TestModifyDefaultPool(TestBasePoolEndpoints): def test_delete_400(self): - response = self.client.delete("api/v1/pools/default_pool", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.delete("api/v1/pools/default_pool", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert { "detail": "Default Pool can't be deleted", "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } == response.json() @pytest.mark.parametrize( "status_code, url, json, expected_response", @@ -595,9 +613,9 @@ def test_delete_400(self): ], ) def test_patch(self, status_code, url, json, expected_response, session): - response = self.client.patch(url, json=json, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.patch(url, json=json, headers={"REMOTE_USER": "test"}) assert response.status_code == status_code - assert response.json == expected_response + assert response.json() == expected_response _check_last_log(session, dag_id=None, event="api.patch_pool", execution_date=None) @@ -649,7 +667,8 @@ def test_response_200( pool = Pool(pool="test_pool", slots=3, include_deferred=False) session.add(pool) session.commit() - response = self.client.patch(url, json=patch_json, environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.patch(url, json=patch_json, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 assert { "name": expected_name, @@ -662,20 +681,20 @@ def test_response_200( "open_slots": expected_slots, "description": None, "include_deferred": expected_include_deferred, - } == response.json + } == response.json() _check_last_log(session, dag_id=None, event="api.patch_pool", execution_date=None) @pytest.mark.parametrize( "error_detail, url, patch_json", [ pytest.param( - "Property is read-only - 'occupied_slots'", + "{'occupied_slots': ['Unknown field.']}", "api/v1/pools/test_pool?update_mask=slots, name, occupied_slots", {"name": "test_pool_a", "slots": 2, "occupied_slots": 1}, id="Patching read only field", ), pytest.param( - "Property is read-only - 'queued_slots'", + "{'queued_slots': ['Unknown field.']}", "api/v1/pools/test_pool?update_mask=slots, name, queued_slots", {"name": "test_pool_a", "slots": 2, "queued_slots": 1}, id="Patching read only field", @@ -699,11 +718,12 @@ def test_response_400(self, error_detail, url, patch_json, session): pool = Pool(pool="test_pool", slots=3, include_deferred=False) session.add(pool) session.commit() - response = self.client.patch(url, json=patch_json, environ_overrides={"REMOTE_USER": "test"}) + session.close() + response = self.client.patch(url, json=patch_json, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 assert { "detail": error_detail, "status": 400, "title": "Bad Request", "type": EXCEPTIONS_LINK_MAP[400], - } == response.json + } == response.json() diff --git a/tests/api_connexion/endpoints/test_provider_endpoint.py b/tests/api_connexion/endpoints/test_provider_endpoint.py index 7c973a9bb4132b..fec203cdab1d37 100644 --- a/tests/api_connexion/endpoints/test_provider_endpoint.py +++ b/tests/api_connexion/endpoints/test_provider_endpoint.py @@ -52,26 +52,26 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestBaseProviderEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app, cleanup_providers_manager) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore class TestGetProviders(TestBaseProviderEndpoint): @@ -81,9 +81,9 @@ class TestGetProviders(TestBaseProviderEndpoint): return_value={}, ) def test_response_200_empty_list(self, mock_providers): - response = self.client.get("/api/v1/providers", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/providers", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == {"providers": [], "total_entries": 0} + assert response.json() == {"providers": [], "total_entries": 0} @mock.patch( "airflow.providers_manager.ProvidersManager.providers", @@ -91,9 +91,9 @@ def test_response_200_empty_list(self, mock_providers): return_value=MOCK_PROVIDERS, ) def test_response_200(self, mock_providers): - response = self.client.get("/api/v1/providers", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/providers", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "providers": [ { "description": "Amazon Web Services (AWS) https://aws.amazon.com/", @@ -114,7 +114,5 @@ def test_should_raises_401_unauthenticated(self): assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/api/v1/providers", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/api/v1/providers", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index 454b0db7525d11..127a8c1cba6904 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -28,7 +28,7 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator from airflow.security import permissions -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags pytestmark = pytest.mark.db_test @@ -36,9 +36,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -47,12 +47,12 @@ def configured_app(minimal_app_for_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestTaskEndpoint: @@ -80,7 +80,7 @@ def setup_dag(self, configured_app): task1 >> task2 dag_bag = DagBag(os.devnull, include_examples=False) dag_bag.dags = {dag.dag_id: dag, mapped_dag.dag_id: mapped_dag} - configured_app.dag_bag = dag_bag # type:ignore + configured_app.app.dag_bag = dag_bag # type:ignore @staticmethod def clean_db(): @@ -91,8 +91,9 @@ def clean_db(): @pytest.fixture(autouse=True) def setup_attrs(self, configured_app, setup_dag) -> None: self.clean_db() - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self) -> None: self.clean_db() @@ -139,10 +140,10 @@ def test_should_respond_200(self): "is_mapped": False, } response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_mapped_task(self): expected = { @@ -175,17 +176,17 @@ def test_mapped_task(self): } response = self.client.get( f"/api/v1/dags/{self.mapped_dag_id}/tasks/{self.mapped_task_id}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_should_respond_200_serialized(self): # Get the dag out of the dagbag before we patch it to an empty one - SerializedDagModel.write_dag(self.app.dag_bag.get_dag(self.dag_id)) + SerializedDagModel.write_dag(self.flask_app.dag_bag.get_dag(self.dag_id)) dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True) - patcher = unittest.mock.patch.object(self.app, "dag_bag", dag_bag) + patcher = unittest.mock.patch.object(self.flask_app, "dag_bag", dag_bag) patcher.start() expected = { @@ -227,35 +228,35 @@ def test_should_respond_200_serialized(self): "is_mapped": False, } response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected patcher.stop() def test_should_respond_404(self): task_id = "xxxx_not_existing" response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks/{task_id}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{self.dag_id}/tasks/{task_id}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 def test_should_respond_404_when_dag_not_found(self): dag_id = "xxxx_not_existing" response = self.client.get( - f"/api/v1/dags/{dag_id}/tasks/{self.task_id}", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{dag_id}/tasks/{self.task_id}", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 - assert response.json["title"] == "DAG not found" + assert response.json()["title"] == "DAG not found" def test_should_raises_401_unauthenticated(self): response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks", environ_overrides={"REMOTE_USER": "test_no_permissions"} + f"/api/v1/dags/{self.dag_id}/tasks", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -336,11 +337,9 @@ def test_should_respond_200(self): ], "total_entries": 2, } - response = self.client.get( - f"/api/v1/dags/{self.dag_id}/tasks", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_get_tasks_mapped(self): expected = { @@ -408,46 +407,48 @@ def test_get_tasks_mapped(self): "total_entries": 2, } response = self.client.get( - f"/api/v1/dags/{self.mapped_dag_id}/tasks", environ_overrides={"REMOTE_USER": "test"} + f"/api/v1/dags/{self.mapped_dag_id}/tasks", headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_should_respond_200_ascending_order_by_start_date(self): response = self.client.get( f"/api/v1/dags/{self.dag_id}/tasks?order_by=start_date", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert self.task1_start_date < self.task2_start_date - assert response.json["tasks"][0]["task_id"] == self.task_id - assert response.json["tasks"][1]["task_id"] == self.task_id2 + assert response.json()["tasks"][0]["task_id"] == self.task_id + assert response.json()["tasks"][1]["task_id"] == self.task_id2 def test_should_respond_200_descending_order_by_start_date(self): response = self.client.get( f"/api/v1/dags/{self.dag_id}/tasks?order_by=-start_date", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 # - means is descending assert self.task1_start_date < self.task2_start_date - assert response.json["tasks"][0]["task_id"] == self.task_id2 - assert response.json["tasks"][1]["task_id"] == self.task_id + assert response.json()["tasks"][0]["task_id"] == self.task_id2 + assert response.json()["tasks"][1]["task_id"] == self.task_id def test_should_raise_400_for_invalid_order_by_name(self): response = self.client.get( f"/api/v1/dags/{self.dag_id}/tasks?order_by=invalid_task_colume_name", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "'EmptyOperator' object has no attribute 'invalid_task_colume_name'" + assert ( + response.json()["detail"] == "'EmptyOperator' object has no attribute 'invalid_task_colume_name'" + ) def test_should_respond_404(self): dag_id = "xxxx_not_existing" - response = self.client.get(f"/api/v1/dags/{dag_id}/tasks", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(f"/api/v1/dags/{dag_id}/tasks", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 def test_should_raises_401_unauthenticated(self): response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks") - assert_401(response) + assert response.status_code == 401 diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index 26f573be1e3e2b..99de556fbaf178 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -36,7 +36,7 @@ from airflow.utils.state import State from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.www import _check_last_log @@ -52,9 +52,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -67,7 +67,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_dag_read_only", role_name="TestDagReadOnly", permissions=[ @@ -78,7 +78,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_task_read_only", role_name="TestTaskReadOnly", permissions=[ @@ -89,7 +89,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_read_only_one_dag", role_name="TestReadOnlyOneDag", permissions=[ @@ -99,7 +99,7 @@ def configured_app(minimal_app_for_api): ) # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( + connexion_app.app.appbuilder.sm.bulk_sync_roles( [ { "role": "TestReadOnlyOneDag", @@ -107,16 +107,16 @@ def configured_app(minimal_app_for_api): } ] ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_read_only") # type: ignore - delete_user(app, username="test_task_read_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_read_only_one_dag") # type: ignore - delete_roles(app) + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_dag_read_only") # type: ignore + delete_user(connexion_app.app, username="test_task_read_only") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test_read_only_one_dag") # type: ignore + delete_roles(connexion_app.app) class TestTaskInstanceEndpoint: @@ -136,8 +136,9 @@ def setup_attrs(self, configured_app, dagbag) -> None: "queue": "default_queue", "job_id": 0, } - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore clear_db_runs() clear_db_sla_miss() clear_rendered_ti_fields() @@ -196,6 +197,7 @@ def create_task_instances( tis.append(ti) session.commit() + session.close() return tis @@ -217,12 +219,13 @@ def test_should_respond_200(self, username, session): # https://github.com/apache/airflow/issues/14421 session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch") session.commit() + session.close() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -263,12 +266,14 @@ def test_should_respond_200_with_task_state_in_deferred(self, session): ti.triggerer_job = Job() TriggererJobRunner(job=ti.triggerer_job) ti.triggerer_job.state = "running" + session.merge(ti) session.commit() + session.close() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - data = response.json + data = response.json() # this logic in effect replicates mock.ANY for these values values_to_ignore = { @@ -324,10 +329,10 @@ def test_should_respond_200_with_task_state_in_removed(self, session): self.create_task_instances(session, task_instances=[{"state": State.REMOVED}], update_extras=True) response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -371,13 +376,14 @@ def test_should_respond_200_task_instance_with_sla_and_rendered(self, session): rendered_fields = RTIF(tis[0], render_templates=False) session.add(rendered_fields) session.commit() + session.close() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -427,17 +433,18 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): setattr(ti, attr, getattr(old_ti, attr)) session.add(ti) session.commit() + session.close() # in each loop, we should get the right mapped TI back for map_index in (1, 2): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances" f"/print_the_context/{map_index}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -472,28 +479,28 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 def test_raises_404_for_nonexistent_task_instance(self): response = self.client.get( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/nonexistent_task", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json["title"] == "Task instance not found" + assert response.json()["title"] == "Task instance not found" def test_unmapped_map_index_should_return_404(self, session): self.create_task_instances(session) response = self.client.get( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/-1", - environ_overrides={"REMOTE_USER": "test"}, + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/-6", + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -503,7 +510,7 @@ def test_should_return_404_for_mapped_endpoint(self, session): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/" f"taskInstances/print_the_context/{index}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -512,7 +519,7 @@ def test_should_return_404_for_list_mapped_endpoint(self, session): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/" "taskInstances/print_the_context/listMapped", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 @@ -675,10 +682,10 @@ def test_should_respond_200(self, task_instances, update_extras, url, expected_t update_extras=update_extras, task_instances=task_instances, ) - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == expected_ti - assert len(response.json["task_instances"]) == expected_ti + assert response.json()["total_entries"] == expected_ti + assert len(response.json()["task_instances"]) == expected_ti @pytest.mark.parametrize( "task_instances, user, expected_ti", @@ -719,36 +726,34 @@ def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ ], dag_id=dag_id, ) - response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} - ) + response = self.client.get("/api/v1/dags/~/dagRuns/~/taskInstances", headers={"REMOTE_USER": user}) assert response.status_code == 200 - assert response.json["total_entries"] == expected_ti - assert len(response.json["task_instances"]) == expected_ti + assert response.json()["total_entries"] == expected_ti + assert len(response.json()["task_instances"]) == expected_ti def test_should_respond_200_for_dag_id_filter(self, session): self.create_task_instances(session) self.create_task_instances(session, dag_id="example_skip_dag") response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 count = session.query(TaskInstance).filter(TaskInstance.dag_id == "example_python_operator").count() - assert count == response.json["total_entries"] - assert count == len(response.json["task_instances"]) + assert count == response.json()["total_entries"] + assert count == len(response.json()["task_instances"]) def test_should_raises_401_unauthenticated(self): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances", ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -909,12 +914,12 @@ def test_should_respond_200( ) response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, json=payload, ) - assert response.status_code == 200, response.json - assert expected_ti_count == response.json["total_entries"] - assert expected_ti_count == len(response.json["task_instances"]) + assert response.status_code == 200, response.json() + assert expected_ti_count == response.json()["total_entries"] + assert expected_ti_count == len(response.json()["task_instances"]) @pytest.mark.parametrize( "task_instances, payload, expected_ti_count", @@ -948,12 +953,12 @@ def test_should_respond_200_when_task_instance_properties_are_none( ) response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) - assert response.status_code == 200, response.json - assert expected_ti_count == response.json["total_entries"] - assert expected_ti_count == len(response.json["task_instances"]) + assert response.status_code == 200, response.json() + assert expected_ti_count == response.json()["total_entries"] + assert expected_ti_count == len(response.json()["task_instances"]) @pytest.mark.parametrize( "payload, expected_ti, total_ti", @@ -972,24 +977,24 @@ def test_should_respond_200_dag_ids_filter(self, payload, expected_ti, total_ti, self.create_task_instances(session, dag_id="example_skip_dag") response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 200 - assert len(response.json["task_instances"]) == expected_ti - assert response.json["total_entries"] == total_ti + assert len(response.json()["task_instances"]) == expected_ti + assert response.json()["total_entries"] == total_ti def test_should_raises_401_unauthenticated(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", json={"dag_ids": ["example_python_operator", "example_skip_dag"]}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, json={"dag_ids": ["example_python_operator", "example_skip_dag"]}, ) assert response.status_code == 403 @@ -1001,11 +1006,11 @@ def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, sess response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, + headers={"REMOTE_USER": "test_read_only_one_dag"}, json=payload, ) assert response.status_code == 403 - assert response.json == { + assert response.json() == { "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", "status": 403, "title": "Forbidden", @@ -1015,19 +1020,19 @@ def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, sess def test_should_raise_400_for_no_json(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "Request body must not be empty" + assert response.json()["detail"] == "RequestBody is required" def test_should_raise_400_for_unknown_fields(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={"unknown_field": "unknown_value"}, ) assert response.status_code == 400 - assert response.json["detail"] == "{'unknown_field': ['Unknown field.']}" + assert response.json()["detail"] == "{'unknown_field': ['Unknown field.']}" @pytest.mark.parametrize( "payload, expected", @@ -1045,11 +1050,11 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se self.create_task_instances(session) response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 400 - assert expected in response.json["detail"] + assert expected in response.json()["detail"] class TestPostClearTaskInstances(TestTaskInstanceEndpoint): @@ -1245,14 +1250,14 @@ def test_should_respond_200(self, main_dag, task_instances, request_dag, payload task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db() + self.flask_app.dag_bag.sync_to_db() response = self.client.post( f"/api/v1/dags/{request_dag}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 200 - assert len(response.json["task_instances"]) == expected_ti + assert len(response.json()["task_instances"]) == expected_ti _check_last_log( session, dag_id=request_dag, @@ -1267,15 +1272,15 @@ def test_clear_taskinstance_is_called_with_queued_dr_state(self, mock_clearti, s self.create_task_instances(session) dag_id = "example_python_operator" payload = {"include_subdags": True, "reset_dag_runs": True, "dry_run": False} - self.app.dag_bag.sync_to_db() + self.flask_app.dag_bag.sync_to_db() response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 200 mock_clearti.assert_called_once_with( - [], session, dag=self.app.dag_bag.get_dag(dag_id), dag_run_state=State.QUEUED + [], mock.ANY, dag=self.flask_app.dag_bag.get_dag(dag_id), dag_run_state=State.QUEUED ) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) @@ -1287,10 +1292,10 @@ def test_clear_taskinstance_is_called_with_invalid_task_ids(self, session): assert dagrun.state == "running" payload = {"dry_run": False, "reset_dag_runs": True, "task_ids": [""]} - self.app.dag_bag.sync_to_db() + self.flask_app.dag_bag.sync_to_db() response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 200 @@ -1342,7 +1347,7 @@ def test_should_respond_200_with_reset_dag_run(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) @@ -1387,8 +1392,8 @@ def test_should_respond_200_with_reset_dag_run(self, session): }, ] for task_instance in expected_response: - assert task_instance in response.json["task_instances"] - assert 6 == len(response.json["task_instances"]) + assert task_instance in response.json()["task_instances"] + assert 6 == len(response.json()["task_instances"]) assert 0 == failed_dag_runs, 0 _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) @@ -1435,7 +1440,7 @@ def test_should_respond_200_with_dag_run_id(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert 200 == response.status_code @@ -1447,8 +1452,8 @@ def test_should_respond_200_with_dag_run_id(self, session): "task_id": "print_the_context", }, ] - assert response.json["task_instances"] == expected_response - assert 1 == len(response.json["task_instances"]) + assert response.json()["task_instances"] == expected_response + assert 1 == len(response.json()["task_instances"]) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) def test_should_respond_200_with_include_past(self, session): @@ -1494,7 +1499,7 @@ def test_should_respond_200_with_include_past(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert 200 == response.status_code @@ -1537,8 +1542,8 @@ def test_should_respond_200_with_include_past(self, session): }, ] for task_instance in expected_response: - assert task_instance in response.json["task_instances"] - assert 6 == len(response.json["task_instances"]) + assert task_instance in response.json()["task_instances"] + assert 6 == len(response.json()["task_instances"]) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) def test_should_respond_200_with_include_future(self, session): @@ -1583,7 +1588,7 @@ def test_should_respond_200_with_include_future(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) @@ -1627,8 +1632,8 @@ def test_should_respond_200_with_include_future(self, session): }, ] for task_instance in expected_response: - assert task_instance in response.json["task_instances"] - assert 6 == len(response.json["task_instances"]) + assert task_instance in response.json()["task_instances"] + assert 6 == len(response.json()["task_instances"]) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) def test_should_respond_404_for_nonexistent_dagrun_id(self, session): @@ -1658,13 +1663,13 @@ def test_should_respond_404_for_nonexistent_dagrun_id(self, session): ) response = self.client.post( f"/api/v1/dags/{dag_id}/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert 404 == response.status_code assert ( - response.json["title"] + response.json()["title"] == "Dag Run id TEST_DAG_RUN_ID_100 not found in dag example_python_operator" ) _check_last_log(session, dag_id=dag_id, event="api.post_clear_task_instances", execution_date=None) @@ -1680,13 +1685,13 @@ def test_should_raises_401_unauthenticated(self): "include_subdags": True, }, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) def test_should_raise_403_forbidden(self, username: str): response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, json={ "dry_run": False, "reset_dag_runs": True, @@ -1721,19 +1726,19 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se task_instances=task_instances, update_extras=False, ) - self.app.dag_bag.sync_to_db() + self.flask_app.dag_bag.sync_to_db() response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected def test_raises_404_for_non_existent_dag(self): response = self.client.post( "/api/v1/dags/non-existent-dag/clearTaskInstances", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "reset_dag_runs": True, @@ -1743,7 +1748,7 @@ def test_raises_404_for_non_existent_dag(self): }, ) assert response.status_code == 404 - assert response.json["title"] == "Dag id non-existent-dag not found" + assert response.json()["title"] == "Dag id non-existent-dag not found" class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): @@ -1759,7 +1764,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi ) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -1772,7 +1777,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi }, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "task_instances": [ { "dag_id": "example_python_operator", @@ -1793,7 +1798,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi state="failed", task_id="print_the_context", upstream=True, - session=session, + session=mock.ANY, ) @mock.patch("airflow.models.dag.DAG.set_task_instance_state") @@ -1808,7 +1813,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ ) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -1821,7 +1826,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ }, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "task_instances": [ { "dag_id": "example_python_operator", @@ -1842,7 +1847,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ state="failed", task_id="print_the_context", upstream=True, - session=session, + session=mock.ANY, ) @pytest.mark.parametrize( @@ -1911,11 +1916,11 @@ def test_should_handle_errors(self, error, code, payload, session): self.create_task_instances(session) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == code - assert response.json["detail"] == error + assert response.json()["detail"] == error def test_should_raises_401_unauthenticated(self): response = self.client.post( @@ -1931,13 +1936,13 @@ def test_should_raises_401_unauthenticated(self): "new_state": "failed", }, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) def test_should_raise_403_forbidden(self, username): response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, json={ "dry_run": True, "task_id": "print_the_context", @@ -1954,7 +1959,7 @@ def test_should_raise_403_forbidden(self, username): def test_should_raise_404_not_found_dag(self): response = self.client.post( "/api/v1/dags/INVALID_DAG/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -1974,7 +1979,7 @@ def test_should_raise_not_found_if_execution_date_is_wrong(self, mock_set_task_i date = DEFAULT_DATETIME_1 + dt.timedelta(days=1) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -1987,7 +1992,7 @@ def test_should_raise_not_found_if_execution_date_is_wrong(self, mock_set_task_i }, ) assert response.status_code == 404 - assert response.json["detail"] == ( + assert response.json()["detail"] == ( f"Task instance not found for task 'print_the_context' on execution_date {date}" ) assert mock_set_task_instance_state.call_count == 0 @@ -1995,7 +2000,7 @@ def test_should_raise_not_found_if_execution_date_is_wrong(self, mock_set_task_i def test_should_raise_404_not_found_task(self): response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "INVALID_TASK", @@ -2045,11 +2050,11 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se self.create_task_instances(session) response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 400 - assert response.json["detail"] == expected + assert response.json()["detail"] == expected class TestPatchTaskInstance(TestTaskInstanceEndpoint): @@ -2073,14 +2078,14 @@ def test_should_call_mocked_api(self, mock_set_task_instance_state, session): ) response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": NEW_STATE, }, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "dag_run_id": "TEST_DAG_RUN_ID", "execution_date": "2020-01-01T00:00:00+00:00", @@ -2093,7 +2098,7 @@ def test_should_call_mocked_api(self, mock_set_task_instance_state, session): map_indexes=[-1], state=NEW_STATE, commit=True, - session=session, + session=mock.ANY, ) _check_last_log( session, @@ -2118,14 +2123,14 @@ def test_should_not_call_mocked_api_for_dry_run(self, mock_set_task_instance_sta ) response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "new_state": NEW_STATE, }, ) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "dag_run_id": "TEST_DAG_RUN_ID", "execution_date": "2020-01-01T00:00:00+00:00", @@ -2141,7 +2146,7 @@ def test_should_update_task_instance_state(self, session): self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": NEW_STATE, @@ -2150,11 +2155,10 @@ def test_should_update_task_instance_state(self, session): response2 = self.client.get( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, - json={}, + headers={"REMOTE_USER": "test"}, ) assert response2.status_code == 200 - assert response2.json["state"] == NEW_STATE + assert response2.json()["state"] == NEW_STATE def test_should_update_task_instance_state_default_dry_run_to_true(self, session): self.create_task_instances(session) @@ -2163,7 +2167,7 @@ def test_should_update_task_instance_state_default_dry_run_to_true(self, session self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "new_state": NEW_STATE, }, @@ -2171,11 +2175,10 @@ def test_should_update_task_instance_state_default_dry_run_to_true(self, session response2 = self.client.get( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, - json={}, + headers={"REMOTE_USER": "test"}, ) assert response2.status_code == 200 - assert response2.json["state"] == NEW_STATE + assert response2.json()["state"] == NEW_STATE def test_should_update_mapped_task_instance_state(self, session): NEW_STATE = "failed" @@ -2185,10 +2188,11 @@ def test_should_update_mapped_task_instance_state(self, session): ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) session.add(ti) session.commit() + session.close() self.client.patch( f"{self.ENDPOINT_URL}/{map_index}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": NEW_STATE, @@ -2197,11 +2201,10 @@ def test_should_update_mapped_task_instance_state(self, session): response2 = self.client.get( f"{self.ENDPOINT_URL}/{map_index}", - environ_overrides={"REMOTE_USER": "test"}, - json={}, + headers={"REMOTE_USER": "test"}, ) assert response2.status_code == 200 - assert response2.json["state"] == NEW_STATE + assert response2.json()["state"] == NEW_STATE @pytest.mark.parametrize( "error, code, payload", @@ -2219,51 +2222,51 @@ def test_should_update_mapped_task_instance_state(self, session): def test_should_handle_errors(self, error, code, payload, session): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == code - assert response.json["detail"] == error + assert response.json()["detail"] == error def test_should_raise_400_for_unknown_fields(self, session): self.create_task_instances(session) response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dryrun": True, "new_state": "failed", }, ) assert response.status_code == 400 - assert response.json["detail"] == "{'dryrun': ['Unknown field.']}" + assert response.json()["detail"] == "{'dryrun': ['Unknown field.']}" def test_should_raise_404_for_non_existent_dag(self): response = self.client.patch( "/api/v1/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": "failed", }, ) assert response.status_code == 404 - assert response.json["title"] == "DAG not found" - assert response.json["detail"] == "DAG 'non-existent-dag' not found" + assert response.json()["title"] == "DAG not found" + assert response.json()["detail"] == "DAG 'non-existent-dag' not found" def test_should_raise_404_for_non_existent_task_in_dag(self): response = self.client.patch( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": False, "new_state": "failed", }, ) assert response.status_code == 404 - assert response.json["title"] == "Task not found" + assert response.json()["title"] == "Task not found" assert ( - response.json["detail"] == "Task 'non_existent_task' not found in DAG 'example_python_operator'" + response.json()["detail"] == "Task 'non_existent_task' not found in DAG 'example_python_operator'" ) def test_should_raises_401_unauthenticated(self): @@ -2274,13 +2277,13 @@ def test_should_raises_401_unauthenticated(self): "new_state": "failed", }, ) - assert_401(response) + assert response.status_code == 401 @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) def test_should_raise_403_forbidden(self, username): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": username}, + headers={"REMOTE_USER": username}, json={ "dry_run": True, "new_state": "failed", @@ -2291,7 +2294,7 @@ def test_should_raise_403_forbidden(self, username): def test_should_raise_404_not_found_dag(self): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "new_state": "failed", @@ -2302,7 +2305,7 @@ def test_should_raise_404_not_found_dag(self): def test_should_raise_404_not_found_task(self): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json={ "dry_run": True, "new_state": "failed", @@ -2336,12 +2339,12 @@ def test_should_raise_400_for_invalid_task_instance_state(self, payload, expecte self.create_task_instances(session) response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, json=payload, ) assert response.status_code == 400 - assert response.json["detail"] == expected - assert response.json["detail"] == expected + assert response.json()["detail"] == expected + assert response.json()["detail"] == expected class TestSetTaskInstanceNote(TestTaskInstanceEndpoint): @@ -2359,10 +2362,10 @@ def test_should_respond_200(self, session): "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" "print_the_context/setNote", json={"note": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.text - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -2409,6 +2412,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): setattr(ti, attr, getattr(old_ti, attr)) session.add(ti) session.commit() + session.close() # in each loop, we should get the right mapped TI back for map_index in (1, 2): @@ -2417,11 +2421,11 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" f"print_the_context/{map_index}/setNote", json={"note": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.text - assert response.json == { + assert response.json() == { "dag_id": "example_python_operator", "duration": 10000.0, "end_date": "2020-01-03T00:00:00+00:00", @@ -2458,15 +2462,16 @@ def test_should_respond_200_when_note_is_empty(self, session): ti.task_instance_note = None session.add(ti) session.commit() + session.close() new_note_value = "My super cool TaskInstance note." response = self.client.patch( "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" "print_the_context/setNote", json={"note": new_note_value}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.text - assert response.json["note"] == new_note_value + assert response.json()["note"] == new_note_value def test_should_raise_400_for_unknown_fields(self, session): self.create_task_instances(session) @@ -2474,10 +2479,10 @@ def test_should_raise_400_for_unknown_fields(self, session): "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" "print_the_context/setNote", json={"note": "a valid field", "not": "an unknown field"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "{'not': ['Unknown field.']}" + assert response.json()["detail"] == "{'not': ['Unknown field.']}" def test_should_raises_401_unauthenticated(self): for map_index in ["", "/0"]: @@ -2489,7 +2494,7 @@ def test_should_raises_401_unauthenticated(self): url, json={"note": "I am setting a note while being unauthenticated."}, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): for map_index in ["", "/0"]: @@ -2497,7 +2502,7 @@ def test_should_raise_403_forbidden(self): "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" f"print_the_context{map_index}/setNote", json={"note": "I am setting a note without the proper permissions."}, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -2508,6 +2513,6 @@ def test_should_respond_404(self, session): f"api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" f"{map_index}/setNote", json={"note": "I am setting a note on a DAG that doesn't exist."}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py index 0e300b0a8f3800..37cdbf42f3db96 100644 --- a/tests/api_connexion/endpoints/test_variable_endpoint.py +++ b/tests/api_connexion/endpoints/test_variable_endpoint.py @@ -23,7 +23,7 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Variable from airflow.security import permissions -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_variables from tests.test_utils.www import _check_last_log @@ -33,10 +33,10 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -47,7 +47,7 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_read_only", role_name="TestReadOnly", permissions=[ @@ -55,28 +55,28 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_delete_only", role_name="TestDeleteOnly", permissions=[ (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_read_only") # type: ignore - delete_user(app, username="test_delete_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_read_only") # type: ignore + delete_user(connexion_app.app, username="test_delete_only") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestVariableEndpoint: @pytest.fixture(autouse=True) def setup_method(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore clear_db_variables() def teardown_method(self) -> None: @@ -87,22 +87,20 @@ class TestDeleteVariable(TestVariableEndpoint): def test_should_delete_variable(self, session): Variable.set("delete_var1", 1) # make sure variable is added - response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables/delete_var1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - response = self.client.delete( - "/api/v1/variables/delete_var1", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.delete("/api/v1/variables/delete_var1", headers={"REMOTE_USER": "test"}) assert response.status_code == 204 # make sure variable is deleted - response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables/delete_var1", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 _check_last_log(session, dag_id=None, event="api.variable.delete", execution_date=None) def test_should_respond_404_if_key_does_not_exist(self): response = self.client.delete( - "/api/v1/variables/NONEXIST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": "test"} + "/api/v1/variables/NONEXIST_VARIABLE_KEY", headers={"REMOTE_USER": "test"} ) assert response.status_code == 404 @@ -111,17 +109,17 @@ def test_should_raises_401_unauthenticated(self): # make sure variable is added response = self.client.delete("/api/v1/variables/delete_var1") - assert_401(response) + assert response.status_code == 401 # make sure variable is not deleted - response = self.client.get("/api/v1/variables/delete_var1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables/delete_var1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 def test_should_raise_403_forbidden(self): expected_value = '{"foo": 1}' Variable.set("TEST_VARIABLE_KEY", expected_value) response = self.client.get( - "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/api/v1/variables/TEST_VARIABLE_KEY", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -139,17 +137,17 @@ class TestGetVariable(TestVariableEndpoint): def test_read_variable(self, user, expected_status_code): expected_value = '{"foo": 1}' Variable.set("TEST_VARIABLE_KEY", expected_value) - response = self.client.get( - "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user} - ) + response = self.client.get("/api/v1/variables/TEST_VARIABLE_KEY", headers={"REMOTE_USER": user}) assert response.status_code == expected_status_code if expected_status_code == 200: - assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None} + assert response.json() == { + "key": "TEST_VARIABLE_KEY", + "value": expected_value, + "description": None, + } def test_should_respond_404_if_not_found(self): - response = self.client.get( - "/api/v1/variables/NONEXIST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/variables/NONEXIST_VARIABLE_KEY", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 def test_should_raises_401_unauthenticated(self): @@ -157,17 +155,17 @@ def test_should_raises_401_unauthenticated(self): response = self.client.get("/api/v1/variables/TEST_VARIABLE_KEY") - assert_401(response) + assert response.status_code == 401 def test_should_handle_slashes_in_keys(self): expected_value = "hello" Variable.set("foo/bar", expected_value) response = self.client.get( f"/api/v1/variables/{urllib.parse.quote('foo/bar', safe='')}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == {"key": "foo/bar", "value": expected_value, "description": None} + assert response.json() == {"key": "foo/bar", "value": expected_value, "description": None} class TestGetVariables(TestVariableEndpoint): @@ -209,42 +207,40 @@ def test_should_get_list_variables(self, query, expected): Variable.set("var1", 1, "I am a variable") Variable.set("var2", "foo", "Another variable") Variable.set("var3", "[100, 101]") - response = self.client.get(query, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(query, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == expected + assert response.json() == expected def test_should_respect_page_size_limit_default(self): for i in range(101): Variable.set(f"var{i}", i) - response = self.client.get("/api/v1/variables", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 101 - assert len(response.json["variables"]) == 100 + assert response.json()["total_entries"] == 101 + assert len(response.json()["variables"]) == 100 def test_should_raise_400_for_invalid_order_by(self): for i in range(101): Variable.set(f"var{i}", i) - response = self.client.get( - "/api/v1/variables?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/api/v1/variables?order_by=invalid", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self): for i in range(200): Variable.set(f"var{i}", i) - response = self.client.get("/api/v1/variables?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/api/v1/variables?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["variables"]) == 150 + assert len(response.json()["variables"]) == 150 def test_should_raises_401_unauthenticated(self): Variable.set("var1", 1) response = self.client.get("/api/v1/variables?limit=2&offset=0") - assert_401(response) + assert response.status_code == 401 class TestPatchVariable(TestVariableEndpoint): @@ -257,10 +253,10 @@ def test_should_update_variable(self, session): response = self.client.patch( "/api/v1/variables/var1", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == {"key": "var1", "value": "updated", "description": None} + assert response.json() == {"key": "var1", "value": "updated", "description": None} _check_last_log( session, dag_id=None, event="api.variable.edit", execution_date=None, expected_extra=payload ) @@ -270,10 +266,10 @@ def test_should_update_variable_with_mask(self, session): response = self.client.patch( "/api/v1/variables/var1?update_mask=description", json={"key": "var1", "value": "updated", "description": "after_update"}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == {"key": "var1", "value": "foo", "description": "after_update"} + assert response.json() == {"key": "var1", "value": "foo", "description": "after_update"} _check_last_log(session, dag_id=None, event="api.variable.edit", execution_date=None) def test_should_reject_invalid_update(self): @@ -283,10 +279,10 @@ def test_should_reject_invalid_update(self): "key": "var1", "value": "foo", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "title": "Variable not found", "status": 404, "type": EXCEPTIONS_LINK_MAP[404], @@ -299,10 +295,10 @@ def test_should_reject_invalid_update(self): "key": "var2", "value": "updated", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "title": "Invalid post body", "status": 400, "type": EXCEPTIONS_LINK_MAP[400], @@ -314,9 +310,9 @@ def test_should_reject_invalid_update(self): json={ "key": "var2", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.json == { + assert response.json() == { "title": "Invalid Variable schema", "status": 400, "type": EXCEPTIONS_LINK_MAP[400], @@ -334,7 +330,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 class TestPostVariables(TestVariableEndpoint): @@ -353,14 +349,14 @@ def test_should_create_variable(self, description, session): response = self.client.post( "/api/v1/variables", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 _check_last_log( session, dag_id=None, event="api.variable.create", execution_date=None, expected_extra=payload ) - response = self.client.get("/api/v1/variables/var_create", environ_overrides={"REMOTE_USER": "test"}) - assert response.json == { + response = self.client.get("/api/v1/variables/var_create", headers={"REMOTE_USER": "test"}) + assert response.json() == { "key": "var_create", "value": "{}", "description": description, @@ -372,7 +368,7 @@ def test_should_create_masked_variable(self, session): response = self.client.post( "/api/v1/variables", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 expected_extra = { @@ -386,8 +382,8 @@ def test_should_create_masked_variable(self, session): execution_date=None, expected_extra=expected_extra, ) - response = self.client.get("/api/v1/variables/api_key", environ_overrides={"REMOTE_USER": "test"}) - assert response.json == payload + response = self.client.get("/api/v1/variables/api_key", headers={"REMOTE_USER": "test"}) + assert response.json() == payload def test_should_reject_invalid_request(self, session): response = self.client.post( @@ -396,10 +392,10 @@ def test_should_reject_invalid_request(self, session): "key": "var_create", "v": "{}", }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "title": "Invalid Variable schema", "status": 400, "type": EXCEPTIONS_LINK_MAP[400], @@ -416,4 +412,4 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 diff --git a/tests/api_connexion/endpoints/test_version_endpoint.py b/tests/api_connexion/endpoints/test_version_endpoint.py index 6c21985a73584e..f966347e04a8a1 100644 --- a/tests/api_connexion/endpoints/test_version_endpoint.py +++ b/tests/api_connexion/endpoints/test_version_endpoint.py @@ -29,8 +29,8 @@ def setup_attrs(self, minimal_app_for_api) -> None: """ Setup For XCom endpoint TC """ - self.app = minimal_app_for_api - self.client = self.app.test_client() # type:ignore + self.connexion_app = minimal_app_for_api + self.client = self.connexion_app.test_client() # type:ignore @mock.patch("airflow.api_connexion.endpoints.version_endpoint.airflow.__version__", "MOCK_VERSION") @mock.patch( @@ -40,5 +40,5 @@ def test_should_respond_200(self, mock_get_airflow_get_commit): response = self.client.get("/api/v1/version") assert 200 == response.status_code - assert {"git_version": "GIT_COMMIT", "version": "MOCK_VERSION"} == response.json + assert {"git_version": "GIT_COMMIT", "version": "MOCK_VERSION"} == response.json() mock_get_airflow_get_commit.assert_called_once_with() diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index 1e4dbb56780cfb..d0727b5292c1a5 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -31,7 +31,7 @@ from airflow.utils.session import create_session from airflow.utils.timezone import utcnow from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom @@ -49,10 +49,10 @@ def orm_deserialize_value(self): @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + connexion_app = minimal_app_for_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -61,23 +61,23 @@ def configured_app(minimal_app_for_api): ], ) create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test_granular_permissions", role_name="TestGranularDag", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), ], ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore + connexion_app.app.appbuilder.sm.sync_perm_for_dag( # type: ignore "test-dag-id-1", access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore def _compare_xcom_collections(collection1: dict, collection_2: dict): @@ -109,8 +109,8 @@ def setup_attrs(self, configured_app) -> None: """ Setup For XCom endpoint TC """ - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore # clear existing xcoms self.clean_db() @@ -132,11 +132,11 @@ def test_should_respond_200(self): self._create_xcom_entry(dag_id, run_id, execution_date_parsed, task_id, xcom_key) response = self.client.get( f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - current_data = response.json + current_data = response.json() current_data["timestamp"] = "TIMESTAMP" assert current_data == { "dag_id": dag_id, @@ -158,10 +158,10 @@ def test_should_raise_404_for_non_existent_xcom(self): self._create_xcom_entry(dag_id, run_id, execution_date_parsed, task_id, xcom_key) response = self.client.get( f"/api/v1/dags/nonexistentdagid/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 404 == response.status_code - assert response.json["title"] == "XCom entry not found" + assert response.json()["title"] == "XCom entry not found" def test_should_raises_401_unauthenticated(self): dag_id = "test-dag-id" @@ -175,7 +175,7 @@ def test_should_raises_401_unauthenticated(self): f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}" ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): dag_id = "test-dag-id" @@ -188,7 +188,7 @@ def test_should_raise_403_forbidden(self): self._create_xcom_entry(dag_id, run_id, execution_date_parsed, task_id, xcom_key) response = self.client.get( f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -262,13 +262,13 @@ def test_custom_xcom_deserialize(self, allowed: bool, query: str, expected_statu url = f"/api/v1/dags/dag/dagRuns/run/taskInstances/task/xcomEntries/key{query}" with mock.patch("airflow.api_connexion.endpoints.xcom_endpoint.XCom", XCom): with conf_vars({("api", "enable_xcom_deserialize_support"): str(allowed)}): - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) if isinstance(expected_status_or_value, int): assert response.status_code == expected_status_or_value else: assert response.status_code == 200 - assert response.json["value"] == expected_status_or_value + assert response.json()["value"] == expected_status_or_value class TestGetXComEntries(TestXComEndpoint): @@ -282,11 +282,11 @@ def test_should_respond_200(self): self._create_xcom_entries(dag_id, run_id, execution_date_parsed, task_id) response = self.client.get( f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" _compare_xcom_collections( @@ -329,11 +329,11 @@ def test_should_respond_200_with_tilde_and_access_to_all_dags(self): response = self.client.get( "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" _compare_xcom_collections( @@ -392,11 +392,11 @@ def test_should_respond_200_with_tilde_and_granular_dag_access(self): self._create_invalid_xcom_entries(execution_date_parsed) response = self.client.get( "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + headers={"REMOTE_USER": "test_granular_permissions"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" _compare_xcom_collections( @@ -436,11 +436,11 @@ def assert_expected_result(expected_entries, map_index=None): response = self.client.get( "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries" f"{('?map_index=' + str(map_index)) if map_index is not None else ''}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" assert response_data == { @@ -479,11 +479,11 @@ def test_should_respond_200_with_xcom_key(self): def assert_expected_result(expected_entries, key=None): response = self.client.get( f"/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries?xcom_key={key}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert 200 == response.status_code - response_data = response.json + response_data = response.json() for xcom_entry in response_data["xcom_entries"]: xcom_entry["timestamp"] = "TIMESTAMP" assert response_data == { @@ -522,7 +522,7 @@ def test_should_raises_401_unauthenticated(self): f"/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries" ) - assert_401(response) + assert response.status_code == 401 def _create_xcom_entries(self, dag_id, run_id, execution_date, task_id, mapped_ti=False): with create_session() as session: @@ -683,8 +683,8 @@ def test_handle_limit_offset(self, query_params, expected_xcom_ids): ) session.add(xcom) - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 10 - conn_ids = [conn["key"] for conn in response.json["xcom_entries"] if conn] + assert response.json()["total_entries"] == 10 + conn_ids = [conn["key"] for conn in response.json()["xcom_entries"] if conn] assert conn_ids == expected_xcom_ids diff --git a/tests/api_connexion/schemas/test_dag_run_schema.py b/tests/api_connexion/schemas/test_dag_run_schema.py index ce187868c78f3b..6e7ba21dae63fe 100644 --- a/tests/api_connexion/schemas/test_dag_run_schema.py +++ b/tests/api_connexion/schemas/test_dag_run_schema.py @@ -129,7 +129,7 @@ def test_invalid_execution_date_raises(self): serialized_dagrun = {"execution_date": "mydate"} with pytest.raises(BadRequest) as ctx: dagrun_schema.load(serialized_dagrun) - assert str(ctx.value) == "Incorrect datetime argument" + assert str(ctx.value) == "400: Invalid date string: mydate" class TestDagRunCollection(TestDAGRunBase): diff --git a/tests/api_connexion/schemas/test_role_and_permission_schema.py b/tests/api_connexion/schemas/test_role_and_permission_schema.py index a8a49242168383..26cd87c976786e 100644 --- a/tests/api_connexion/schemas/test_role_and_permission_schema.py +++ b/tests/api_connexion/schemas/test_role_and_permission_schema.py @@ -33,17 +33,17 @@ class TestRoleCollectionItemSchema: @pytest.fixture(scope="class") def role(self, minimal_app_for_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_api.app, # type: ignore name="Test", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test") + delete_role(minimal_app_for_api.app, "Test") @pytest.fixture(autouse=True) def _set_attrs(self, minimal_app_for_api, role): - self.app = minimal_app_for_api + self.connexion_app = minimal_app_for_api self.role = role def test_serialize(self): @@ -69,24 +69,24 @@ class TestRoleCollectionSchema: @pytest.fixture(scope="class") def role1(self, minimal_app_for_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_api.app, # type: ignore name="Test1", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test1") + delete_role(minimal_app_for_api.app, "Test1") @pytest.fixture(scope="class") def role2(self, minimal_app_for_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_api.app, # type: ignore name="Test2", permissions=[ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), ], ) - delete_role(minimal_app_for_api, "Test2") + delete_role(minimal_app_for_api.app, "Test2") def test_serialize(self, role1, role2): instance = RoleCollection([role1, role2], total_entries=2) diff --git a/tests/api_connexion/test_auth.py b/tests/api_connexion/test_auth.py index 869b69990f00cd..2ec6187c4cc0ca 100644 --- a/tests/api_connexion/test_auth.py +++ b/tests/api_connexion/test_auth.py @@ -19,7 +19,6 @@ from base64 import b64encode import pytest -from flask_login import current_user from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.config import conf_vars @@ -32,9 +31,10 @@ class BaseTestAuth: @pytest.fixture(autouse=True) def set_attrs(self, minimal_app_for_api): - self.app = minimal_app_for_api + self.connexion_app = minimal_app_for_api + self.flask_app = self.connexion_app.app - sm = self.app.appbuilder.sm + sm = self.flask_app.appbuilder.sm tester = sm.find_user(username="test") if not tester: role_admin = sm.find_role("Admin") @@ -53,25 +53,28 @@ class TestBasicAuth(BaseTestAuth): def with_basic_auth_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: - with conf_vars({("api", "auth_backends"): "airflow.api.auth.backend.basic_auth"}): - init_api_experimental_auth(minimal_app_for_api) + with conf_vars( + {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} + ): + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_success(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert current_user.email == "test@fab.org" + # assert current_user.email == "test@fab.org" assert response.status_code == 200 - assert response.json == { + assert response.json() == { "pools": [ { "name": "default_pool", @@ -103,7 +106,7 @@ def test_success(self): ], ) def test_malformed_headers(self, token): - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["Content-Type"] == "application/problem+json" @@ -120,7 +123,7 @@ def test_malformed_headers(self, token): ], ) def test_invalid_auth_header(self, token): - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["Content-Type"] == "application/problem+json" @@ -133,22 +136,23 @@ class TestSessionAuth(BaseTestAuth): def with_session_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars({("api", "auth_backends"): "airflow.api.auth.backend.session"}): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_success(self): clear_db_pools() - admin_user = client_with_login(self.app, username="test", password="test") + admin_user = client_with_login(self.connexion_app, username="test", password="test") response = admin_user.get("/api/v1/pools") assert response.status_code == 200 - assert response.json == { + assert response.json() == { "pools": [ { "name": "default_pool", @@ -167,7 +171,7 @@ def test_success(self): } def test_failure(self): - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools") assert response.status_code == 401 assert response.headers["Content-Type"] == "application/problem+json" @@ -179,7 +183,8 @@ class TestSessionWithBasicAuthFallback(BaseTestAuth): def with_basic_auth_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars( @@ -187,29 +192,29 @@ def with_basic_auth_backend(self, minimal_app_for_api): ( "api", "auth_backends", - ): "airflow.api.auth.backend.session,airflow.api.auth.backend.basic_auth" + ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" } ): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_basic_auth_fallback(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() # request uses session - admin_user = client_with_login(self.app, username="test", password="test") + admin_user = client_with_login(self.connexion_app, username="test", password="test") response = admin_user.get("/api/v1/pools") assert response.status_code == 200 # request uses basic auth - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 200 # request without session or basic auth header - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools") assert response.status_code == 401 diff --git a/tests/api_connexion/test_cors.py b/tests/api_connexion/test_cors.py index 4dc4950df99460..fb60eebb44e7bb 100644 --- a/tests/api_connexion/test_cors.py +++ b/tests/api_connexion/test_cors.py @@ -28,10 +28,12 @@ class BaseTestAuth: @pytest.fixture(autouse=True) - def set_attrs(self, minimal_app_for_api): - self.app = minimal_app_for_api + def set_attrs(self, minimal_app_for_api, minimal_app_for_api_cors_allow_all): + self.connexion_app = minimal_app_for_api + self.connexion_app_cors_allow_all = minimal_app_for_api_cors_allow_all + self.flask_app = self.connexion_app.app - sm = self.app.appbuilder.sm + sm = self.flask_app.appbuilder.sm tester = sm.find_user(username="test") if not tester: role_admin = sm.find_role("Admin") @@ -50,20 +52,21 @@ class TestEmptyCors(BaseTestAuth): def with_basic_auth_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars({("api", "auth_backends"): "airflow.api.auth.backend.basic_auth"}): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_empty_cors_headers(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: + with self.connexion_app.test_client() as test_client: response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 200 assert "Access-Control-Allow-Headers" not in response.headers @@ -76,29 +79,25 @@ class TestCorsOrigin(BaseTestAuth): def with_basic_auth_backend(self, minimal_app_for_api): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + flask_app = minimal_app_for_api.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars( { ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth", - ("api", "access_control_allow_origins"): "http://apache.org http://example.com", } ): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 200 - assert response.headers["Access-Control-Allow-Origin"] == "http://apache.org" - + with self.connexion_app.test_client() as test_client: response = test_client.get( "/api/v1/pools", headers={"Authorization": token, "Origin": "http://apache.org"} ) @@ -109,33 +108,35 @@ def test_cors_origin_reflection(self): "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} ) assert response.status_code == 200 + assert response.headers["Access-Control-Allow-Origin"] == "http://example.com" class TestCorsWildcard(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_api_cors_allow_all): from airflow.www.extensions.init_security import init_api_experimental_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + self.connexion_app = minimal_app_for_api_cors_allow_all + flask_app = minimal_app_for_api_cors_allow_all.app + old_auth = getattr(flask_app, "api_auth") try: with conf_vars( { ("api", "auth_backends"): "airflow.api.auth.backend.basic_auth", - ("api", "access_control_allow_origins"): "*", } ): - init_api_experimental_auth(minimal_app_for_api) + init_api_experimental_auth(flask_app) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(flask_app, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: + with self.connexion_app_cors_allow_all.test_client() as test_client: response = test_client.get( "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} ) diff --git a/tests/api_connexion/test_error_handling.py b/tests/api_connexion/test_error_handling.py index d89515d05b68f1..59a056bed1e88a 100644 --- a/tests/api_connexion/test_error_handling.py +++ b/tests/api_connexion/test_error_handling.py @@ -31,8 +31,8 @@ def test_incorrect_endpoint_should_return_json(minimal_app_for_api): # Then we have parsable JSON as output - assert "Not Found" == resp.json["title"] - assert 404 == resp.json["status"] + assert "Not Found" == resp.json()["title"] + assert 404 == resp.json()["status"] assert 404 == resp.status_code @@ -45,8 +45,7 @@ def test_incorrect_endpoint_should_return_html(minimal_app_for_api): # Then we do not have JSON as response, rather standard HTML - assert resp.json is None - assert resp.mimetype == "text/html" + assert resp.headers["content-type"].startswith("text/html") assert resp.status_code == 404 @@ -60,8 +59,8 @@ def test_incorrect_method_should_return_json(minimal_app_for_api): # Then we have parsable JSON as output - assert "Method Not Allowed" == resp.json["title"] - assert 405 == resp.json["status"] + assert "Method Not Allowed" == resp.json()["title"] + assert 405 == resp.json()["status"] assert 405 == resp.status_code @@ -74,6 +73,5 @@ def test_incorrect_method_should_return_html(minimal_app_for_api): # Then we do not have JSON as response, rather standard HTML - assert resp.json is None - assert resp.mimetype == "text/html" + assert resp.headers["content-type"].startswith("text/html") assert resp.status_code == 405 diff --git a/tests/api_connexion/test_security.py b/tests/api_connexion/test_security.py index e75eba53e40f41..d0fa1988caaba8 100644 --- a/tests/api_connexion/test_security.py +++ b/tests/api_connexion/test_security.py @@ -20,35 +20,37 @@ from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @pytest.fixture(scope="module") def configured_app(minimal_app_for_api): - app = minimal_app_for_api + flask_app = minimal_app_for_api.app create_user( - app, # type:ignore + flask_app, # type:ignore username="test", role_name="Test", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore ) - yield minimal_app_for_api + with conf_vars({("webserver", "expose_config"): "True"}): + yield minimal_app_for_api - delete_user(app, username="test") # type: ignore + delete_user(flask_app, username="test") # type: ignore class TestSession: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.client = self.connexion_app.test_client() # type:ignore def test_session_not_created_on_api_request(self): - self.client.get("api/v1/dags", environ_overrides={"REMOTE_USER": "test"}) - assert all(cookie.name != "session" for cookie in self.client.cookie_jar) + self.client.get("/api/v1/dags", headers={"REMOTE_USER": "test"}) + assert all(cookie.name != "session" for cookie in self.client.cookies) def test_session_not_created_on_health_endpoint_request(self): self.client.get("health") - assert all(cookie.name != "session" for cookie in self.client.cookie_jar) + assert all(cookie.name != "session" for cookie in self.client.cookies) diff --git a/tests/api_experimental/auth/backend/test_basic_auth.py b/tests/api_experimental/auth/backend/test_basic_auth.py index 0d84465dd04462..2f045447950b5d 100644 --- a/tests/api_experimental/auth/backend/test_basic_auth.py +++ b/tests/api_experimental/auth/backend/test_basic_auth.py @@ -29,9 +29,9 @@ class TestBasicAuth: @pytest.fixture(autouse=True) def set_attrs(self, minimal_app_for_experimental_api): - self.app = minimal_app_for_experimental_api + self.connexion_app = minimal_app_for_experimental_api - self.appbuilder = self.app.appbuilder + self.appbuilder = self.connexion_app.app.appbuilder role_admin = self.appbuilder.sm.find_role("Admin") tester = self.appbuilder.sm.find_user(username="test") if not tester: @@ -48,7 +48,7 @@ def test_success(self): token = "Basic " + b64encode(b"test:test").decode() clear_db_pools() - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools", headers={"Authorization": token}) assert current_user.email == "test@fab.org" @@ -68,7 +68,7 @@ def test_success(self): ], ) def test_malformed_headers(self, token): - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["WWW-Authenticate"] == "Basic" @@ -83,14 +83,14 @@ def test_malformed_headers(self, token): ], ) def test_invalid_auth_header(self, token): - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools", headers={"Authorization": token}) assert response.status_code == 401 assert response.headers["WWW-Authenticate"] == "Basic" @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_experimental_api(self): - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools", headers={"Authorization": "Basic"}) assert response.status_code == 401 assert response.headers["WWW-Authenticate"] == "Basic" diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py index afa3bc9920ef47..5984d597895b46 100644 --- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -70,8 +70,8 @@ def equals(a, b) -> bool: class TestRpcApiEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator: - self.app = minimal_app_for_internal_api - self.client = self.app.test_client() # type:ignore + self.connexion_app = minimal_app_for_internal_api + self.client = self.connexion_app.test_client() # type:ignore mock_test_method.reset_mock() mock_test_method.side_effect = None with mock.patch( @@ -85,7 +85,7 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator: @pytest.mark.parametrize( "input_params, method_result, result_cmp_func, method_params", [ - ({}, None, lambda got, _: got == b"", {}), + ({}, None, lambda got, _: got == "", {}), ({}, "test_me", equals, {}), ( BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"}), @@ -123,9 +123,9 @@ def test_method(self, input_params, method_result, result_cmp_func, method_param ) assert response.status_code == 200 if method_result: - response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True) + response_data = BaseSerialization.deserialize(json.loads(response.text), use_pydantic_models=True) else: - response_data = response.data + response_data = response.text assert result_cmp_func(response_data, method_result) @@ -139,7 +139,7 @@ def test_method_with_exception(self): "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) assert response.status_code == 500 - assert response.data, b"Error executing method: test_method." + assert response.text, b"Error executing method: test_method." mock_test_method.assert_called_once() def test_unknown_method(self): @@ -149,7 +149,7 @@ def test_unknown_method(self): "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) assert response.status_code == 400 - assert response.data == b"Unrecognized method: i-bet-it-does-not-exist." + assert response.text == "Unrecognized method: i-bet-it-does-not-exist." mock_test_method.assert_not_called() def test_invalid_jsonrpc(self): @@ -159,5 +159,5 @@ def test_invalid_jsonrpc(self): "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) assert response.status_code == 400 - assert response.data == b"Expected jsonrpc 2.0 request." + assert response.text == "Expected jsonrpc 2.0 request." mock_test_method.assert_not_called() diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 04191c4838c8aa..151b3bb1c8c567 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -125,8 +125,9 @@ class TestBaseAuthManager: def test_get_cli_commands_return_empty_list(self, auth_manager): assert auth_manager.get_cli_commands() == [] - def test_get_api_endpoints_return_none(self, auth_manager): - assert auth_manager.get_api_endpoints() is None + def test_set_api_endpoints_return_none(self, auth_manager): + flask_app = Flask(__name__) + assert auth_manager.set_api_endpoints(flask_app) is None def test_get_user_name(self, auth_manager): user = Mock() diff --git a/tests/cli/commands/test_internal_api_command.py b/tests/cli/commands/test_internal_api_command.py index 9de857588a3fca..aadfd52589574a 100644 --- a/tests/cli/commands/test_internal_api_command.py +++ b/tests/cli/commands/test_internal_api_command.py @@ -163,8 +163,7 @@ def test_cli_internal_api_debug(self, app): internal_api_command.internal_api(args) app_run.assert_called_with( - debug=True, - use_reloader=False, + log_level="debug", port=9080, host="0.0.0.0", ) @@ -192,7 +191,7 @@ def test_cli_internal_api_args(self): "--workers", "4", "--worker-class", - "sync", + "uvicorn.workers.UvicornWorker", "--timeout", "120", "--bind", diff --git a/tests/cli/commands/test_webserver_command.py b/tests/cli/commands/test_webserver_command.py index 07d95a9e5f75a2..d24f1adf9bd953 100644 --- a/tests/cli/commands/test_webserver_command.py +++ b/tests/cli/commands/test_webserver_command.py @@ -324,11 +324,11 @@ def test_cli_webserver_debug(self, app): webserver_command.webserver(args) app_run.assert_called_with( - debug=True, - use_reloader=False, + log_level="debug", port=8080, host="0.0.0.0", - ssl_context=None, + ssl_certfile=None, + ssl_keyfile=None, ) def test_cli_webserver_args(self): @@ -352,7 +352,7 @@ def test_cli_webserver_args(self): "--workers", "4", "--worker-class", - "sync", + "uvicorn.workers.UvicornWorker", "--timeout", "120", "--bind", diff --git a/tests/conftest.py b/tests/conftest.py index 32dbc7ec3ea5f5..5e8f69dab27e7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -97,7 +97,7 @@ os.environ["_AIRFLOW_RUN_DB_TESTS_ONLY"] = "true" AIRFLOW_TESTS_DIR = Path(os.path.dirname(os.path.realpath(__file__))).resolve() -AIRFLOW_SOURCES_ROOT_DIR = AIRFLOW_TESTS_DIR.parent.parent +AIRFLOW_SOURCES_ROOT_DIR = AIRFLOW_TESTS_DIR.parent os.environ["AIRFLOW__CORE__PLUGINS_FOLDER"] = os.fspath(AIRFLOW_TESTS_DIR / "plugins") os.environ["AIRFLOW__CORE__DAGS_FOLDER"] = os.fspath(AIRFLOW_TESTS_DIR / "dags") @@ -1131,20 +1131,21 @@ def _get(dag_id): @pytest.fixture def create_log_template(request): - from airflow import settings from airflow.models.tasklog import LogTemplate - session = settings.Session() - def _create_log_template(filename_template, elasticsearch_id=""): - log_template = LogTemplate(filename=filename_template, elasticsearch_id=elasticsearch_id) - session.add(log_template) - session.commit() + from airflow.utils.session import create_session - def _delete_log_template(): - session.delete(log_template) + with create_session() as session: + log_template = LogTemplate(filename=filename_template, elasticsearch_id=elasticsearch_id) + session.add(log_template) session.commit() + def _delete_log_template(): + with create_session() as session: + session.delete(log_template) + session.commit() + request.addfinalizer(_delete_log_template) return _create_log_template @@ -1251,6 +1252,16 @@ def initialize_providers_manager(): ProvidersManager().initialize_providers_configuration() +@pytest.fixture(autouse=True) +def create_swagger_ui_dir_if_missing(): + """ + The directory needs to exist to satisfy starlette attempting to register it as middleware + :return: + """ + swagger_ui_dir = AIRFLOW_SOURCES_ROOT_DIR / "airflow" / "www" / "static" / "dist" / "swagger-ui" + swagger_ui_dir.mkdir(exist_ok=True, parents=True) + + @pytest.fixture(autouse=True) def close_all_sqlalchemy_sessions(): from sqlalchemy.orm import close_all_sessions diff --git a/tests/integration/api_experimental/auth/backend/test_kerberos_auth.py b/tests/integration/api_experimental/auth/backend/test_kerberos_auth.py index 3641163952bb80..efaae029e53af1 100644 --- a/tests/integration/api_experimental/auth/backend/test_kerberos_auth.py +++ b/tests/integration/api_experimental/auth/backend/test_kerberos_auth.py @@ -57,15 +57,15 @@ def dagbag_to_db(): class TestApiKerberos: @pytest.fixture(autouse=True) def _set_attrs(self, app_for_kerberos, dagbag_to_db): - self.app = app_for_kerberos + self.connexion_app = app_for_kerberos def test_trigger_dag(self): - with self.app.test_client() as client: + with self.connexion_app.app.test_client() as client: url_template = "/api/experimental/dags/{}/dag_runs" response = client.post( url_template.format("example_bash_operator"), data=json.dumps(dict(run_id="my_run" + datetime.now().isoformat())), - content_type="application/json", + headers={"Content-Type": "application/json"}, ) assert 401 == response.status_code @@ -86,21 +86,22 @@ class Request: CLIENT_AUTH.handle_response(response) assert "Authorization" in response.request.headers + headers = response.request.headers + headers.update({"Content-Type": "application/json"}) response2 = client.post( url_template.format("example_bash_operator"), data=json.dumps(dict(run_id="my_run" + datetime.now().isoformat())), - content_type="application/json", - headers=response.request.headers, + headers=headers, ) assert 200 == response2.status_code def test_unauthorized(self): - with self.app.test_client() as client: + with self.connexion_app.app.test_client() as client: url_template = "/api/experimental/dags/{}/dag_runs" response = client.post( url_template.format("example_bash_operator"), data=json.dumps(dict(run_id="my_run" + datetime.now().isoformat())), - content_type="application/json", + headers={"Content-Type": "application/json"}, ) assert 401 == response.status_code diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index 62c68fd6e9b1c6..986782a1425459 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -77,8 +77,8 @@ def wrapper(*args, **kwargs): class TestPluginsRBAC: @pytest.fixture(autouse=True) def _set_attrs(self, app): - self.app = app - self.appbuilder = app.appbuilder + self.connexion_app = app + self.appbuilder = app.app.appbuilder def test_flaskappbuilder_views(self): from tests.plugins.test_plugin import v_appbuilder_package @@ -136,12 +136,15 @@ def test_app_blueprints(self): from tests.plugins.test_plugin import bp # Blueprint should be present in the app - assert "test_plugin" in self.app.blueprints - assert self.app.blueprints["test_plugin"].name == bp.name + assert "test_plugin" in self.connexion_app.app.blueprints + assert self.connexion_app.app.blueprints["test_plugin"].name == bp.name def test_app_static_folder(self): # Blueprint static folder should be properly set - assert AIRFLOW_SOURCES_ROOT / "airflow" / "www" / "static" == Path(self.app.static_folder).resolve() + assert ( + AIRFLOW_SOURCES_ROOT / "airflow" / "www" / "static" + == Path(self.connexion_app.app.static_folder).resolve() + ) @pytest.mark.db_test @@ -154,7 +157,7 @@ class AirflowNoMenuViewsPlugin(AirflowPlugin): appbuilder_class_name = str(v_nomenu_appbuilder_package["view"].__class__.__name__) with mock_plugin_manager(plugins=[AirflowNoMenuViewsPlugin()]): - appbuilder = application.create_app(testing=True).appbuilder + appbuilder = application.create_app(testing=True).app.appbuilder plugin_views = [view for view in appbuilder.baseviews if view.blueprint.name == appbuilder_class_name] diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py index 90d9138bd956e5..f6c12a5339cba3 100644 --- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -154,7 +154,7 @@ def test_avp_facade(self, auth_manager): def test_get_user(self, mock_is_logged_in, auth_manager, app, test_user): mock_is_logged_in.return_value = True - with app.test_request_context(): + with app.app.test_request_context(): session["aws_user"] = test_user result = auth_manager.get_user() @@ -169,7 +169,7 @@ def test_get_user_return_none_when_not_logged_in(self, mock_is_logged_in, auth_m @pytest.mark.db_test def test_is_logged_in(self, auth_manager, app, test_user): - with app.test_request_context(): + with app.app.test_request_context(): session["aws_user"] = test_user result = auth_manager.is_logged_in() @@ -177,7 +177,7 @@ def test_is_logged_in(self, auth_manager, app, test_user): @pytest.mark.db_test def test_is_logged_in_return_false_when_no_user_in_session(self, auth_manager, app, test_user): - with app.test_request_context(): + with app.app.test_request_context(): result = auth_manager.is_logged_in() assert result is False diff --git a/tests/providers/amazon/aws/auth_manager/views/test_auth.py b/tests/providers/amazon/aws/auth_manager/views/test_auth.py index 85ef6aafe65051..70533ec7b7455c 100644 --- a/tests/providers/amazon/aws/auth_manager/views/test_auth.py +++ b/tests/providers/amazon/aws/auth_manager/views/test_auth.py @@ -70,19 +70,19 @@ def aws_app(): @pytest.mark.db_test class TestAwsAuthManagerAuthenticationViews: def test_login_redirect_to_identity_center(self, aws_app): - with aws_app.test_client() as client: + with aws_app.app.test_client() as client: response = client.get("/login") assert response.status_code == 302 assert response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/assertion/") def test_logout_redirect_to_identity_center(self, aws_app): - with aws_app.test_client() as client: + with aws_app.app.test_client() as client: response = client.get("/logout") assert response.status_code == 302 assert response.location.startswith("https://portal.sso.us-east-1.amazonaws.com/saml/logout/") def test_login_metadata_return_xml_file(self, aws_app): - with aws_app.test_client() as client: + with aws_app.app.test_client() as client: response = client.get("/login_metadata") assert response.status_code == 200 assert response.headers["Content-Type"] == "text/xml" @@ -116,8 +116,8 @@ def test_login_callback_set_user_in_session(self): "email": ["email"], } mock_init_saml_auth.return_value = auth - app = application.create_app(testing=True) - with app.test_client() as client: + connexion_app = application.create_app(testing=True) + with connexion_app.app.test_client() as client: response = client.get("/login_callback") assert response.status_code == 302 assert response.location == url_for("Airflow.index") @@ -148,12 +148,12 @@ def test_login_callback_raise_exception_if_errors(self): auth = Mock() auth.is_authenticated.return_value = False mock_init_saml_auth.return_value = auth - app = application.create_app(testing=True) - with app.test_client() as client: + connexion_app = application.create_app(testing=True) + with connexion_app.app.test_client() as client: with pytest.raises(AirflowException): client.get("/login_callback") def test_logout_callback_raise_not_implemented_error(self, aws_app): - with aws_app.test_client() as client: + with aws_app.app.test_client() as client: with pytest.raises(NotImplementedError): client.get("/logout_callback") diff --git a/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py b/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py index 3bd81dcd108078..2a6f96232ccd00 100644 --- a/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py +++ b/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py @@ -65,7 +65,7 @@ def setup_method(self) -> None: mock_call.reset_mock() def test_requires_authentication_with_no_header(self, app): - with app.test_request_context() as mock_context: + with app.app.test_request_context() as mock_context: mock_context.request.authorization = None result = function_decorated() @@ -82,7 +82,7 @@ def test_requires_authentication_with_ldap( user = Mock() mock_sm.auth_user_ldap.return_value = user - with app.test_request_context() as mock_context: + with app.app.test_request_context() as mock_context: mock_context.request.authorization = mock_authorization function_decorated() @@ -101,7 +101,7 @@ def test_requires_authentication_with_db( user = Mock() mock_sm.auth_user_db.return_value = user - with app.test_request_context() as mock_context: + with app.app.test_request_context() as mock_context: mock_context.request.authorization = mock_authorization function_decorated() diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py index b8970d797eb6fd..ca4b4d1eec71d2 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py @@ -23,7 +23,6 @@ from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES from airflow.security import permissions from tests.test_utils.api_connexion_utils import ( - assert_401, create_role, create_user, delete_role, @@ -35,9 +34,9 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api + connexion_app = minimal_app_for_auth_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -48,58 +47,55 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestRoleEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore def teardown_method(self): """ Delete all roles except these ones. Test and TestNoPermissions are deleted by delete_user above """ - session = self.app.appbuilder.get_session + session = self.flask_app.appbuilder.get_session existing_roles = set(EXISTING_ROLES) existing_roles.update(["Test", "TestNoPermissions"]) roles = session.query(Role).filter(~Role.name.in_(existing_roles)).all() for role in roles: - delete_role(self.app, role.name) + delete_role(self.flask_app, role.name) class TestGetRoleEndpoint(TestRoleEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/roles/Admin", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/roles/Admin", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["name"] == "Admin" + assert response.json()["name"] == "Admin" def test_should_respond_404(self): - response = self.client.get( - "/auth/fab/v1/roles/invalid-role", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/auth/fab/v1/roles/invalid-role", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "Role with name 'invalid-role' was not found", "status": 404, "title": "Role not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/roles/Admin") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/roles/Admin", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/auth/fab/v1/roles/Admin", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @pytest.mark.parametrize( @@ -114,30 +110,26 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestGetRolesEndpoint(TestRoleEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/roles", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/roles", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 existing_roles = set(EXISTING_ROLES) existing_roles.update(["Test", "TestNoPermissions"]) - assert response.json["total_entries"] == len(existing_roles) - roles = {role["name"] for role in response.json["roles"]} + assert response.json()["total_entries"] == len(existing_roles) + roles = {role["name"] for role in response.json()["roles"]} assert roles == existing_roles def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/roles") - assert_401(response) + assert response.status_code == 401 def test_should_raises_400_for_invalid_order_by(self): - response = self.client.get( - "/auth/fab/v1/roles?order_by=invalid", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/auth/fab/v1/roles?order_by=invalid", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/roles", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/auth/fab/v1/roles", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @pytest.mark.parametrize( @@ -173,33 +165,31 @@ class TestGetRolesEndpointPaginationandFilter(TestRoleEndpoint): ], ) def test_can_handle_limit_and_offset(self, url, expected_roles): - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 existing_roles = set(EXISTING_ROLES) existing_roles.update(["Test", "TestNoPermissions"]) - assert response.json["total_entries"] == len(existing_roles) - roles = [role["name"] for role in response.json["roles"] if role] + assert response.json()["total_entries"] == len(existing_roles) + roles = [role["name"] for role in response.json()["roles"] if role] assert roles == expected_roles class TestGetPermissionsEndpoint(TestRoleEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/permissions", environ_overrides={"REMOTE_USER": "test"}) - actions = {i[0] for i in self.app.appbuilder.sm.get_all_permissions() if i} + response = self.client.get("/auth/fab/v1/permissions", headers={"REMOTE_USER": "test"}) + actions = {i[0] for i in self.flask_app.appbuilder.sm.get_all_permissions() if i} assert response.status_code == 200 - assert response.json["total_entries"] == len(actions) - returned_actions = {perm["name"] for perm in response.json["actions"]} + assert response.json()["total_entries"] == len(actions) + returned_actions = {perm["name"] for perm in response.json()["actions"]} assert actions == returned_actions def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/permissions") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/permissions", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/auth/fab/v1/permissions", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @pytest.mark.parametrize( @@ -218,11 +208,9 @@ def test_post_should_respond_200(self): "name": "Test2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } - response = self.client.post( - "/auth/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/auth/fab/v1/roles", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - role = self.app.appbuilder.sm.find_role("Test2") + role = self.flask_app.appbuilder.sm.find_role("Test2") assert role is not None @pytest.mark.parametrize( @@ -291,11 +279,9 @@ def test_post_should_respond_200(self): ], ) def test_post_should_respond_400_for_invalid_payload(self, payload, error_message): - response = self.client.post( - "/auth/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/auth/fab/v1/roles", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "detail": error_message, "status": 400, "title": "Bad Request", @@ -307,11 +293,9 @@ def test_post_should_respond_409_already_exist(self): "name": "Test", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } - response = self.client.post( - "/auth/fab/v1/roles", json=payload, environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.post("/auth/fab/v1/roles", json=payload, headers={"REMOTE_USER": "test"}) assert response.status_code == 409 - assert response.json == { + assert response.json() == { "detail": "Role with name 'Test' already exists; please update with the PATCH endpoint", "status": 409, "title": "Conflict", @@ -327,7 +311,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.post( @@ -336,7 +320,7 @@ def test_should_raise_403_forbidden(self): "name": "mytest2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -356,20 +340,16 @@ def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_c class TestDeleteRole(TestRoleEndpoint): def test_delete_should_respond_204(self, session): - role = create_role(self.app, "mytestrole") - response = self.client.delete( - f"/auth/fab/v1/roles/{role.name}", environ_overrides={"REMOTE_USER": "test"} - ) + role = create_role(self.flask_app, "mytestrole") + response = self.client.delete(f"/auth/fab/v1/roles/{role.name}", headers={"REMOTE_USER": "test"}) assert response.status_code == 204 role_obj = session.query(Role).filter(Role.name == role.name).all() assert len(role_obj) == 0 def test_delete_should_respond_404(self): - response = self.client.delete( - "/auth/fab/v1/roles/invalidrolename", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.delete("/auth/fab/v1/roles/invalidrolename", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 - assert response.json == { + assert response.json() == { "detail": "Role with name 'invalidrolename' was not found", "status": 404, "title": "Role not found", @@ -379,11 +359,11 @@ def test_delete_should_respond_404(self): def test_should_raises_401_unauthenticated(self): response = self.client.delete("/auth/fab/v1/roles/test") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.delete( - "/auth/fab/v1/roles/test", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/auth/fab/v1/roles/test", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 @@ -393,7 +373,7 @@ def test_should_raise_403_forbidden(self): indirect=["set_auto_role_public"], ) def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") response = self.client.delete(f"/auth/fab/v1/roles/{role.name}") assert response.status_code == expected_status_code @@ -414,17 +394,17 @@ class TestPatchRole(TestRoleEndpoint): ], ) def test_patch_should_respond_200(self, payload, expected_name, expected_actions): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") response = self.client.patch( - f"/auth/fab/v1/roles/{role.name}", json=payload, environ_overrides={"REMOTE_USER": "test"} + f"/auth/fab/v1/roles/{role.name}", json=payload, headers={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["name"] == expected_name - assert response.json["actions"] == expected_actions + assert response.json()["name"] == expected_name + assert response.json()["actions"] == expected_actions def test_patch_should_update_correct_roles_permissions(self): - create_role(self.app, "role_to_change") - create_role(self.app, "already_exists") + create_role(self.flask_app, "role_to_change") + create_role(self.flask_app, "already_exists") response = self.client.patch( "/auth/fab/v1/roles/role_to_change", @@ -432,16 +412,16 @@ def test_patch_should_update_correct_roles_permissions(self): "name": "already_exists", "actions": [{"action": {"name": "can_delete"}, "resource": {"name": "XComs"}}], }, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - updated_permissions = self.app.appbuilder.sm.find_role("role_to_change").permissions + updated_permissions = self.flask_app.appbuilder.sm.find_role("role_to_change").permissions assert len(updated_permissions) == 1 assert updated_permissions[0].resource.name == "XComs" assert updated_permissions[0].action.name == "can_delete" - assert len(self.app.appbuilder.sm.find_role("already_exists").permissions) == 0 + assert len(self.flask_app.appbuilder.sm.find_role("already_exists").permissions) == 0 @pytest.mark.parametrize( "update_mask, payload, expected_name, expected_actions", @@ -469,27 +449,27 @@ def test_patch_should_update_correct_roles_permissions(self): def test_patch_should_respond_200_with_update_mask( self, update_mask, payload, expected_name, expected_actions ): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") assert role.permissions == [] response = self.client.patch( f"/auth/fab/v1/roles/{role.name}{update_mask}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json["name"] == expected_name - assert response.json["actions"] == expected_actions + assert response.json()["name"] == expected_name + assert response.json()["actions"] == expected_actions def test_patch_should_respond_400_for_invalid_fields_in_update_mask(self): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") payload = {"name": "testme"} response = self.client.patch( f"/auth/fab/v1/roles/{role.name}?update_mask=invalid_name", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == "'invalid_name' in update_mask is unknown" + assert response.json()["detail"] == "'invalid_name' in update_mask is unknown" @pytest.mark.parametrize( "payload, expected_error", @@ -542,14 +522,14 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask(self): ], ) def test_patch_should_respond_400_for_invalid_update(self, payload, expected_error): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") response = self.client.patch( f"/auth/fab/v1/roles/{role.name}", json=payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400 - assert response.json["detail"] == expected_error + assert response.json()["detail"] == expected_error def test_should_raises_401_unauthenticated(self): response = self.client.patch( @@ -560,7 +540,7 @@ def test_should_raises_401_unauthenticated(self): }, ) - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.patch( @@ -569,7 +549,7 @@ def test_should_raise_403_forbidden(self): "name": "mytest2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], }, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 @@ -579,7 +559,7 @@ def test_should_raise_403_forbidden(self): indirect=["set_auto_role_public"], ) def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): - role = create_role(self.app, "mytestrole") + role = create_role(self.flask_app, "mytestrole") response = self.client.patch( f"/auth/fab/v1/roles/{role.name}", json={"name": "mytest"}, diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py index 9092f7c36361b1..8ba1441710299d 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py @@ -26,19 +26,19 @@ from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user +from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test -DEFAULT_TIME = "2020-06-11T18:00:00+00:00" +DEFAULT_TIME = "2020-06-11T18:00:00" @pytest.fixture(scope="module") def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api + connexion_app = minimal_app_for_auth_api create_user( - app, # type: ignore + connexion_app.app, # type: ignore username="test", role_name="Test", permissions=[ @@ -48,20 +48,21 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(connexion_app.app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - yield app + yield connexion_app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(connexion_app.app, username="test") # type: ignore + delete_user(connexion_app.app, username="test_no_permissions") # type: ignore class TestUserEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - self.session = self.app.appbuilder.get_session + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore + self.session = self.flask_app.appbuilder.get_session def teardown_method(self) -> None: # Delete users that have our custom default time @@ -94,9 +95,9 @@ def test_should_respond_200(self): users = self._create_users(1) self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users/TEST_USER1", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "active": True, "changed_on": DEFAULT_TIME, "created_on": DEFAULT_TIME, @@ -122,9 +123,9 @@ def test_last_names_can_be_empty(self): ) self.session.add_all([prince]) self.session.commit() - response = self.client.get("/auth/fab/v1/users/prince", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users/prince", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "active": True, "changed_on": DEFAULT_TIME, "created_on": DEFAULT_TIME, @@ -150,9 +151,9 @@ def test_first_names_can_be_empty(self): ) self.session.add_all([liberace]) self.session.commit() - response = self.client.get("/auth/fab/v1/users/liberace", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users/liberace", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "active": True, "changed_on": DEFAULT_TIME, "created_on": DEFAULT_TIME, @@ -178,9 +179,9 @@ def test_both_first_and_last_names_can_be_empty(self): ) self.session.add_all([nameless]) self.session.commit() - response = self.client.get("/auth/fab/v1/users/nameless", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users/nameless", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json == { + assert response.json() == { "active": True, "changed_on": DEFAULT_TIME, "created_on": DEFAULT_TIME, @@ -195,44 +196,40 @@ def test_both_first_and_last_names_can_be_empty(self): } def test_should_respond_404(self): - response = self.client.get( - "/auth/fab/v1/users/invalid-user", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/auth/fab/v1/users/invalid-user", headers={"REMOTE_USER": "test"}) assert response.status_code == 404 assert { "detail": "The User with username `invalid-user` was not found", "status": 404, "title": "User not found", "type": EXCEPTIONS_LINK_MAP[404], - } == response.json + } == response.json() def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/users/TEST_USER1") - assert_401(response) + assert response.status_code == 401 def test_should_raise_403_forbidden(self): response = self.client.get( - "/auth/fab/v1/users/TEST_USER1", environ_overrides={"REMOTE_USER": "test_no_permissions"} + "/auth/fab/v1/users/TEST_USER1", headers={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 class TestGetUsers(TestUserEndpoint): def test_should_response_200(self): - response = self.client.get("/auth/fab/v1/users", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 2 - usernames = [user["username"] for user in response.json["users"] if user] + assert response.json()["total_entries"] == 2 + usernames = [user["username"] for user in response.json()["users"] if user] assert usernames == ["test", "test_no_permissions"] def test_should_raises_401_unauthenticated(self): response = self.client.get("/auth/fab/v1/users") - assert_401(response) + assert response.status_code def test_should_raise_403_forbidden(self): - response = self.client.get( - "/auth/fab/v1/users", environ_overrides={"REMOTE_USER": "test_no_permissions"} - ) + response = self.client.get("/auth/fab/v1/users", headers={"REMOTE_USER": "test_no_permissions"}) assert response.status_code == 403 @@ -283,10 +280,10 @@ def test_handle_limit_offset(self, url, expected_usernames): users = self._create_users(10) self.session.add_all(users) self.session.commit() - response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get(url, headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert response.json["total_entries"] == 12 - usernames = [user["username"] for user in response.json["users"] if user] + assert response.json()["total_entries"] == 12 + usernames = [user["username"] for user in response.json()["users"] if user] assert usernames == expected_usernames def test_should_respect_page_size_limit_default(self): @@ -294,33 +291,31 @@ def test_should_respect_page_size_limit_default(self): self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 # Explicitly add the 2 users on setUp - assert response.json["total_entries"] == 200 + len(["test", "test_no_permissions"]) - assert len(response.json["users"]) == 100 + assert response.json()["total_entries"] == 200 + len(["test", "test_no_permissions"]) + assert len(response.json()["users"]) == 100 def test_should_response_400_with_invalid_order_by(self): users = self._create_users(2) self.session.add_all(users) self.session.commit() - response = self.client.get( - "/auth/fab/v1/users?order_by=myname", environ_overrides={"REMOTE_USER": "test"} - ) + response = self.client.get("/auth/fab/v1/users?order_by=myname", headers={"REMOTE_USER": "test"}) assert response.status_code == 400 msg = "Ordering with 'myname' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_limit_of_zero_should_return_default(self): users = self._create_users(200) self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users?limit=0", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users?limit=0", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 # Explicit add the 2 users on setUp - assert response.json["total_entries"] == 200 + len(["test", "test_no_permissions"]) - assert len(response.json["users"]) == 100 + assert response.json()["total_entries"] == 200 + len(["test", "test_no_permissions"]) + assert len(response.json()["users"]) == 100 @conf_vars({("api", "maximum_page_limit"): "150"}) def test_should_return_conf_max_if_req_max_above_conf(self): @@ -328,9 +323,9 @@ def test_should_return_conf_max_if_req_max_above_conf(self): self.session.add_all(users) self.session.commit() - response = self.client.get("/auth/fab/v1/users?limit=180", environ_overrides={"REMOTE_USER": "test"}) + response = self.client.get("/auth/fab/v1/users?limit=180", headers={"REMOTE_USER": "test"}) assert response.status_code == 200 - assert len(response.json["users"]) == 150 + assert len(response.json()["users"]) == 150 EXAMPLE_USER_NAME = "example_user" @@ -343,6 +338,7 @@ def _delete_user(**filters): user = session.query(User).filter_by(**filters).first() if user is None: return + session.refresh(user) user.roles = [] session.delete(user) @@ -364,7 +360,7 @@ def autoclean_email(): @pytest.fixture def user_with_same_username(configured_app, autoclean_username): user = create_user( - configured_app, + configured_app.app, username=autoclean_username, email="another_user@example.com", role_name="TestNoPermissions", @@ -376,7 +372,7 @@ def user_with_same_username(configured_app, autoclean_username): @pytest.fixture def user_with_same_email(configured_app, autoclean_email): user = create_user( - configured_app, + configured_app.app, username="another_user", email=autoclean_email, role_name="TestNoPermissions", @@ -391,7 +387,7 @@ def user_different(configured_app): email = "another_user@example.com" _delete_user(username=username, email=email) - user = create_user(configured_app, username=username, email=email, role_name="TestNoPermissions") + user = create_user(configured_app.app, username=username, email=email, role_name="TestNoPermissions") assert user, "failed to create user 'another_user '" yield user _delete_user(username=username, email=email) @@ -410,7 +406,7 @@ def autoclean_user_payload(autoclean_username, autoclean_email): @pytest.fixture def autoclean_admin_user(configured_app, autoclean_user_payload): - security_manager = configured_app.appbuilder.sm + security_manager = configured_app.app.appbuilder.sm return security_manager.add_user( role=security_manager.find_role("Admin"), **autoclean_user_payload, @@ -419,27 +415,29 @@ def autoclean_admin_user(configured_app, autoclean_user_payload): class TestPostUser(TestUserEndpoint): def test_with_default_role(self, autoclean_username, autoclean_user_payload): + self.flask_app.config["AUTH_USER_REGISTRATION_ROLE"] = "Public" response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.json - security_manager = self.app.appbuilder.sm + security_manager = self.flask_app.appbuilder.sm user = security_manager.find_user(autoclean_username) assert user is not None assert user.roles == [security_manager.find_role("Public")] + self.flask_app.config["AUTH_USER_REGISTRATION_ROLE"] = None def test_with_custom_roles(self, autoclean_username, autoclean_user_payload): response = self.client.post( "/auth/fab/v1/users", json={"roles": [{"name": "User"}, {"name": "Viewer"}], **autoclean_user_payload}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 200, response.json - security_manager = self.app.appbuilder.sm + security_manager = self.flask_app.appbuilder.sm user = security_manager.find_user(autoclean_username) assert user is not None assert {r.name for r in user.roles} == {"User", "Viewer"} @@ -449,24 +447,24 @@ def test_with_existing_different_user(self, autoclean_user_payload): response = self.client.post( "/auth/fab/v1/users", json={"roles": [{"name": "User"}, {"name": "Viewer"}], **autoclean_user_payload}, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() def test_unauthenticated(self, autoclean_user_payload): response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, ) - assert response.status_code == 401, response.json + assert response.status_code == 401, response.json() def test_forbidden(self, autoclean_user_payload): response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) - assert response.status_code == 403, response.json + assert response.status_code == 403, response.json() @pytest.mark.parametrize( "existing_user_fixture_name, error_detail_template", @@ -488,12 +486,12 @@ def test_already_exists( response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 409, response.json + assert response.status_code == 409, response.json() error_detail = error_detail_template.format(username=existing.username, email=existing.email) - assert response.json["detail"] == error_detail + assert response.json()["detail"] == error_detail @pytest.mark.parametrize( "payload_converter, error_message", @@ -524,10 +522,10 @@ def test_invalid_payload(self, autoclean_user_payload, payload_converter, error_ response = self.client.post( "/auth/fab/v1/users", json=payload_converter(autoclean_user_payload), - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 400, response.json - assert response.json == { + assert response.status_code == 400, response.json() + assert response.json() == { "detail": error_message, "status": 400, "title": "Bad Request", @@ -535,13 +533,13 @@ def test_invalid_payload(self, autoclean_user_payload, payload_converter, error_ } def test_internal_server_error(self, autoclean_user_payload): - with unittest.mock.patch.object(self.app.appbuilder.sm, "add_user", return_value=None): + with unittest.mock.patch.object(self.flask_app.appbuilder.sm, "add_user", return_value=None): response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.json == { + assert response.json() == { "detail": "Failed to add user `example_user`.", "status": 500, "title": "Internal Server Error", @@ -556,12 +554,12 @@ def test_change(self, autoclean_username, autoclean_user_payload): response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() # The first name is changed. - data = response.json + data = response.json() assert data["first_name"] == "Changed" assert data["last_name"] == "" @@ -572,12 +570,12 @@ def test_change_with_update_mask(self, autoclean_username, autoclean_user_payloa response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}?update_mask=last_name", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() # The first name is changed, but the last name isn't since we masked it. - data = response.json + data = response.json() assert data["first_name"] == "Tester" assert data["last_name"] == "McTesterson" @@ -602,11 +600,11 @@ def test_patch_already_exists( response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 409, response.json - assert response.json["detail"] == error_message + assert response.json()["detail"] == error_message @pytest.mark.parametrize( "field", @@ -623,10 +621,10 @@ def test_required_fields( response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) assert response.status_code == 400, response.json - assert response.json["detail"] == f"{{'{field}': ['Missing data for required field.']}}" + assert response.json()["detail"] == f"{{'{field}': ['Missing data for required field.']}}" @pytest.mark.usefixtures("autoclean_admin_user") def test_username_can_be_updated(self, autoclean_user_payload, autoclean_username): @@ -635,10 +633,10 @@ def test_username_can_be_updated(self, autoclean_user_payload, autoclean_usernam response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) _delete_user(username=testusername) - assert response.json["username"] == testusername + assert response.json()["username"] == testusername @pytest.mark.usefixtures("autoclean_admin_user") @unittest.mock.patch( @@ -655,10 +653,10 @@ def test_password_hashed( response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json - assert "password" not in response.json + assert response.status_code == 200, response.json() + assert "password" not in response.json() mock_generate_password_hash.assert_called_once_with("new-pass") @@ -674,10 +672,10 @@ def test_replace_roles(self, autoclean_username, autoclean_user_payload): response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}?update_mask=roles", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json - assert {d["name"] for d in response.json["roles"]} == {"User", "Viewer"} + assert response.status_code == 200, response.json() + assert {d["name"] for d in response.json()["roles"]} == {"User", "Viewer"} @pytest.mark.usefixtures("autoclean_admin_user") def test_unchanged(self, autoclean_username, autoclean_user_payload): @@ -685,12 +683,12 @@ def test_unchanged(self, autoclean_username, autoclean_user_payload): response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 200, response.json + assert response.status_code == 200, response.json() expected = {k: v for k, v in autoclean_user_payload.items() if k != "password"} - assert {k: response.json[k] for k in expected} == expected + assert {k: response.json()[k] for k in expected} == expected @pytest.mark.usefixtures("autoclean_admin_user") def test_unauthenticated(self, autoclean_username, autoclean_user_payload): @@ -698,25 +696,25 @@ def test_unauthenticated(self, autoclean_username, autoclean_user_payload): f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, ) - assert response.status_code == 401, response.json + assert response.status_code == 401, response.json() @pytest.mark.usefixtures("autoclean_admin_user") def test_forbidden(self, autoclean_username, autoclean_user_payload): response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) - assert response.status_code == 403, response.json + assert response.status_code == 403, response.json() def test_not_found(self, autoclean_username, autoclean_user_payload): # This test does not populate autoclean_admin_user into the database. response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=autoclean_user_payload, - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 404, response.json + assert response.status_code == 404, response.json() @pytest.mark.parametrize( "payload_converter, error_message", @@ -754,10 +752,10 @@ def test_invalid_payload( response = self.client.patch( f"/auth/fab/v1/users/{autoclean_username}", json=payload_converter(autoclean_user_payload), - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 400, response.json - assert response.json == { + assert response.status_code == 400, response.json() + assert response.json() == { "detail": error_message, "status": 400, "title": "Bad Request", @@ -770,9 +768,9 @@ class TestDeleteUser(TestUserEndpoint): def test_delete(self, autoclean_username): response = self.client.delete( f"/auth/fab/v1/users/{autoclean_username}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 204, response.json # NO CONTENT. + assert response.status_code == 204, response.json() # NO CONTENT. assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 0 @pytest.mark.usefixtures("autoclean_admin_user") @@ -780,22 +778,22 @@ def test_unauthenticated(self, autoclean_username): response = self.client.delete( f"/auth/fab/v1/users/{autoclean_username}", ) - assert response.status_code == 401, response.json + assert response.status_code == 401, response.json() assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 1 @pytest.mark.usefixtures("autoclean_admin_user") def test_forbidden(self, autoclean_username): response = self.client.delete( f"/auth/fab/v1/users/{autoclean_username}", - environ_overrides={"REMOTE_USER": "test_no_permissions"}, + headers={"REMOTE_USER": "test_no_permissions"}, ) - assert response.status_code == 403, response.json + assert response.status_code == 403, response.json() assert self.session.query(count(User.id)).filter(User.username == autoclean_username).scalar() == 1 def test_not_found(self, autoclean_username): # This test does not populate autoclean_admin_user into the database. response = self.client.delete( f"/auth/fab/v1/users/{autoclean_username}", - environ_overrides={"REMOTE_USER": "test"}, + headers={"REMOTE_USER": "test"}, ) - assert response.status_code == 404, response.json + assert response.status_code == 404, response.json() diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py index 222dbdbbb49a57..cf756333a6032f 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py @@ -32,24 +32,25 @@ @pytest.fixture(scope="module") def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api + connexion_app = minimal_app_for_auth_api create_role( - app, + connexion_app.app, name="TestRole", permissions=[], ) - yield app + yield connexion_app - delete_role(app, "TestRole") # type:ignore + delete_role(connexion_app.app, "TestRole") # type:ignore class TestUserBase: @pytest.fixture(autouse=True) def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - self.role = self.app.appbuilder.sm.find_role("TestRole") - self.session = self.app.appbuilder.get_session + self.connexion_app = configured_app + self.flask_app = self.connexion_app.app + self.client = self.connexion_app.test_client() # type:ignore + self.role = self.flask_app.appbuilder.sm.find_role("TestRole") + self.session = self.flask_app.appbuilder.get_session def teardown_method(self): user = self.session.query(User).filter(User.email == TEST_EMAIL).first() diff --git a/tests/providers/fab/auth_manager/conftest.py b/tests/providers/fab/auth_manager/conftest.py index 6b4feb143f4b5a..844d10eee6f1e4 100644 --- a/tests/providers/fab/auth_manager/conftest.py +++ b/tests/providers/fab/auth_manager/conftest.py @@ -29,14 +29,16 @@ def minimal_app_for_auth_api(): skip_all_except=[ "init_appbuilder", "init_api_experimental_auth", - "init_api_auth_provider", + "init_api_auth_manager", "init_api_error_handlers", ] ) def factory(): with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): - _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore - _app.config["AUTH_ROLE_PUBLIC"] = None + _app = app.create_app( + testing=True, + config={"WTF_CSRF_ENABLED": False, "AUTH_ROLE_PUBLIC": None}, + ) # type:ignore return _app return factory() @@ -45,9 +47,9 @@ def factory(): @pytest.fixture def set_auto_role_public(request): app = request.getfixturevalue("minimal_app_for_auth_api") - auto_role_public = app.config["AUTH_ROLE_PUBLIC"] - app.config["AUTH_ROLE_PUBLIC"] = request.param + auto_role_public = app.app.config["AUTH_ROLE_PUBLIC"] + app.app.config["AUTH_ROLE_PUBLIC"] = request.param yield - app.config["AUTH_ROLE_PUBLIC"] = auto_role_public + app.app.config["AUTH_ROLE_PUBLIC"] = auto_role_public diff --git a/tests/providers/fab/auth_manager/decorators/test_auth.py b/tests/providers/fab/auth_manager/decorators/test_auth.py index 4e0b6b6ffdccd4..a9978b22ef7c91 100644 --- a/tests/providers/fab/auth_manager/decorators/test_auth.py +++ b/tests/providers/fab/auth_manager/decorators/test_auth.py @@ -53,7 +53,7 @@ def mock_auth_manager(mock_sm): @pytest.fixture def mock_app(mock_appbuilder): app = Mock() - app.appbuilder = mock_appbuilder + app.app.appbuilder = mock_appbuilder return app @@ -76,11 +76,11 @@ def setup_method(self) -> None: def test_requires_access_fab_sync_resource_permissions( self, mock_get_auth_manager, mock_sm, mock_appbuilder, mock_auth_manager, app ): - app.appbuilder = mock_appbuilder + app.app.appbuilder = mock_appbuilder mock_appbuilder.update_perms = True mock_get_auth_manager.return_value = mock_auth_manager - with app.test_request_context(): + with app.app.test_request_context(): @_requires_access_fab() def decorated_requires_access_fab(): @@ -96,7 +96,7 @@ def test_requires_access_fab_access_denied( mock_sm.check_authorization.return_value = False mock_get_auth_manager.return_value = mock_auth_manager - with app.test_request_context(): + with app.app.test_request_context(): @_requires_access_fab(permissions) def decorated_requires_access_fab(): @@ -117,7 +117,7 @@ def test_requires_access_fab_access_granted( mock_sm.check_authorization.return_value = True mock_get_auth_manager.return_value = mock_auth_manager - with app.test_request_context(): + with app.app.test_request_context(): @_requires_access_fab(permissions) def decorated_requires_access_fab(): @@ -131,8 +131,8 @@ def decorated_requires_access_fab(): @patch("airflow.providers.fab.auth_manager.decorators.auth._has_access") def test_has_access_fab_with_no_dags(self, mock_has_access, mock_sm, mock_appbuilder, app): - app.appbuilder = mock_appbuilder - with app.test_request_context(): + app.app.appbuilder = mock_appbuilder + with app.app.test_request_context(): decorated_has_access_fab() mock_sm.check_authorization.assert_called_once_with(permissions, None) @@ -143,8 +143,8 @@ def test_has_access_fab_with_no_dags(self, mock_has_access, mock_sm, mock_appbui def test_has_access_fab_with_multiple_dags_render_error( self, mock_has_access, mock_render_template, mock_sm, mock_appbuilder, app ): - app.appbuilder = mock_appbuilder - with app.test_request_context() as mock_context: + app.app.appbuilder = mock_appbuilder + with app.app.test_request_context() as mock_context: mock_context.request.args = {"dag_id": "dag1"} mock_context.request.form = {"dag_id": "dag2"} decorated_has_access_fab() diff --git a/tests/providers/fab/auth_manager/test_security.py b/tests/providers/fab/auth_manager/test_security.py index aa262f675f666d..fed5a9117cf16a 100644 --- a/tests/providers/fab/auth_manager/test_security.py +++ b/tests/providers/fab/auth_manager/test_security.py @@ -163,16 +163,16 @@ def clear_db_before_test(): @pytest.fixture(scope="module") def app(): _app = application.create_app(testing=True) - _app.config["WTF_CSRF_ENABLED"] = False + _app.app.config["WTF_CSRF_ENABLED"] = False return _app @pytest.fixture(scope="module") def app_builder(app): - app_builder = app.appbuilder + app_builder = app.app.appbuilder app_builder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews") app_builder.add_view(SomeModelView, "SomeModelView", category="ModelViews") - return app.appbuilder + return app.app.appbuilder @pytest.fixture(scope="module") @@ -187,7 +187,7 @@ def session(app_builder): @pytest.fixture(scope="module") def db(app): - return SQLA(app) + return SQLA(app.app) @pytest.fixture @@ -199,7 +199,7 @@ def role(request, app, security_manager): security_manager.bulk_sync_roles(params["mock_roles"]) _role = security_manager.find_role(params["name"]) yield _role, params - delete_role(app, params["name"]) + delete_role(app.app, params["name"]) @pytest.fixture @@ -338,10 +338,10 @@ def test_verify_public_role_has_no_permissions(security_manager): def test_verify_default_anon_user_has_no_accessible_dag_ids( mock_is_logged_in, app, session, security_manager ): - with app.app_context(): + with app.app.app_context(): mock_is_logged_in.return_value = False user = AnonymousUser() - app.config["AUTH_ROLE_PUBLIC"] = "Public" + app.app.config["AUTH_ROLE_PUBLIC"] = "Public" assert security_manager.get_user_roles(user) == {security_manager.get_public_role()} with _create_dag_model_context("test_dag_id", session, security_manager): @@ -351,9 +351,9 @@ def test_verify_default_anon_user_has_no_accessible_dag_ids( def test_verify_default_anon_user_has_no_access_to_specific_dag(app, session, security_manager, has_dag_perm): - with app.app_context(): + with app.app.app_context(): user = AnonymousUser() - app.config["AUTH_ROLE_PUBLIC"] = "Public" + app.app.config["AUTH_ROLE_PUBLIC"] = "Public" assert security_manager.get_user_roles(user) == {security_manager.get_public_role()} dag_id = "test_dag_id" @@ -376,8 +376,8 @@ def test_verify_anon_user_with_admin_role_has_all_dag_access( mock_is_logged_in, app, security_manager, mock_dag_models ): test_dag_ids = mock_dag_models - with app.app_context(): - app.config["AUTH_ROLE_PUBLIC"] = "Admin" + with app.app.app_context(): + app.app.config["AUTH_ROLE_PUBLIC"] = "Admin" mock_is_logged_in.return_value = False user = AnonymousUser() @@ -391,9 +391,9 @@ def test_verify_anon_user_with_admin_role_has_all_dag_access( def test_verify_anon_user_with_admin_role_has_access_to_each_dag( app, session, security_manager, has_dag_perm ): - with app.app_context(): + with app.app.app_context(): user = AnonymousUser() - app.config["AUTH_ROLE_PUBLIC"] = "Admin" + app.app.config["AUTH_ROLE_PUBLIC"] = "Admin" # Call `.get_user_roles` bc `user` is a mock and the `user.roles` prop needs to be set. user.roles = security_manager.get_user_roles(user) @@ -453,9 +453,9 @@ def test_get_user_roles_for_anonymous_user(app, security_manager): (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS_MENU), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS), } - app.config["AUTH_ROLE_PUBLIC"] = "Viewer" + app.app.config["AUTH_ROLE_PUBLIC"] = "Viewer" - with app.app_context(): + with app.app.app_context(): user = AnonymousUser() perms_views = set() @@ -468,9 +468,9 @@ def test_get_current_user_permissions(app): action = "can_some_action" resource = "SomeBaseView" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="get_current_user_permissions", role_name="MyRole5", permissions=[ @@ -480,7 +480,7 @@ def test_get_current_user_permissions(app): assert user.perms == {(action, resource)} with create_user_scope( - app, + app.app, username="no_perms", ) as user: assert len(user.perms) == 0 @@ -493,9 +493,9 @@ def test_get_accessible_dag_ids(mock_is_logged_in, app, security_manager, sessio dag_id = "dag_id" username = "ElUser" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[ @@ -525,9 +525,9 @@ def test_dont_get_inaccessible_dag_ids_for_dag_resource_permission( role_name = "MyRole1" permission_action = [permissions.ACTION_CAN_EDIT] dag_id = "dag_id" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[ @@ -566,9 +566,9 @@ def test_sync_perm_for_dag_creates_permissions_for_specified_roles(app, security test_dag_id = "TEST_DAG" test_role = "limited-role" security_manager.bulk_sync_roles([{"role": test_role, "perms": []}]) - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="test_user", role_name=test_role, permissions=[], @@ -585,9 +585,9 @@ def test_sync_perm_for_dag_removes_existing_permissions_if_empty(app, security_m test_dag_id = "TEST_DAG" test_role = "limited-role" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="test_user", role_name=test_role, permissions=[], @@ -623,9 +623,9 @@ def test_sync_perm_for_dag_removes_permissions_from_other_roles(app, security_ma test_dag_id = "TEST_DAG" test_role = "limited-role" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="test_user", role_name=test_role, permissions=[], @@ -662,9 +662,9 @@ def test_sync_perm_for_dag_does_not_prune_roles_when_access_control_unset(app, s test_dag_id = "TEST_DAG" test_role = "limited-role" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="test_user", role_name=test_role, permissions=[], @@ -695,35 +695,35 @@ def test_sync_perm_for_dag_does_not_prune_roles_when_access_control_unset(app, s def test_has_all_dag_access(app, security_manager): for role_name in ["Admin", "Viewer", "Op", "User"]: - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="user", role_name=role_name, ) as user: assert _has_all_dags_access(user) - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="user", role_name="read_all", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)], ) as user: assert _has_all_dags_access(user) - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="user", role_name="edit_all", permissions=[(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)], ) as user: assert _has_all_dags_access(user) - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username="user", role_name="nada", permissions=[], @@ -745,9 +745,9 @@ def test_access_control_with_non_existent_role(security_manager): def test_all_dag_access_doesnt_give_non_dag_access(app, security_manager): username = "dag_access_user" role_name = "dag_access_role" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[ @@ -769,7 +769,7 @@ def test_access_control_with_invalid_permission(app, security_manager): username = "LaUser" rolename = "team-a" with create_user_scope( - app, + app.app, username=username, role_name=rolename, ): @@ -791,9 +791,9 @@ def test_access_control_is_set_on_init( username = "access_control_is_set_on_init" role_name = "team-a" negated_role = "NOT-team-a" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[], @@ -809,7 +809,7 @@ def test_access_control_is_set_on_init( ) security_manager.bulk_sync_roles([{"role": negated_role, "perms": []}]) - set_user_single_role(app, user, role_name=negated_role) + set_user_single_role(app.app, user, role_name=negated_role) assert_user_does_not_have_dag_perms( perms=["PUT", "GET"], dag_id="access_control_test", @@ -825,14 +825,14 @@ def test_access_control_stale_perms_are_revoked( ): username = "access_control_stale_perms_are_revoked" role_name = "team-a" - with app.app_context(): + with app.app.app_context(): with create_user_scope( - app, + app.app, username=username, role_name=role_name, permissions=[], ) as user: - set_user_single_role(app, user, role_name="team-a") + set_user_single_role(app.app, user, role_name="team-a") security_manager._sync_dag_view_permissions( "access_control_test", access_control={"team-a": READ_WRITE} ) @@ -976,7 +976,7 @@ def test_parent_dag_access_applies_to_subdag(app, security_manager, assert_user_ parent_dag_name = "parent_dag" subdag_name = parent_dag_name + ".subdag" subsubdag_name = parent_dag_name + ".subdag.subsubdag" - with app.app_context(): + with app.app.app_context(): mock_roles = [ { "role": role_name, @@ -987,7 +987,7 @@ def test_parent_dag_access_applies_to_subdag(app, security_manager, assert_user_ } ] with create_user_scope( - app, + app.app, username=username, role_name=role_name, ) as user: @@ -1017,7 +1017,7 @@ def test_permissions_work_for_dags_with_dot_in_dagname( role_name = "dag_permission_role" dag_id = "dag_id_1" dag_id_2 = "dag_id_1.with_dot" - with app.app_context(): + with app.app.app_context(): mock_roles = [ { "role": role_name, @@ -1028,7 +1028,7 @@ def test_permissions_work_for_dags_with_dot_in_dagname( } ] with create_user_scope( - app, + app.app, username=username, role_name=role_name, ) as user: @@ -1117,14 +1117,14 @@ def test_update_user_auth_stat_subsequent_unsuccessful_auth(mock_security_manage def test_users_can_be_found(app, security_manager, session, caplog): """Test that usernames are case insensitive""" - create_user(app, "Test") - create_user(app, "test") - create_user(app, "TEST") - create_user(app, "TeSt") + create_user(app.app, "Test") + create_user(app.app, "test") + create_user(app.app, "TEST") + create_user(app.app, "TeSt") assert security_manager.find_user("Test") users = security_manager.get_all_users() assert len(users) == 1 - delete_user(app, "Test") + delete_user(app.app, "Test") assert "Error adding new user to database" in caplog.text @@ -1174,7 +1174,7 @@ def test_dag_id_consistency( dag_id_json: str | None, fail: bool, ): - with app.test_request_context() as mock_context: + with app.app.test_request_context() as mock_context: from airflow.www.auth import has_access_dag mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args else {} @@ -1185,7 +1185,7 @@ def test_dag_id_consistency( mock_context.request._parsed_content_type = ["application/json"] with create_user_scope( - app, + app.app, username="test-user", role_name="limited-role", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)], diff --git a/tests/providers/fab/auth_manager/views/test_permissions.py b/tests/providers/fab/auth_manager/views/test_permissions.py index 14fbdd2232e794..f044c472b9e22d 100644 --- a/tests/providers/fab/auth_manager/views/test_permissions.py +++ b/tests/providers/fab/auth_manager/views/test_permissions.py @@ -33,7 +33,7 @@ def fab_app(): @pytest.fixture(scope="module") def user_permissions_reader(fab_app): return create_user( - fab_app, + fab_app.app, username="user_permissions", role_name="role_permissions", permissions=[ @@ -47,7 +47,7 @@ def user_permissions_reader(fab_app): @pytest.fixture def client_permissions_reader(fab_app, user_permissions_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False + fab_app.app.config["WTF_CSRF_ENABLED"] = False return client_with_login( fab_app, username="user_permissions", diff --git a/tests/providers/fab/auth_manager/views/test_roles_list.py b/tests/providers/fab/auth_manager/views/test_roles_list.py index 9631190de42c7a..362ec8b99232c8 100644 --- a/tests/providers/fab/auth_manager/views/test_roles_list.py +++ b/tests/providers/fab/auth_manager/views/test_roles_list.py @@ -22,7 +22,7 @@ from airflow.security import permissions from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user -from tests.test_utils.www import client_with_login +from tests.test_utils.www import flask_client_with_login @pytest.fixture(scope="module") @@ -33,7 +33,7 @@ def fab_app(): @pytest.fixture(scope="module") def user_roles_reader(fab_app): return create_user( - fab_app, + fab_app.app, username="user_roles", role_name="role_roles", permissions=[ @@ -45,8 +45,8 @@ def user_roles_reader(fab_app): @pytest.fixture def client_roles_reader(fab_app, user_roles_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False - return client_with_login( + fab_app.app.config["WTF_CSRF_ENABLED"] = False + return flask_client_with_login( fab_app, username="user_roles_reader", password="user_roles_reader", diff --git a/tests/providers/fab/auth_manager/views/test_user.py b/tests/providers/fab/auth_manager/views/test_user.py index 80c3c59d4d17f5..fb877dd2d471d2 100644 --- a/tests/providers/fab/auth_manager/views/test_user.py +++ b/tests/providers/fab/auth_manager/views/test_user.py @@ -22,7 +22,7 @@ from airflow.security import permissions from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user -from tests.test_utils.www import client_with_login +from tests.test_utils.www import flask_client_with_login @pytest.fixture(scope="module") @@ -33,7 +33,7 @@ def fab_app(): @pytest.fixture(scope="module") def user_user_reader(fab_app): return create_user( - fab_app, + fab_app.app, username="user_user", role_name="role_user", permissions=[ @@ -45,8 +45,8 @@ def user_user_reader(fab_app): @pytest.fixture def client_user_reader(fab_app, user_user_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False - return client_with_login( + fab_app.app.config["WTF_CSRF_ENABLED"] = False + return flask_client_with_login( fab_app, username="user_user_reader", password="user_user_reader", diff --git a/tests/providers/fab/auth_manager/views/test_user_edit.py b/tests/providers/fab/auth_manager/views/test_user_edit.py index 11cc65a5bfc407..738ee816d3ae92 100644 --- a/tests/providers/fab/auth_manager/views/test_user_edit.py +++ b/tests/providers/fab/auth_manager/views/test_user_edit.py @@ -22,7 +22,7 @@ from airflow.security import permissions from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user -from tests.test_utils.www import client_with_login +from tests.test_utils.www import flask_client_with_login @pytest.fixture(scope="module") @@ -33,7 +33,7 @@ def fab_app(): @pytest.fixture(scope="module") def user_user_reader(fab_app): return create_user( - fab_app, + fab_app.app, username="user_user", role_name="role_user", permissions=[ @@ -45,8 +45,8 @@ def user_user_reader(fab_app): @pytest.fixture def client_user_reader(fab_app, user_user_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False - return client_with_login( + fab_app.app.config["WTF_CSRF_ENABLED"] = False + return flask_client_with_login( fab_app, username="user_user_reader", password="user_user_reader", diff --git a/tests/providers/fab/auth_manager/views/test_user_stats.py b/tests/providers/fab/auth_manager/views/test_user_stats.py index 28891621385777..1ac7fe2a0c551d 100644 --- a/tests/providers/fab/auth_manager/views/test_user_stats.py +++ b/tests/providers/fab/auth_manager/views/test_user_stats.py @@ -22,7 +22,7 @@ from airflow.security import permissions from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user -from tests.test_utils.www import client_with_login +from tests.test_utils.www import flask_client_with_login @pytest.fixture(scope="module") @@ -33,7 +33,7 @@ def fab_app(): @pytest.fixture(scope="module") def user_user_stats_reader(fab_app): return create_user( - fab_app, + fab_app.app, username="user_user_stats", role_name="role_user_stats", permissions=[ @@ -45,8 +45,8 @@ def user_user_stats_reader(fab_app): @pytest.fixture def client_user_stats_reader(fab_app, user_user_stats_reader): - fab_app.config["WTF_CSRF_ENABLED"] = False - return client_with_login( + fab_app.app.config["WTF_CSRF_ENABLED"] = False + return flask_client_with_login( fab_app, username="user_user_stats_reader", password="user_user_stats_reader", @@ -56,5 +56,5 @@ def client_user_stats_reader(fab_app, user_user_stats_reader): @pytest.mark.db_test class TestUserStats: def test_user_stats(self, client_user_stats_reader): - resp = client_user_stats_reader.get("/userstatschartview/chart", follow_redirects=True) + resp = client_user_stats_reader.get("/userstatschartview/chart/", follow_redirects=True) assert resp.status_code == 200 diff --git a/tests/providers/google/common/auth_backend/test_google_openid.py b/tests/providers/google/common/auth_backend/test_google_openid.py index d11613b5cf9f3e..2b66e9d5d8e77e 100644 --- a/tests/providers/google/common/auth_backend/test_google_openid.py +++ b/tests/providers/google/common/auth_backend/test_google_openid.py @@ -39,7 +39,7 @@ def google_openid_app(): @pytest.fixture(scope="module") def admin_user(google_openid_app): - appbuilder = google_openid_app.appbuilder + appbuilder = google_openid_app.app.appbuilder role_admin = appbuilder.sm.find_role("Admin") tester = appbuilder.sm.find_user(username="test") if not tester: @@ -58,7 +58,7 @@ def admin_user(google_openid_app): class TestGoogleOpenID: @pytest.fixture(autouse=True) def _set_attrs(self, google_openid_app, admin_user) -> None: - self.app = google_openid_app + self.connexion_app = google_openid_app self.admin_user = admin_user @mock.patch("google.oauth2.id_token.verify_token") @@ -70,7 +70,7 @@ def test_success(self, mock_verify_token): "email": "test@fab.org", } - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get( "/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"} ) @@ -88,7 +88,7 @@ def test_malformed_headers(self, mock_verify_token, auth_header): "email": "test@fab.org", } - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools", headers={"Authorization": auth_header}) assert 403 == response.status_code @@ -102,7 +102,7 @@ def test_invalid_iss_in_jwt_token(self, mock_verify_token): "email": "test@fab.org", } - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get( "/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"} ) @@ -118,7 +118,7 @@ def test_user_not_exists(self, mock_verify_token): "email": "invalid@fab.org", } - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get( "/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"} ) @@ -128,7 +128,7 @@ def test_user_not_exists(self, mock_verify_token): @conf_vars({("api", "auth_backends"): "airflow.providers.google.common.auth_backend.google_openid"}) def test_missing_id_token(self): - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get("/api/experimental/pools") assert 403 == response.status_code @@ -139,7 +139,7 @@ def test_missing_id_token(self): def test_invalid_id_token(self, mock_verify_token): mock_verify_token.side_effect = GoogleAuthError("Invalid token") - with self.app.test_client() as test_client: + with self.connexion_app.app.test_client() as test_client: response = test_client.get( "/api/experimental/pools", headers={"Authorization": "bearer JWT_TOKEN"} ) diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 557e4cf00dea8d..aad2d6191ed1c0 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -1081,8 +1081,8 @@ def test_external_task_sensor_extra_link( assert ti.task.external_task_id == expected_external_task_id assert ti.task.external_task_ids == [expected_external_task_id] - app.config["SERVER_NAME"] = "" - with app.app_context(): + app.app.config["SERVER_NAME"] = "" + with app.app.app_context(): url = ti.task.get_extra_links(ti, "External DAG") assert f"/dags/{expected_external_dag_id}/grid" in url diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py index 791f7ac0baad78..2731b0a601a24b 100644 --- a/tests/test_utils/api_connexion_utils.py +++ b/tests/test_utils/api_connexion_utils.py @@ -121,7 +121,7 @@ def delete_users(app): def assert_401(response): assert response.status_code == 401, f"Current code: {response.status_code}" - assert response.json == { + assert response.json() == { "detail": None, "status": 401, "title": "Unauthorized", diff --git a/tests/test_utils/decorators.py b/tests/test_utils/decorators.py index 5b028c694a8c62..cf382be98f174e 100644 --- a/tests/test_utils/decorators.py +++ b/tests/test_utils/decorators.py @@ -40,7 +40,7 @@ def no_op(*args, **kwargs): "init_api_connexion", "init_api_internal", "init_api_experimental", - "init_api_auth_provider", + "init_api_auth_manager", "init_api_error_handlers", "init_jinja_globals", "init_xframe_protection", diff --git a/tests/test_utils/mock_cors_middeleware.py b/tests/test_utils/mock_cors_middeleware.py new file mode 100644 index 00000000000000..211f46a44639b4 --- /dev/null +++ b/tests/test_utils/mock_cors_middeleware.py @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import connexion + +from airflow.configuration import conf + + +def init_mock_cors_middleware(connexion_app: connexion.FlaskApp, allow_origins: list): + from starlette.middleware.cors import CORSMiddleware + + connexion_app.add_middleware( + CORSMiddleware, + connexion.middleware.MiddlewarePosition.BEFORE_ROUTING, + allow_origins=allow_origins, + allow_credentials=True, + allow_methods=conf.get("api", "access_control_allow_methods"), + allow_headers=conf.get("api", "access_control_allow_headers"), + ) diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py index b7714e5192e6ae..5be8a2bf9da0ae 100644 --- a/tests/test_utils/remote_user_api_auth_backend.py +++ b/tests/test_utils/remote_user_api_auth_backend.py @@ -62,7 +62,7 @@ def requires_authentication(function: T): @wraps(function) def decorated(*args, **kwargs): - user_id = request.remote_user + user_id = request.headers.get("REMOTE-USER") if not user_id: log.debug("Missing REMOTE_USER.") return Response("Forbidden", 403) diff --git a/tests/test_utils/www.py b/tests/test_utils/www.py index 0a19c312fba4eb..d8ff0f1abaf65b 100644 --- a/tests/test_utils/www.py +++ b/tests/test_utils/www.py @@ -23,19 +23,29 @@ from airflow.models import Log -def client_with_login(app, expected_response_code=302, **kwargs): +def client_with_login(app, expected_path=b"/home", **kwargs): patch_path = "airflow.providers.fab.auth_manager.security_manager.override.check_password_hash" with mock.patch(patch_path) as check_password_hash: check_password_hash.return_value = True client = app.test_client() resp = client.post("/login/", data=kwargs) + assert resp.url.raw_path == expected_path + return client + + +def flask_client_with_login(app, expected_response_code=302, **kwargs): + patch_path = "airflow.providers.fab.auth_manager.security_manager.override.check_password_hash" + with mock.patch(patch_path) as check_password_hash: + check_password_hash.return_value = True + client = app.app.test_client() + resp = client.post("/login/", data=kwargs) assert resp.status_code == expected_response_code return client def client_without_login(app): # Anonymous users can only view if AUTH_ROLE_PUBLIC is set to non-Public - app.config["AUTH_ROLE_PUBLIC"] = "Viewer" + app.app.config["AUTH_ROLE_PUBLIC"] = "Viewer" client = app.test_client() return client @@ -48,7 +58,7 @@ def client_without_login_as_admin(app): def check_content_in_response(text, resp, resp_code=200): - resp_html = resp.data.decode("utf-8") + resp_html = resp.text assert resp_code == resp.status_code if isinstance(text, list): for line in text: @@ -58,7 +68,7 @@ def check_content_in_response(text, resp, resp_code=200): def check_content_not_in_response(text, resp, resp_code=200): - resp_html = resp.data.decode("utf-8") + resp_html = resp.text assert resp_code == resp.status_code if isinstance(text, list): for line in text: diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 75b04b14c7b507..6eb51efa1025ef 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -170,12 +170,12 @@ def test_build_airflow_url_with_query(self): """ Test query generated with dag_id and params """ - query = {"dag_id": "test_dag", "param": "key/to.encode"} - expected_url = "/dags/test_dag/graph?param=key%2Fto.encode" + query = {"dag_id": "test_dag", "param": "key to.encode"} + expected_url = "/dags/test_dag/graph?param=key+to.encode" from airflow.www.app import cached_app - with cached_app(testing=True).test_request_context(): + with cached_app(testing=True).app.test_request_context(): assert build_airflow_url_with_query(query) == expected_url @pytest.mark.parametrize( diff --git a/tests/www/api/experimental/conftest.py b/tests/www/api/experimental/conftest.py index 59c6e13357c85f..d2395ea7fe0356 100644 --- a/tests/www/api/experimental/conftest.py +++ b/tests/www/api/experimental/conftest.py @@ -40,10 +40,10 @@ def experiemental_api_app(): ) def factory(): app = application.create_app(testing=True) - app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" - app.config["SECRET_KEY"] = "secret_key" - app.config["CSRF_ENABLED"] = False - app.config["WTF_CSRF_ENABLED"] = False + app.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" + app.app.config["SECRET_KEY"] = "secret_key" + app.app.config["CSRF_ENABLED"] = False + app.app.config["WTF_CSRF_ENABLED"] = False return app return factory() diff --git a/tests/www/api/experimental/test_dag_runs_endpoint.py b/tests/www/api/experimental/test_dag_runs_endpoint.py index 9f4bbf30bc41ef..246e58b0df8fb1 100644 --- a/tests/www/api/experimental/test_dag_runs_endpoint.py +++ b/tests/www/api/experimental/test_dag_runs_endpoint.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import json - import pytest from airflow.api.common.trigger_dag import trigger_dag @@ -59,7 +57,7 @@ def test_get_dag_runs_success(self): response = self.app.get(url_template.format(dag_id)) assert 200 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert isinstance(data, list) assert len(data) == 1 @@ -74,7 +72,7 @@ def test_get_dag_runs_success_with_state_parameter(self): response = self.app.get(url_template.format(dag_id)) assert 200 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert isinstance(data, list) assert len(data) == 1 @@ -89,7 +87,7 @@ def test_get_dag_runs_success_with_capital_state_parameter(self): response = self.app.get(url_template.format(dag_id)) assert 200 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert isinstance(data, list) assert len(data) == 1 @@ -102,8 +100,8 @@ def test_get_dag_runs_success_with_state_no_result(self): # Create DagRun trigger_dag(dag_id=dag_id, run_id="test_get_dag_runs_success") - with pytest.raises(ValueError): - self.app.get(url_template.format(dag_id)) + resp = self.app.get(url_template.format(dag_id)) + assert 500 == resp.status_code def test_get_dag_runs_invalid_dag_id(self): url_template = "/api/experimental/dags/{}/dag_runs" @@ -111,7 +109,7 @@ def test_get_dag_runs_invalid_dag_id(self): response = self.app.get(url_template.format(dag_id)) assert 400 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert not isinstance(data, list) @@ -121,7 +119,7 @@ def test_get_dag_runs_no_runs(self): response = self.app.get(url_template.format(dag_id)) assert 200 == response.status_code - data = json.loads(response.data.decode("utf-8")) + data = response.json() assert isinstance(data, list) assert len(data) == 0 diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py index d78bc8fb37232e..c7ac0abe5e0c7a 100644 --- a/tests/www/api/experimental/test_endpoints.py +++ b/tests/www/api/experimental/test_endpoints.py @@ -53,7 +53,7 @@ class TestBase: @pytest.fixture(autouse=True) def _setup_attrs_base(self, experiemental_api_app, configured_session): self.app = experiemental_api_app - self.appbuilder = self.app.appbuilder + self.appbuilder = self.app.app.appbuilder self.client = self.app.test_client() self.session = configured_session @@ -92,7 +92,7 @@ def test_info(self): url = "/api/experimental/info" resp_raw = self.client.get(url) - resp = json.loads(resp_raw.data.decode("utf-8")) + resp = resp_raw.json() assert version == resp["version"] self.assert_deprecated(resp_raw) @@ -103,16 +103,16 @@ def test_task_info(self): response = self.client.get(url_template.format("example_bash_operator", "runme_0")) self.assert_deprecated(response) - assert '"email"' in response.data.decode("utf-8") - assert "error" not in response.data.decode("utf-8") + assert '"email"' in response.text + assert "error" not in response.json() assert 200 == response.status_code response = self.client.get(url_template.format("example_bash_operator", "does-not-exist")) - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() assert 404 == response.status_code response = self.client.get(url_template.format("does-not-exist", "does-not-exist")) - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() assert 404 == response.status_code def test_get_dag_code(self): @@ -120,7 +120,7 @@ def test_get_dag_code(self): response = self.client.get(url_template.format("example_bash_operator")) self.assert_deprecated(response) - assert "BashOperator(" in response.data.decode("utf-8") + assert "BashOperator(" in response.text assert 200 == response.status_code response = self.client.get(url_template.format("xyz")) @@ -133,22 +133,22 @@ def test_dag_paused(self): response = self.client.get(pause_url_template.format("example_bash_operator", "true")) self.assert_deprecated(response) - assert "ok" in response.data.decode("utf-8") + assert "ok" == response.json()["response"] assert 200 == response.status_code paused_response = self.client.get(paused_url) assert 200 == paused_response.status_code - assert {"is_paused": True} == paused_response.json + assert {"is_paused": True} == paused_response.json() response = self.client.get(pause_url_template.format("example_bash_operator", "false")) - assert "ok" in response.data.decode("utf-8") + assert "ok" in response.text assert 200 == response.status_code paused_response = self.client.get(paused_url) assert 200 == paused_response.status_code - assert {"is_paused": False} == paused_response.json + assert {"is_paused": False} == paused_response.json() def test_trigger_dag(self): url_template = "/api/experimental/dags/{}/dag_runs" @@ -156,7 +156,8 @@ def test_trigger_dag(self): # Test error for nonexistent dag response = self.client.post( - url_template.format("does_not_exist_dag"), data=json.dumps({}), content_type="application/json" + url_template.format("does_not_exist_dag"), + data=json.dumps({}), ) assert 404 == response.status_code @@ -164,7 +165,6 @@ def test_trigger_dag(self): response = self.client.post( url_template.format("example_bash_operator"), data=json.dumps({"conf": "This is a string not a dict"}), - content_type="application/json", ) assert 400 == response.status_code @@ -172,16 +172,15 @@ def test_trigger_dag(self): response = self.client.post( url_template.format("example_bash_operator"), data=json.dumps({"run_id": run_id, "conf": {"param": "value"}}), - content_type="application/json", ) self.assert_deprecated(response) assert 200 == response.status_code - response_execution_date = parse_datetime(json.loads(response.data.decode("utf-8"))["execution_date"]) + response_execution_date = parse_datetime(response.json()["execution_date"]) assert 0 == response_execution_date.microsecond # Check execution_date is correct - response = json.loads(response.data.decode("utf-8")) + response = response.json() dagbag = DagBag() dag = dagbag.get_dag("example_bash_operator") dag_run = dag.get_dagrun(response_execution_date) @@ -199,11 +198,10 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format(dag_id), data=json.dumps({"execution_date": datetime_string}), - content_type="application/json", ) self.assert_deprecated(response) assert 200 == response.status_code - assert datetime_string == json.loads(response.data.decode("utf-8"))["execution_date"] + assert datetime_string == response.json()["execution_date"] dagbag = DagBag() dag = dagbag.get_dag(dag_id) @@ -214,10 +212,9 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format(dag_id), data=json.dumps({"execution_date": datetime_string, "replace_microseconds": "true"}), - content_type="application/json", ) assert 200 == response.status_code - response_execution_date = parse_datetime(json.loads(response.data.decode("utf-8"))["execution_date"]) + response_execution_date = parse_datetime(response.json()["execution_date"]) assert 0 == response_execution_date.microsecond dagbag = DagBag() @@ -229,7 +226,6 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format("does_not_exist_dag"), data=json.dumps({"execution_date": datetime_string}), - content_type="application/json", ) assert 404 == response.status_code @@ -237,7 +233,6 @@ def test_trigger_dag_for_date(self): response = self.client.post( url_template.format(dag_id), data=json.dumps({"execution_date": "not_a_datetime"}), - content_type="application/json", ) assert 400 == response.status_code @@ -256,30 +251,30 @@ def test_task_instance_info(self): response = self.client.get(url_template.format(dag_id, datetime_string, task_id)) self.assert_deprecated(response) assert 200 == response.status_code - assert "state" in response.data.decode("utf-8") - assert "error" not in response.data.decode("utf-8") + assert "state" in response.json() + assert "error" not in response.json() # Test error for nonexistent dag response = self.client.get( url_template.format("does_not_exist_dag", datetime_string, task_id), ) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for nonexistent task response = self.client.get(url_template.format(dag_id, datetime_string, "does_not_exist_task")) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for nonexistent dag run (wrong execution_date) response = self.client.get(url_template.format(dag_id, wrong_datetime_string, task_id)) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for bad datetime format response = self.client.get(url_template.format(dag_id, "not_a_datetime", task_id)) assert 400 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() def test_dagrun_status(self): url_template = "/api/experimental/dags/{}/dag_runs/{}" @@ -295,25 +290,25 @@ def test_dagrun_status(self): response = self.client.get(url_template.format(dag_id, datetime_string)) self.assert_deprecated(response) assert 200 == response.status_code - assert "state" in response.data.decode("utf-8") - assert "error" not in response.data.decode("utf-8") + assert "state" in response.json() + assert "error" not in response.json() # Test error for nonexistent dag response = self.client.get( url_template.format("does_not_exist_dag", datetime_string), ) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for nonexistent dag run (wrong execution_date) response = self.client.get(url_template.format(dag_id, wrong_datetime_string)) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for bad datetime format response = self.client.get(url_template.format(dag_id, "not_a_datetime")) assert 400 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() class TestLineageApiExperimental(TestBase): @@ -354,25 +349,25 @@ def test_lineage_info(self): response = self.client.get(url_template.format(dag_id, datetime_string)) self.assert_deprecated(response) assert 200 == response.status_code - assert "task_ids" in response.data.decode("utf-8") - assert "error" not in response.data.decode("utf-8") + assert "task_ids" in response.json() + assert "error" not in response.json() # Test error for nonexistent dag response = self.client.get( url_template.format("does_not_exist_dag", datetime_string), ) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for nonexistent dag run (wrong execution_date) response = self.client.get(url_template.format(dag_id, wrong_datetime_string)) assert 404 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() # Test error for bad datetime format response = self.client.get(url_template.format(dag_id, "not_a_datetime")) assert 400 == response.status_code - assert "error" in response.data.decode("utf-8") + assert "error" in response.json() class TestPoolApiExperimental(TestBase): @@ -399,7 +394,7 @@ def _setup_attrs(self, _setup_attrs_base): def _get_pool_count(self): response = self.client.get("/api/experimental/pools") assert response.status_code == 200 - return len(json.loads(response.data.decode("utf-8"))) + return len(response.json()) def test_get_pool(self): response = self.client.get( @@ -407,18 +402,18 @@ def test_get_pool(self): ) self.assert_deprecated(response) assert response.status_code == 200 - assert json.loads(response.data.decode("utf-8")) == self.pool.to_json() + assert response.json() == self.pool.to_json() def test_get_pool_non_existing(self): response = self.client.get("/api/experimental/pools/foo") assert response.status_code == 404 - assert json.loads(response.data.decode("utf-8"))["error"] == "Pool 'foo' doesn't exist" + assert response.json()["error"] == "Pool 'foo' doesn't exist" def test_get_pools(self): response = self.client.get("/api/experimental/pools") self.assert_deprecated(response) assert response.status_code == 200 - pools = json.loads(response.data.decode("utf-8")) + pools = response.json() assert len(pools) == self.TOTAL_POOL_COUNT for i, pool in enumerate(sorted(pools, key=lambda p: p["pool"])): assert pool == self.pools[i].to_json() @@ -433,11 +428,10 @@ def test_create_pool(self): "description": "", } ), - content_type="application/json", ) self.assert_deprecated(response) assert response.status_code == 200 - pool = json.loads(response.data.decode("utf-8")) + pool = response.json() assert pool["pool"] == "foo" assert pool["slots"] == 1 assert pool["description"] == "" @@ -455,10 +449,9 @@ def test_create_pool_with_bad_name(self): "description": "", } ), - content_type="application/json", ) assert response.status_code == 400 - assert json.loads(response.data.decode("utf-8"))["error"] == "Pool name shouldn't be empty" + assert response.json()["error"] == "Pool name shouldn't be empty" assert self._get_pool_count() == self.TOTAL_POOL_COUNT def test_delete_pool(self): @@ -467,7 +460,7 @@ def test_delete_pool(self): ) self.assert_deprecated(response) assert response.status_code == 200 - assert json.loads(response.data.decode("utf-8")) == self.pool.to_json() + assert response.json() == self.pool.to_json() assert self._get_pool_count() == self.TOTAL_POOL_COUNT - 1 def test_delete_pool_non_existing(self): @@ -475,7 +468,7 @@ def test_delete_pool_non_existing(self): "/api/experimental/pools/foo", ) assert response.status_code == 404 - assert json.loads(response.data.decode("utf-8"))["error"] == "Pool 'foo' doesn't exist" + assert response.json()["error"] == "Pool 'foo' doesn't exist" def test_delete_default_pool(self): clear_db_pools() @@ -483,4 +476,4 @@ def test_delete_default_pool(self): "/api/experimental/pools/default_pool", ) assert response.status_code == 400 - assert json.loads(response.data.decode("utf-8"))["error"] == "default_pool cannot be deleted" + assert response.json()["error"] == "default_pool cannot be deleted" diff --git a/tests/www/test_app.py b/tests/www/test_app.py index 1e7bd67c9ae042..05f91d7868122d 100644 --- a/tests/www/test_app.py +++ b/tests/www/test_app.py @@ -54,8 +54,8 @@ def setup_class(cls) -> None: ) @dont_initialize_flask_app_submodules def test_should_respect_proxy_fix(self): - app = application.cached_app(testing=True) - app.url_map.add(Rule("/debug", endpoint="debug")) + flask_app = application.cached_app(testing=True).app + flask_app.url_map.add(Rule("/debug", endpoint="debug")) def debug_view(): from flask import request @@ -68,7 +68,7 @@ def debug_view(): return Response("success") - app.view_functions["debug"] = debug_view + flask_app.view_functions["debug"] = debug_view new_environ = { "PATH_INFO": "/debug", @@ -82,7 +82,7 @@ def debug_view(): } environ = create_environ(environ_overrides=new_environ) - response = Response.from_app(app, environ) + response = Response.from_app(flask_app, environ) assert b"success" == response.get_data() assert response.status_code == 200 @@ -90,7 +90,7 @@ def debug_view(): @dont_initialize_flask_app_submodules def test_should_respect_base_url_ignore_proxy_headers(self): with conf_vars({("webserver", "base_url"): "http://localhost:8080/internal-client"}): - app = application.cached_app(testing=True) + app = application.cached_app(testing=True).app app.url_map.add(Rule("/debug", endpoint="debug")) def debug_view(): @@ -144,7 +144,7 @@ def test_base_url_contains_trailing_slash(self): @dont_initialize_flask_app_submodules def test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(self): with conf_vars({("webserver", "base_url"): "http://localhost:8080/internal-client"}): - app = application.cached_app(testing=True) + app = application.cached_app(testing=True).app app.url_map.add(Rule("/debug", endpoint="debug")) def debug_view(): @@ -184,7 +184,7 @@ def debug_view(): ) @dont_initialize_flask_app_submodules def test_should_respect_base_url_and_proxy_when_proxy_fix_and_base_url_is_set_up(self): - app = application.cached_app(testing=True) + app = application.cached_app(testing=True).app app.url_map.add(Rule("/debug", endpoint="debug")) def debug_view(): @@ -224,16 +224,16 @@ def debug_view(): ) @dont_initialize_flask_app_submodules def test_should_set_permanent_session_timeout(self): - app = application.cached_app(testing=True) - assert app.config["PERMANENT_SESSION_LIFETIME"] == timedelta(minutes=3600) + flask_app = application.cached_app(testing=True).app + assert flask_app.config["PERMANENT_SESSION_LIFETIME"] == timedelta(minutes=3600) @conf_vars({("webserver", "cookie_samesite"): ""}) @dont_initialize_flask_app_submodules def test_correct_default_is_set_for_cookie_samesite(self): """An empty 'cookie_samesite' should be corrected to 'Lax' with a deprecation warning.""" with pytest.deprecated_call(): - app = application.cached_app(testing=True) - assert app.config["SESSION_COOKIE_SAMESITE"] == "Lax" + flask_app = application.cached_app(testing=True).app + assert flask_app.config["SESSION_COOKIE_SAMESITE"] == "Lax" @pytest.mark.parametrize( "hash_method, result", @@ -250,7 +250,7 @@ def test_correct_default_is_set_for_cookie_samesite(self): @dont_initialize_flask_app_submodules(skip_all_except=["init_auth_manager"]) def test_should_respect_caching_hash_method(self, hash_method, result): with conf_vars({("webserver", "caching_hash_method"): hash_method}): - app = application.cached_app(testing=True) + app = application.cached_app(testing=True).app assert next(iter(app.extensions["cache"])).cache._hash_method == result @dont_initialize_flask_app_submodules @@ -263,7 +263,7 @@ def test_should_respect_caching_hash_method_invalid(self): class TestFlaskCli: @dont_initialize_flask_app_submodules(skip_all_except=["init_appbuilder"]) def test_flask_cli_should_display_routes(self, capsys): - with mock.patch.dict("os.environ", FLASK_APP="airflow.www.app:cached_app"), mock.patch.object( + with mock.patch.dict("os.environ", FLASK_APP="airflow.www.app:cached_flask_app"), mock.patch.object( sys, "argv", ["flask", "routes"] ): # Import from flask.__main__ with a combination of mocking With mocking sys.argv @@ -282,5 +282,5 @@ def test_app_can_json_serialize_k8s_pod(): k8s = pytest.importorskip("kubernetes.client.models") pod = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")])) - app = application.cached_app(testing=True) - assert app.json.dumps(pod) == '{"spec": {"containers": [{"name": "base"}]}}' + flask_app = application.cached_app(testing=True).app + assert flask_app.json.dumps(pod) == '{"spec": {"containers": [{"name": "base"}]}}' diff --git a/tests/www/test_auth.py b/tests/www/test_auth.py index f21973a8b67829..0c67aa40c15f2b 100644 --- a/tests/www/test_auth.py +++ b/tests/www/test_auth.py @@ -101,7 +101,7 @@ def test_has_access_no_details_when_not_logged_in( auth_manager.get_url_login.return_value = "login_url" mock_get_auth_manager.return_value = auth_manager - with app.test_request_context(): + with app.app.test_request_context(): result = getattr(auth, decorator_name)("GET")(self.method_test)() mock_call.assert_not_called() @@ -171,7 +171,7 @@ def test_has_access_with_details_when_unauthorized( setattr(auth_manager, is_authorized_method_name, is_authorized_method) mock_get_auth_manager.return_value = auth_manager - with app.test_request_context(): + with app.app.test_request_context(): result = getattr(auth, decorator_name)("GET")(self.method_test)(None, items) mock_call.assert_not_called() @@ -215,7 +215,7 @@ def test_has_access_dag_entities_when_unauthorized(self, mock_get_auth_manager, mock_get_auth_manager.return_value = auth_manager items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")] - with app.test_request_context(): + with app.app.test_request_context(): result = auth.has_access_dag_entities("GET", dag_access_entity)(self.method_test)(None, items) mock_call.assert_not_called() @@ -231,7 +231,7 @@ def test_has_access_dag_entities_when_logged_out(self, mock_get_auth_manager, ap mock_get_auth_manager.return_value = auth_manager items = [Mock(dag_id="dag_1"), Mock(dag_id="dag_2")] - with app.test_request_context(): + with app.app.test_request_context(): result = auth.has_access_dag_entities("GET", dag_access_entity)(self.method_test)(None, items) mock_call.assert_not_called() diff --git a/tests/www/test_security_manager.py b/tests/www/test_security_manager.py index 81a05e5fd063eb..24e6f014f6d667 100644 --- a/tests/www/test_security_manager.py +++ b/tests/www/test_security_manager.py @@ -39,7 +39,7 @@ def app(): @pytest.fixture def app_builder(app): - return app.appbuilder + return app.app.appbuilder @pytest.fixture diff --git a/tests/www/test_utils.py b/tests/www/test_utils.py index dfd8b563dc4156..1fa88e8abf0a74 100644 --- a/tests/www/test_utils.py +++ b/tests/www/test_utils.py @@ -163,14 +163,14 @@ def test_state_token(self): def test_task_instance_link(self): from airflow.www.app import cached_app - with cached_app(testing=True).test_request_context(): + with cached_app(testing=True).app.test_request_context(): html = str( utils.task_instance_link( {"dag_id": "", "task_id": "", "execution_date": datetime.now()} ) ) - assert "%3Ca%261%3E" in html + assert "%3Ca&1%3E" in html assert "%3Cb2%3E" in html assert "" not in html assert "" not in html @@ -179,10 +179,10 @@ def test_task_instance_link(self): def test_dag_link(self): from airflow.www.app import cached_app - with cached_app(testing=True).test_request_context(): + with cached_app(testing=True).app.test_request_context(): html = str(utils.dag_link({"dag_id": "", "execution_date": datetime.now()})) - assert "%3Ca%261%3E" in html + assert "%3Ca&1%3E" in html assert "" not in html @pytest.mark.db_test @@ -190,7 +190,7 @@ def test_dag_link_when_dag_is_none(self): """Test that when there is no dag_id, dag_link does not contain hyperlink""" from airflow.www.app import cached_app - with cached_app(testing=True).test_request_context(): + with cached_app(testing=True).app.test_request_context(): html = str(utils.dag_link({})) assert "None" in html @@ -200,12 +200,12 @@ def test_dag_link_when_dag_is_none(self): def test_dag_run_link(self): from airflow.www.app import cached_app - with cached_app(testing=True).test_request_context(): + with cached_app(testing=True).app.test_request_context(): html = str( utils.dag_run_link({"dag_id": "", "run_id": "", "execution_date": datetime.now()}) ) - assert "%3Ca%261%3E" in html + assert "%3Ca&1%3E" in html assert "%3Cb2%3E" in html assert "" not in html assert "" not in html diff --git a/tests/www/views/conftest.py b/tests/www/views/conftest.py index 821f541ef0c43f..46fda4c387a5f0 100644 --- a/tests/www/views/conftest.py +++ b/tests/www/views/conftest.py @@ -30,7 +30,12 @@ from tests.test_utils.api_connexion_utils import delete_user from tests.test_utils.config import conf_vars from tests.test_utils.decorators import dont_initialize_flask_app_submodules -from tests.test_utils.www import client_with_login, client_without_login, client_without_login_as_admin +from tests.test_utils.www import ( + client_with_login, + client_without_login, + client_without_login_as_admin, + flask_client_with_login, +) @pytest.fixture(autouse=True, scope="module") @@ -52,6 +57,7 @@ def app(examples_dag_bag): @dont_initialize_flask_app_submodules( skip_all_except=[ "init_api_connexion", + "init_api_error_handlers", "init_appbuilder", "init_appbuilder_links", "init_appbuilder_views", @@ -67,11 +73,11 @@ def factory(): return create_app(testing=True) app = factory() - app.config["WTF_CSRF_ENABLED"] = False - app.dag_bag = examples_dag_bag - app.jinja_env.undefined = jinja2.StrictUndefined + app.app.config["WTF_CSRF_ENABLED"] = False + app.app.dag_bag = examples_dag_bag + app.app.jinja_env.undefined = jinja2.StrictUndefined - security_manager = app.appbuilder.sm + security_manager = app.app.appbuilder.sm test_users = [ { @@ -107,7 +113,7 @@ def factory(): yield app for user_dict in test_users: - delete_user(app, user_dict["username"]) + delete_user(app.app, user_dict["username"]) @pytest.fixture @@ -115,6 +121,11 @@ def admin_client(app): return client_with_login(app, username="test_admin", password="test_admin") +@pytest.fixture +def flask_admin_client(app): + return flask_client_with_login(app, username="test_admin", password="test_admin") + + @pytest.fixture def viewer_client(app): return client_with_login(app, username="test_viewer", password="test_viewer") @@ -125,6 +136,11 @@ def user_client(app): return client_with_login(app, username="test_user", password="test_user") +@pytest.fixture +def flask_user_client(app): + return flask_client_with_login(app, username="test_user", password="test_user") + + @pytest.fixture def anonymous_client(app): return client_without_login(app) @@ -132,7 +148,12 @@ def anonymous_client(app): @pytest.fixture def anonymous_client_as_admin(app): - return client_without_login_as_admin(app) + return client_without_login_as_admin(app.app) + + +@pytest.fixture +def admin_flask_client(app): + return flask_client_with_login(app, username="test_admin", password="test_admin") class _TemplateWithContext(NamedTuple): @@ -198,11 +219,11 @@ def manager() -> Generator[list[_TemplateWithContext], None, None]: def record(sender, template, context, **extra): recorded.append(_TemplateWithContext(template, context)) - flask.template_rendered.connect(record, app) # type: ignore + flask.template_rendered.connect(record, app.app) # type: ignore try: yield recorded finally: - flask.template_rendered.disconnect(record, app) # type: ignore + flask.template_rendered.disconnect(record, app.app) # type: ignore assert recorded, "Failed to catch the templates" diff --git a/tests/www/views/test_anonymous_as_admin_role.py b/tests/www/views/test_anonymous_as_admin_role.py index b7603d1eae5bba..64ce1b1a42592f 100644 --- a/tests/www/views/test_anonymous_as_admin_role.py +++ b/tests/www/views/test_anonymous_as_admin_role.py @@ -55,8 +55,9 @@ def factory(**values): def test_delete_pool_anonymous_user_no_role(anonymous_client, pool_factory): pool = pool_factory() resp = anonymous_client.post(f"pool/delete/{pool.id}") - assert 302 == resp.status_code - assert f"/login/?next={quote_plus(f'http://localhost/pool/delete/{pool.id}')}" == resp.headers["Location"] + expected_path = f"/login/?next={quote_plus(f'http://testserver/pool/delete/{pool.id}', safe='/:?')}" + assert expected_path.encode("utf-8") == resp.url.raw_path + assert 200 == resp.status_code def test_delete_pool_anonymous_user_as_admin(anonymous_client_as_admin, pool_factory): diff --git a/tests/www/views/test_session.py b/tests/www/views/test_session.py index aeb9c0ffeeeb9e..f483990af3b463 100644 --- a/tests/www/views/test_session.py +++ b/tests/www/views/test_session.py @@ -29,7 +29,7 @@ def get_session_cookie(client): - return next((cookie for cookie in client.cookie_jar if cookie.name == "session"), None) + return next((cookie for cookie in client.cookies.jar if cookie.name == "session"), None) def test_session_cookie_created_on_login(user_client): @@ -40,13 +40,25 @@ def test_session_inaccessible_after_logout(user_client): session_cookie = get_session_cookie(user_client) assert session_cookie is not None + # correctly logs in + resp = user_client.get("/home") + assert resp.status_code == 200 + assert resp.url.raw_path == b"/home" + + # Same with cookies overwritten + user_client.get("/home", cookies={"session": session_cookie.value}) + assert resp.status_code == 200 + assert resp.url.raw_path == b"/home" + + # logs out resp = user_client.get("/logout/") - assert resp.status_code == 302 + assert resp.status_code == 200 + assert resp.url.raw_path == b"/login/?next=http://testserver/home" - # Try to access /home with the session cookie from earlier - user_client.set_cookie("session", session_cookie.value) - user_client.get("/home/") - assert resp.status_code == 302 + # Try to access /home with the session cookie from earlier call + user_client.get("/home", cookies={"session": session_cookie.value}) + assert resp.status_code == 200 + assert resp.url.raw_path == b"/login/?next=http://testserver/home" def test_invalid_session_backend_option(): @@ -78,14 +90,16 @@ def test_session_id_rotates(app, user_client): old_session_cookie = get_session_cookie(user_client) assert old_session_cookie is not None - resp = user_client.get("/logout/") - assert resp.status_code == 302 + resp = user_client.get("/logout/", follow_redirects=True) + assert resp.status_code == 200 patch_path = "airflow.providers.fab.auth_manager.security_manager.override.check_password_hash" with mock.patch(patch_path) as check_password_hash: check_password_hash.return_value = True - resp = user_client.post("/login/", data={"username": "test_user", "password": "test_user"}) - assert resp.status_code == 302 + resp = user_client.post( + "/login/", data={"username": "test_user", "password": "test_user"}, follow_redirects=True + ) + assert resp.status_code == 200 new_session_cookie = get_session_cookie(user_client) assert new_session_cookie is not None @@ -93,17 +107,16 @@ def test_session_id_rotates(app, user_client): def test_check_active_user(app, user_client): - user = app.appbuilder.sm.find_user(username="test_user") + user = app.app.appbuilder.sm.find_user(username="test_user") user.active = False resp = user_client.get("/home") - assert resp.status_code == 302 - assert "/login/?next=http%3A%2F%2Flocalhost%2Fhome" in resp.headers.get("Location") + assert resp.url.raw_path == b"/home" -def test_check_deactivated_user_redirected_to_login(app, user_client): - with app.test_request_context(): - user = app.appbuilder.sm.find_user(username="test_user") +def test_check_deactivated_user_redirected_to_login(app, flask_user_client): + with app.app.test_request_context(): + user = app.app.appbuilder.sm.find_user(username="test_user") user.active = False - resp = user_client.get("/home", follow_redirects=True) + resp = flask_user_client.get("/home", follow_redirects=True) assert resp.status_code == 200 assert "/login" in resp.request.url diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py index 27f096403f05d7..d9f1f65bd4310c 100644 --- a/tests/www/views/test_views.py +++ b/tests/www/views/test_views.py @@ -226,7 +226,7 @@ def test_task_dag_id_equals_filter(admin_client, url, content): @mock.patch("airflow.www.views.url_for") def test_get_safe_url(mock_url_for, app, test_url, expected_url): mock_url_for.return_value = "/home" - with app.test_request_context(base_url="http://localhost:8080"): + with app.app.test_request_context(base_url="http://localhost:8080"): assert get_safe_url(test_url) == expected_url @@ -294,10 +294,10 @@ def get_task_instance(session, task): session.commit() - test_app.dag_bag = DagBag(dag_folder="/dev/null", include_examples=False) - test_app.dag_bag.bag_dag(dag=dag, root_dag=dag) + test_app.app.dag_bag = DagBag(dag_folder="/dev/null", include_examples=False) + test_app.app.dag_bag.bag_dag(dag=dag, root_dag=dag) - with test_app.test_request_context(): + with test_app.app.test_request_context(): view = Airflow() view._mark_task_instance_state( @@ -396,10 +396,10 @@ def get_task_instance(session, task): session.commit() - test_app.dag_bag = DagBag(dag_folder="/dev/null", include_examples=False) - test_app.dag_bag.bag_dag(dag=dag, root_dag=dag) + test_app.app.dag_bag = DagBag(dag_folder="/dev/null", include_examples=False) + test_app.app.dag_bag.bag_dag(dag=dag, root_dag=dag) - with test_app.test_request_context(): + with test_app.app.test_request_context(): view = Airflow() view._mark_task_group_state( @@ -483,7 +483,9 @@ def test_get_task_stats_from_query(): assert data == expected_data -INVALID_DATETIME_RESPONSE = re.compile(r"Invalid datetime: &#x?\d+;invalid&#x?\d+;") +# After upgrading to connexion v3, test client returns JSON response instead of HTML response. +# Returned JSON does not contain the previous pattern. +INVALID_DATETIME_RESPONSE = re.compile(r"Invalid datetime: 'invalid'") @pytest.mark.parametrize( @@ -522,6 +524,5 @@ def test_get_task_stats_from_query(): def test_invalid_dates(app, admin_client, url, content): """Test invalid date format doesn't crash page.""" resp = admin_client.get(url, follow_redirects=True) - assert resp.status_code == 400 - assert re.search(content, resp.get_data().decode()) + assert re.search(content, resp.json()["detail"]) diff --git a/tests/www/views/test_views_acl.py b/tests/www/views/test_views_acl.py index ead809e081c508..63b536e6510a3e 100644 --- a/tests/www/views/test_views_acl.py +++ b/tests/www/views/test_views_acl.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import json import urllib.parse import pytest @@ -32,7 +31,12 @@ from airflow.www.views import FILTER_STATUS_COOKIE from tests.test_utils.api_connexion_utils import create_user_scope from tests.test_utils.db import clear_db_runs -from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login +from tests.test_utils.www import ( + check_content_in_response, + check_content_not_in_response, + client_with_login, + flask_client_with_login, +) pytestmark = pytest.mark.db_test @@ -81,7 +85,7 @@ @pytest.fixture(scope="module") def acl_app(app): - security_manager = app.appbuilder.sm + security_manager = app.app.appbuilder.sm for username, (role_name, kwargs) in USER_DATA.items(): if not security_manager.find_user(username=username): role = security_manager.add_role(role_name) @@ -138,7 +142,7 @@ def reset_dagruns(): @pytest.fixture(autouse=True) def init_dagruns(acl_app, reset_dagruns): - acl_app.dag_bag.get_dag("example_bash_operator").create_dagrun( + acl_app.app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id=DEFAULT_RUN_ID, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -146,7 +150,7 @@ def init_dagruns(acl_app, reset_dagruns): start_date=timezone.utcnow(), state=State.RUNNING, ) - acl_app.dag_bag.get_dag("example_subdag_operator").create_dagrun( + acl_app.app.dag_bag.get_dag("example_subdag_operator").create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, start_date=timezone.utcnow(), @@ -159,7 +163,9 @@ def init_dagruns(acl_app, reset_dagruns): @pytest.fixture def dag_test_client(acl_app): - return client_with_login(acl_app, username="dag_test", password="dag_test") + return client_with_login( + acl_app, expected_path=b"/login/?next=/home", username="dag_test", password="dag_test" + ) @pytest.fixture @@ -179,7 +185,7 @@ def all_dag_user_client(acl_app): @pytest.fixture(scope="module") def user_edit_one_dag(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_edit_one_dag", role_name="role_edit_one_dag", permissions=[ @@ -192,8 +198,8 @@ def user_edit_one_dag(acl_app): @pytest.mark.usefixtures("user_edit_one_dag") def test_permission_exist(acl_app): - perms_views = acl_app.appbuilder.sm.get_resource_permissions( - acl_app.appbuilder.sm.get_resource("DAG:example_bash_operator"), + perms_views = acl_app.app.appbuilder.sm.get_resource_permissions( + acl_app.app.appbuilder.sm.get_resource("DAG:example_bash_operator"), ) assert len(perms_views) == 3 @@ -205,7 +211,7 @@ def test_permission_exist(acl_app): @pytest.mark.usefixtures("user_edit_one_dag") def test_role_permission_associate(acl_app): - test_role = acl_app.appbuilder.sm.find_role("role_edit_one_dag") + test_role = acl_app.app.appbuilder.sm.find_role("role_edit_one_dag") perms = {str(perm) for perm in test_role.permissions} assert "can edit on DAG:example_bash_operator" in perms assert "can read on DAG:example_bash_operator" in perms @@ -214,7 +220,7 @@ def test_role_permission_associate(acl_app): @pytest.fixture(scope="module") def user_all_dags(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags", role_name="role_all_dags", permissions=[ @@ -234,6 +240,15 @@ def client_all_dags(acl_app, user_all_dags): ) +@pytest.fixture +def flask_client_all_dags(acl_app, user_all_dags): + return flask_client_with_login( + acl_app, + username="user_all_dags", + password="user_all_dags", + ) + + def test_index_for_all_dag_user(client_all_dags): # The all dag user can access/view all dags. resp = client_all_dags.get("/", follow_redirects=True) @@ -261,7 +276,7 @@ def test_dag_autocomplete_success(client_all_dags): {"name": "tutorial_taskflow_api_virtualenv", "type": "dag"}, ] - assert resp.json == expected + assert resp.json() == expected @pytest.mark.parametrize( @@ -278,7 +293,7 @@ def test_dag_autocomplete_empty(client_all_dags, query, expected): if query is not None: url = f"{url}?query={query}" resp = client_all_dags.get(url, follow_redirects=False) - assert resp.json == expected + assert resp.json() == expected @pytest.fixture @@ -300,10 +315,11 @@ def setup_paused_dag(): ], ) @pytest.mark.usefixtures("setup_paused_dag") -def test_dag_autocomplete_status(client_all_dags, status, expected, unexpected): - with client_all_dags.session_transaction() as flask_session: +def test_dag_autocomplete_status(flask_client_all_dags, status, expected, unexpected): + with flask_client_all_dags.session_transaction() as flask_session: flask_session[FILTER_STATUS_COOKIE] = status - resp = client_all_dags.get( + + resp = flask_client_all_dags.get( "dagmodel/autocomplete?query=example_branch_", follow_redirects=False, ) @@ -314,7 +330,7 @@ def test_dag_autocomplete_status(client_all_dags, status, expected, unexpected): @pytest.fixture(scope="module") def user_all_dags_dagruns(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_dagruns", role_name="role_all_dags_dagruns", permissions=[ @@ -338,7 +354,7 @@ def client_all_dags_dagruns(acl_app, user_all_dags_dagruns): def test_dag_stats_success(client_all_dags_dagruns): resp = client_all_dags_dagruns.post("dag_stats", follow_redirects=True) check_content_in_response("example_bash_operator", resp) - assert set(next(iter(resp.json.items()))[1][0].keys()) == {"state", "count"} + assert set(next(iter(resp.json().items()))[1][0].keys()) == {"state", "count"} def test_task_stats_failure(dag_test_client): @@ -355,7 +371,7 @@ def test_dag_stats_success_for_all_dag_user(client_all_dags_dagruns): @pytest.fixture(scope="module") def user_all_dags_dagruns_tis(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_dagruns_tis", role_name="role_all_dags_dagruns_tis", permissions=[ @@ -408,7 +424,7 @@ def test_task_stats_success( assert resp.status_code == 200 for dag_id in unexpected_dag_ids: check_content_not_in_response(dag_id, resp) - stats = json.loads(resp.data.decode()) + stats = resp.json() for dag_id in dags_to_run: assert dag_id in stats @@ -416,7 +432,7 @@ def test_task_stats_success( @pytest.fixture(scope="module") def user_all_dags_codes(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_codes", role_name="role_all_dags_codes", permissions=[ @@ -472,7 +488,7 @@ def test_dag_details_success_for_all_dag_user(client_all_dags_dagruns, dag_id): @pytest.fixture(scope="module") def user_all_dags_tis(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_tis", role_name="role_all_dags_tis", permissions=[ @@ -497,7 +513,7 @@ def client_all_dags_tis(acl_app, user_all_dags_tis): @pytest.fixture(scope="module") def user_all_dags_tis_xcom(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_tis_xcom", role_name="role_all_dags_tis_xcom", permissions=[ @@ -522,7 +538,7 @@ def client_all_dags_tis_xcom(acl_app, user_all_dags_tis_xcom): @pytest.fixture(scope="module") def user_dags_tis_logs(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_dags_tis_logs", role_name="role_dags_tis_logs", permissions=[ @@ -671,7 +687,7 @@ def test_blocked_success_when_selecting_dags( assert resp.status_code == 200 for dag_id in unexpected_dag_ids: check_content_not_in_response(dag_id, resp) - blocked_dags = {blocked["dag_id"] for blocked in json.loads(resp.data.decode())} + blocked_dags = {blocked["dag_id"] for blocked in resp.json()} for dag_id in dags_to_block: assert dag_id in blocked_dags @@ -679,7 +695,7 @@ def test_blocked_success_when_selecting_dags( @pytest.fixture(scope="module") def user_all_dags_edit_tis(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_all_dags_edit_tis", role_name="role_all_dags_edit_tis", permissions=[ @@ -723,7 +739,7 @@ def test_paused_post_success(dag_test_client): @pytest.fixture(scope="module") def user_only_dags_tis(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_only_dags_tis", role_name="role_only_dags_tis", permissions=[ @@ -755,7 +771,7 @@ def test_success_fail_for_read_only_task_instance_access(client_only_dags_tis): past="false", ) resp = client_only_dags_tis.post("success", data=form) - check_content_not_in_response("Wait a minute", resp, resp_code=302) + check_content_not_in_response("Wait a minute", resp, resp_code=200) GET_LOGS_WITH_METADATA_URL = ( @@ -786,7 +802,7 @@ def test_get_logs_with_metadata_failure(dag_faker_client): @pytest.fixture(scope="module") def user_no_roles(acl_app): - with create_user_scope(acl_app, username="no_roles_user", role_name="no_roles_user_role") as user: + with create_user_scope(acl_app.app, username="no_roles_user", role_name="no_roles_user_role") as user: user.roles = [] yield user @@ -803,7 +819,7 @@ def client_no_roles(acl_app, user_no_roles): @pytest.fixture(scope="module") def user_no_permissions(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="no_permissions_user", role_name="no_permissions_role", ) as user: @@ -841,7 +857,7 @@ def test_no_roles_permissions(request, client, url, status_code, expected_conten @pytest.fixture(scope="module") def user_dag_level_access_with_ti_edit(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_dag_level_access_with_ti_edit", role_name="role_dag_level_access_with_ti_edit", permissions=[ @@ -883,7 +899,7 @@ def test_success_edit_ti_with_dag_level_access_only(client_dag_level_access_with @pytest.fixture(scope="module") def user_ti_edit_without_dag_level_access(acl_app): with create_user_scope( - acl_app, + acl_app.app, username="user_ti_edit_without_dag_level_access", role_name="role_ti_edit_without_dag_level_access", permissions=[ diff --git a/tests/www/views/test_views_base.py b/tests/www/views/test_views_base.py index 63caa75f60d471..0ad1d189c51696 100644 --- a/tests/www/views/test_views_base.py +++ b/tests/www/views/test_views_base.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import json import pytest @@ -36,8 +35,9 @@ def test_index_redirect(admin_client): resp = admin_client.get("/") - assert resp.status_code == 302 - assert "/home" in resp.headers.get("Location") + # Starlette TestCliente used by connexion v3 responds after following the redirect + # therefore, the status code is 200 + assert resp.url.raw_path == b"/home" resp = admin_client.get("/", follow_redirects=True) check_content_in_response("DAGs", resp) @@ -122,7 +122,7 @@ def test_health(request, admin_client, heartbeat): # Load the corresponding fixture by name. scheduler_status, last_scheduler_heartbeat = request.getfixturevalue(heartbeat) resp = admin_client.get("health", follow_redirects=True) - resp_json = json.loads(resp.data.decode("utf-8")) + resp_json = resp.json() assert "healthy" == resp_json["metadatabase"]["status"] assert scheduler_status == resp_json["scheduler"]["status"] assert last_scheduler_heartbeat == resp_json["scheduler"]["latest_scheduler_heartbeat"] @@ -150,8 +150,8 @@ def test_roles_read_unauthorized(viewer_client): @pytest.fixture(scope="module") def delete_role_if_exists(app): def func(role_name): - if app.appbuilder.sm.find_role(role_name): - app.appbuilder.sm.delete_role(role_name) + if app.app.appbuilder.sm.find_role(role_name): + app.app.appbuilder.sm.delete_role(role_name) return func @@ -167,32 +167,32 @@ def non_exist_role_name(delete_role_if_exists): @pytest.fixture def exist_role_name(app, delete_role_if_exists): role_name = "test_roles_create_role_new" - app.appbuilder.sm.add_role(role_name) + app.app.appbuilder.sm.add_role(role_name) yield role_name delete_role_if_exists(role_name) @pytest.fixture def exist_role(app, exist_role_name): - return app.appbuilder.sm.find_role(exist_role_name) + return app.app.appbuilder.sm.find_role(exist_role_name) def test_roles_create(app, admin_client, non_exist_role_name): admin_client.post("roles/add", data={"name": non_exist_role_name}, follow_redirects=True) - assert app.appbuilder.sm.find_role(non_exist_role_name) is not None + assert app.app.appbuilder.sm.find_role(non_exist_role_name) is not None def test_roles_create_unauthorized(app, viewer_client, non_exist_role_name): resp = viewer_client.post("roles/add", data={"name": non_exist_role_name}, follow_redirects=True) check_content_in_response("Access is Denied", resp) - assert app.appbuilder.sm.find_role(non_exist_role_name) is None + assert app.app.appbuilder.sm.find_role(non_exist_role_name) is None def test_roles_edit(app, admin_client, non_exist_role_name, exist_role): admin_client.post( f"roles/edit/{exist_role.id}", data={"name": non_exist_role_name}, follow_redirects=True ) - updated_role = app.appbuilder.sm.find_role(non_exist_role_name) + updated_role = app.app.appbuilder.sm.find_role(non_exist_role_name) assert exist_role.id == updated_role.id @@ -201,19 +201,19 @@ def test_roles_edit_unauthorized(app, viewer_client, non_exist_role_name, exist_ f"roles/edit/{exist_role.id}", data={"name": non_exist_role_name}, follow_redirects=True ) check_content_in_response("Access is Denied", resp) - assert app.appbuilder.sm.find_role(exist_role_name) - assert app.appbuilder.sm.find_role(non_exist_role_name) is None + assert app.app.appbuilder.sm.find_role(exist_role_name) + assert app.app.appbuilder.sm.find_role(non_exist_role_name) is None def test_roles_delete(app, admin_client, exist_role_name, exist_role): admin_client.post(f"roles/delete/{exist_role.id}", follow_redirects=True) - assert app.appbuilder.sm.find_role(exist_role_name) is None + assert app.app.appbuilder.sm.find_role(exist_role_name) is None def test_roles_delete_unauthorized(app, viewer_client, exist_role, exist_role_name): resp = viewer_client.post(f"roles/delete/{exist_role.id}", follow_redirects=True) check_content_in_response("Access is Denied", resp) - assert app.appbuilder.sm.find_role(exist_role_name) + assert app.app.appbuilder.sm.find_role(exist_role_name) @pytest.mark.parametrize( @@ -253,7 +253,7 @@ def test_views_get(request, url, client, content): def _check_task_stats_json(resp): - return set(next(iter(resp.json.items()))[1][0]) == {"state", "count"} + return set(next(iter(resp.json().items()))[1][0]) == {"state", "count"} @pytest.mark.parametrize( @@ -281,7 +281,7 @@ def test_views_post(admin_client, url, check_response): ids=["my-viewer", "pk-admin", "pk-viewer"], ) def test_resetmypasswordview_edit(app, request, url, client, content, username): - user = app.appbuilder.sm.find_user(username) + user = app.app.appbuilder.sm.find_user(username) resp = request.getfixturevalue(client).post( url.format(user.id), data={"password": "blah", "conf_password": "blah"}, follow_redirects=True ) @@ -321,13 +321,13 @@ def test_views_post_access_denied(viewer_client, url): @pytest.fixture def non_exist_username(app): username = "fake_username" - user = app.appbuilder.sm.find_user(username) + user = app.app.appbuilder.sm.find_user(username) if user is not None: - app.appbuilder.sm.del_register_user(user) + app.app.appbuilder.sm.del_register_user(user) yield username - user = app.appbuilder.sm.find_user(username) + user = app.app.appbuilder.sm.find_user(username) if user is not None: - app.appbuilder.sm.del_register_user(user) + app.app.appbuilder.sm.del_register_user(user) def test_create_user(app, admin_client, non_exist_username): @@ -345,13 +345,13 @@ def test_create_user(app, admin_client, non_exist_username): follow_redirects=True, ) check_content_in_response("Added Row", resp) - assert app.appbuilder.sm.find_user(non_exist_username) + assert app.app.appbuilder.sm.find_user(non_exist_username) @pytest.fixture def exist_username(app, exist_role): username = "test_edit_user_user" - app.appbuilder.sm.add_user( + app.app.appbuilder.sm.add_user( username, "first_name", "last_name", @@ -360,12 +360,12 @@ def exist_username(app, exist_role): password="password", ) yield username - if app.appbuilder.sm.find_user(username): - app.appbuilder.sm.del_register_user(username) + if app.app.appbuilder.sm.find_user(username): + app.app.appbuilder.sm.del_register_user(username) def test_edit_user(app, admin_client, exist_username): - user = app.appbuilder.sm.find_user(exist_username) + user = app.app.appbuilder.sm.find_user(exist_username) resp = admin_client.post( f"users/edit/{user.id}", data={"first_name": "new_first_name"}, @@ -375,7 +375,7 @@ def test_edit_user(app, admin_client, exist_username): def test_delete_user(app, admin_client, exist_username): - user = app.appbuilder.sm.find_user(exist_username) + user = app.app.appbuilder.sm.find_user(exist_username) resp = admin_client.post( f"users/delete/{user.id}", follow_redirects=True, @@ -419,5 +419,5 @@ def test_page_instance_name_with_markup(admin_client): @conf_vars(instance_name_with_markup_conf) def test_page_instance_name_with_markup_title(): - appbuilder = application.create_app(testing=True).appbuilder + appbuilder = application.create_app(testing=True).app.appbuilder assert appbuilder.app_name == "Bold Site Title Test" diff --git a/tests/www/views/test_views_blocked.py b/tests/www/views/test_views_blocked.py index c3e8cd4e88cf1f..d0b44c77b6eb1c 100644 --- a/tests/www/views/test_views_blocked.py +++ b/tests/www/views/test_views_blocked.py @@ -81,7 +81,7 @@ def test_blocked_subdag_success(admin_client, running_subdag): """ resp = admin_client.post("/blocked", data={"dag_ids": [running_subdag.dag_id]}) assert resp.status_code == 200 - assert resp.json == [ + assert resp.json() == [ { "dag_id": running_subdag.dag_id, "active_dag_run": 1, diff --git a/tests/www/views/test_views_cluster_activity.py b/tests/www/views/test_views_cluster_activity.py index a0d5bcf39f70dd..acc3abb07c2064 100644 --- a/tests/www/views/test_views_cluster_activity.py +++ b/tests/www/views/test_views_cluster_activity.py @@ -94,7 +94,9 @@ def make_dag_runs(dag_maker, session, time_machine): time_machine.move_to("2023-07-02T00:00:00+00:00", tick=False) + session.commit() session.flush() + session.close() @pytest.mark.usefixtures("freeze_time_for_dagruns", "make_dag_runs") @@ -104,7 +106,7 @@ def test_historical_metrics_data(admin_client, session, time_machine): follow_redirects=True, ) assert resp.status_code == 200 - assert resp.json == { + assert resp.json() == { "dag_run_states": {"failed": 1, "queued": 0, "running": 1, "success": 1}, "dag_run_types": {"backfill": 0, "dataset_triggered": 1, "manual": 0, "scheduled": 2}, "task_instance_states": { @@ -133,7 +135,7 @@ def test_historical_metrics_data_date_filters(admin_client, session): follow_redirects=True, ) assert resp.status_code == 200 - assert resp.json == { + assert resp.json() == { "dag_run_states": {"failed": 1, "queued": 0, "running": 0, "success": 0}, "dag_run_types": {"backfill": 0, "dataset_triggered": 1, "manual": 0, "scheduled": 0}, "task_instance_states": { diff --git a/tests/www/views/test_views_connection.py b/tests/www/views/test_views_connection.py index a209cdfc2be8a5..507d2d1e5afaf3 100644 --- a/tests/www/views/test_views_connection.py +++ b/tests/www/views/test_views_connection.py @@ -424,7 +424,7 @@ def test_connection_form_widgets_testable_types(mock_pm_hooks, admin_client): assert ["first"] == ConnectionFormWidget().testable_connection_types -def test_process_form_invalid_extra_removed(admin_client): +def test_process_form_invalid_extra_removed(flask_admin_client): """ Test that when an invalid json `extra` is passed in the form, it is removed and _not_ saved over the existing extras. @@ -437,7 +437,7 @@ def test_process_form_invalid_extra_removed(admin_client): session.add(conn) data = {**conn_details, "extra": "Invalid"} - resp = admin_client.post("/connection/edit/1", data=data, follow_redirects=True) + resp = flask_admin_client.post("/connection/edit/1", data=data, follow_redirects=True) assert resp.status_code == 200 with create_session() as session: diff --git a/tests/www/views/test_views_custom_user_views.py b/tests/www/views/test_views_custom_user_views.py index ae6d0132827c21..692947c34782ad 100644 --- a/tests/www/views/test_views_custom_user_views.py +++ b/tests/www/views/test_views_custom_user_views.py @@ -28,7 +28,11 @@ from airflow.security import permissions from airflow.www import app as application from tests.test_utils.api_connexion_utils import create_user, delete_role -from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login +from tests.test_utils.www import ( + check_content_in_response, + check_content_not_in_response, + client_with_login, +) pytestmark = pytest.mark.db_test @@ -67,23 +71,24 @@ def setup_method(self): # an exception because app context teardown is removed and if even single request is run via app # it cannot be re-intialized again by passing it as constructor to SQLA # This makes the tests slightly slower (but they work with Flask 2.1 and 2.2 - self.app = application.create_app(testing=True) - self.appbuilder = self.app.appbuilder - self.app.config["WTF_CSRF_ENABLED"] = False + self.connexion_app = application.create_app(testing=True) + self.flask_app = self.connexion_app.app + self.appbuilder = self.flask_app.appbuilder + self.flask_app.config["WTF_CSRF_ENABLED"] = False self.security_manager = self.appbuilder.sm self.delete_roles() - self.db = SQLA(self.app) + self.db = SQLA(self.flask_app) - self.client = self.app.test_client() # type:ignore + self.client = self.connexion_app.test_client() # type:ignore def delete_roles(self): for role_name in ["role_edit_one_dag"]: - delete_role(self.app, role_name) + delete_role(self.flask_app, role_name) @pytest.mark.parametrize("url, _, expected_text", PERMISSIONS_TESTS_PARAMS) def test_user_model_view_with_access(self, url, expected_text, _): user_without_access = create_user( - self.app, + self.flask_app, username="no_access", role_name="role_no_access", permissions=[ @@ -91,7 +96,7 @@ def test_user_model_view_with_access(self, url, expected_text, _): ], ) client = client_with_login( - self.app, + self.connexion_app, username="no_access", password="no_access", ) @@ -101,14 +106,14 @@ def test_user_model_view_with_access(self, url, expected_text, _): @pytest.mark.parametrize("url, permission, expected_text", PERMISSIONS_TESTS_PARAMS) def test_user_model_view_without_access(self, url, permission, expected_text): user_with_access = create_user( - self.app, + self.flask_app, username="has_access", role_name="role_has_access", permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), permission], ) client = client_with_login( - self.app, + self.connexion_app, username="has_access", password="has_access", ) @@ -117,22 +122,23 @@ def test_user_model_view_without_access(self, url, permission, expected_text): def test_user_model_view_without_delete_access(self): user_to_delete = create_user( - self.app, + self.flask_app, username="user_to_delete", role_name="user_to_delete", ) create_user( - self.app, + self.flask_app, username="no_access", role_name="role_no_access", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), ], ) client = client_with_login( - self.app, + self.connexion_app, username="no_access", password="no_access", ) @@ -140,27 +146,29 @@ def test_user_model_view_without_delete_access(self): response = client.post(f"/users/delete/{user_to_delete.id}", follow_redirects=True) check_content_not_in_response("Deleted Row", response) - assert bool(self.security_manager.get_user_by_id(user_to_delete.id)) is True + response = client.get(f"/users/show/{user_to_delete.id}", follow_redirects=True) + assert response.status_code == 200 def test_user_model_view_with_delete_access(self): user_to_delete = create_user( - self.app, + self.flask_app, username="user_to_delete", role_name="user_to_delete", ) create_user( - self.app, + self.flask_app, username="has_access", role_name="role_has_access", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_USER), ], ) client = client_with_login( - self.app, + self.connexion_app, username="has_access", password="has_access", ) @@ -168,7 +176,8 @@ def test_user_model_view_with_delete_access(self): response = client.post(f"/users/delete/{user_to_delete.id}", follow_redirects=True) check_content_in_response("Deleted Row", response) check_content_not_in_response(user_to_delete.username, response) - assert bool(self.security_manager.get_user_by_id(user_to_delete.id)) is False + response = client.get(f"/users/show/{user_to_delete.id}", follow_redirects=True) + assert response.status_code == 404 # type: ignore[attr-defined] @@ -184,11 +193,12 @@ def setup_method(self): # an exception because app context teardown is removed and if even single request is run via app # it cannot be re-intialized again by passing it as constructor to SQLA # This makes the tests slightly slower (but they work with Flask 2.1 and 2.2 - self.app = application.create_app(testing=True) - self.appbuilder = self.app.appbuilder - self.app.config["WTF_CSRF_ENABLED"] = False + self.connexion_app = application.create_app(testing=True) + self.flask_app = self.connexion_app.app + self.appbuilder = self.flask_app.appbuilder + self.flask_app.config["WTF_CSRF_ENABLED"] = False self.security_manager = self.appbuilder.sm - self.interface = self.app.session_interface + self.interface = self.flask_app.session_interface self.model = self.interface.sql_session_model self.serializer = self.interface.serializer self.db = self.interface.db @@ -196,12 +206,12 @@ def setup_method(self): self.db.session.commit() self.db.session.flush() self.user_1 = create_user( - self.app, + self.flask_app, username="user_to_delete_1", role_name="user_to_delete", ) self.user_2 = create_user( - self.app, + self.flask_app, username="user_to_delete_2", role_name="user_to_delete", ) @@ -277,7 +287,7 @@ def test_refuse_delete(self, _mock_has_context, flash_mock): "airflow.providers.fab.auth_manager.security_manager.override.has_request_context", return_value=True ) def test_warn_securecookie(self, _mock_has_context, flash_mock): - self.app.session_interface = SecureCookieSessionInterface() + self.flask_app.session_interface = SecureCookieSessionInterface() self.security_manager.reset_password(self.user_1.id, "new_password") assert flash_mock.called assert ( @@ -309,7 +319,7 @@ def test_refuse_delete_cli(self, log_mock): @mock.patch("airflow.providers.fab.auth_manager.security_manager.override.log") def test_warn_securecookie_cli(self, log_mock): - self.app.session_interface = SecureCookieSessionInterface() + self.flask_app.session_interface = SecureCookieSessionInterface() self.security_manager.reset_password(self.user_1.id, "new_password") assert log_mock.warning.called assert ( diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py index b7e048e0eaf216..705b6ff3d7ec6f 100644 --- a/tests/www/views/test_views_dagrun.py +++ b/tests/www/views/test_views_dagrun.py @@ -25,22 +25,25 @@ from airflow.utils.session import create_session from airflow.www.views import DagRunModelView from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user -from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login +from tests.test_utils.www import ( + check_content_in_response, + check_content_not_in_response, + flask_client_with_login, +) from tests.www.views.test_views_tasks import _get_appbuilder_pk_string pytestmark = pytest.mark.db_test @pytest.fixture(scope="module") -def client_dr_without_dag_edit(app): +def flask_client_dr_without_dag_run_create(app): create_user( - app, - username="all_dr_permissions_except_dag_edit", - role_name="all_dr_permissions_except_dag_edit", + app.app, + username="all_dr_permissions_except_dag_run_create", + role_name="all_dr_permissions_except_dag_run_create", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), @@ -48,25 +51,26 @@ def client_dr_without_dag_edit(app): ], ) - yield client_with_login( + yield flask_client_with_login( app, - username="all_dr_permissions_except_dag_edit", - password="all_dr_permissions_except_dag_edit", + username="all_dr_permissions_except_dag_run_create", + password="all_dr_permissions_except_dag_run_create", ) - delete_user(app, username="all_dr_permissions_except_dag_edit") # type: ignore - delete_roles(app) + delete_user(app.app, username="all_dr_permissions_except_dag_run_create") # type: ignore + delete_roles(app.app) @pytest.fixture(scope="module") -def client_dr_without_dag_run_create(app): +def flask_client_dr_without_dag_edit(app): create_user( - app, - username="all_dr_permissions_except_dag_run_create", - role_name="all_dr_permissions_except_dag_run_create", + app.app, + username="all_dr_permissions_except_dag_edit", + role_name="all_dr_permissions_except_dag_edit", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), @@ -74,14 +78,14 @@ def client_dr_without_dag_run_create(app): ], ) - yield client_with_login( + yield flask_client_with_login( app, - username="all_dr_permissions_except_dag_run_create", - password="all_dr_permissions_except_dag_run_create", + username="all_dr_permissions_except_dag_edit", + password="all_dr_permissions_except_dag_edit", ) - delete_user(app, username="all_dr_permissions_except_dag_run_create") # type: ignore - delete_roles(app) + delete_user(app.app, username="all_dr_permissions_except_dag_edit") # type: ignore + delete_roles(app.app) @pytest.fixture(scope="module", autouse=True) @@ -103,14 +107,16 @@ def reset_dagrun(): session.query(TaskInstance).delete() -def test_get_dagrun_can_view_dags_without_edit_perms(session, running_dag_run, client_dr_without_dag_edit): +def test_get_dagrun_can_view_dags_without_edit_perms( + session, running_dag_run, flask_client_dr_without_dag_edit +): """Test that a user without dag_edit but with dag_read permission can view the records""" assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 - resp = client_dr_without_dag_edit.get("/dagrun/list/", follow_redirects=True) + resp = flask_client_dr_without_dag_edit.get("/dagrun/list/", follow_redirects=True) check_content_in_response(running_dag_run.dag_id, resp) -def test_create_dagrun_permission_denied(session, client_dr_without_dag_run_create): +def test_create_dagrun_permission_denied(session, flask_client_dr_without_dag_run_create): data = { "state": "running", "dag_id": "example_bash_operator", @@ -119,7 +125,7 @@ def test_create_dagrun_permission_denied(session, client_dr_without_dag_run_crea "conf": '{"include": "me"}', } - resp = client_dr_without_dag_run_create.post("/dagrun/add", data=data, follow_redirects=True) + resp = flask_client_dr_without_dag_run_create.post("/dagrun/add", data=data, follow_redirects=True) check_content_in_response("Access is Denied", resp) @@ -169,18 +175,18 @@ def completed_dag_run_with_missing_task(session): return dag, dr -def test_delete_dagrun(session, admin_client, running_dag_run): +def test_delete_dagrun(session, flask_admin_client, running_dag_run): composite_key = _get_appbuilder_pk_string(DagRunModelView, running_dag_run) assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 - admin_client.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) + flask_admin_client.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 0 -def test_delete_dagrun_permission_denied(session, running_dag_run, client_dr_without_dag_edit): +def test_delete_dagrun_permission_denied(session, running_dag_run, flask_client_dr_without_dag_edit): composite_key = _get_appbuilder_pk_string(DagRunModelView, running_dag_run) assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 - resp = client_dr_without_dag_edit.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) + resp = flask_client_dr_without_dag_edit.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) check_content_in_response("Access is Denied", resp) assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 @@ -218,13 +224,13 @@ def test_delete_dagrun_permission_denied(session, running_dag_run, client_dr_wit ) def test_set_dag_runs_action( session, - admin_client, + flask_admin_client, running_dag_run, action, expected_ti_states, expected_message, ): - resp = admin_client.post( + resp = flask_admin_client.post( "/dagrun/action_post", data={"action": action, "rowid": [running_dag_run.id]}, follow_redirects=True, @@ -244,8 +250,8 @@ def test_set_dag_runs_action( ], ids=["clear", "success", "failed", "running", "queued"], ) -def test_set_dag_runs_action_fails(admin_client, action, expected_message): - resp = admin_client.post( +def test_set_dag_runs_action_fails(flask_admin_client, action, expected_message): + resp = flask_admin_client.post( "/dagrun/action_post", data={"action": action, "rowid": ["0"]}, follow_redirects=True, @@ -253,9 +259,9 @@ def test_set_dag_runs_action_fails(admin_client, action, expected_message): check_content_in_response(expected_message, resp) -def test_muldelete_dag_runs_action(session, admin_client, running_dag_run): +def test_muldelete_dag_runs_action(session, flask_admin_client, running_dag_run): dag_run_id = running_dag_run.id - resp = admin_client.post( + resp = flask_admin_client.post( "/dagrun/action_post", data={"action": "muldelete", "rowid": [dag_run_id]}, follow_redirects=True, @@ -270,9 +276,9 @@ def test_muldelete_dag_runs_action(session, admin_client, running_dag_run): ["clear", "set_success", "set_failed", "set_running"], ids=["clear", "success", "failed", "running"], ) -def test_set_dag_runs_action_permission_denied(client_dr_without_dag_edit, running_dag_run, action): +def test_set_dag_runs_action_permission_denied(flask_client_dr_without_dag_edit, running_dag_run, action): running_dag_id = running_dag_run.id - resp = client_dr_without_dag_edit.post( + resp = flask_client_dr_without_dag_edit.post( "/dagrun/action_post", data={"action": action, "rowid": [str(running_dag_id)]}, follow_redirects=True, @@ -280,9 +286,9 @@ def test_set_dag_runs_action_permission_denied(client_dr_without_dag_edit, runni check_content_in_response("Access is Denied", resp) -def test_dag_runs_queue_new_tasks_action(session, admin_client, completed_dag_run_with_missing_task): +def test_dag_runs_queue_new_tasks_action(session, flask_admin_client, completed_dag_run_with_missing_task): dag, dag_run = completed_dag_run_with_missing_task - resp = admin_client.post( + resp = flask_admin_client.post( "/dagrun_queued", data={"dag_id": dag.dag_id, "dag_run_id": dag_run.run_id, "confirmed": False}, ) diff --git a/tests/www/views/test_views_dataset.py b/tests/www/views/test_views_dataset.py index d67ed80f385e56..01771bf0a97a26 100644 --- a/tests/www/views/test_views_dataset.py +++ b/tests/www/views/test_views_dataset.py @@ -55,7 +55,7 @@ def test_should_respond_200(self, admin_client, session): response = admin_client.get("/object/datasets_summary") assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -89,7 +89,7 @@ def test_order_by_raises_400_for_invalid_attr(self, admin_client, session): assert response.status_code == 400 msg = "Ordering with 'fake' is disallowed or the attribute does not exist on the model" - assert response.json["detail"] == msg + assert response.json()["detail"] == msg def test_order_by_raises_400_for_invalid_datetimes(self, admin_client, session): datasets = [ @@ -139,15 +139,15 @@ def test_filter_by_datetimes(self, admin_client, session): response = admin_client.get(f"/object/datasets_summary?updated_after={cutoff}") assert response.status_code == 200 - assert response.json["total_entries"] == 2 - assert [json_dict["id"] for json_dict in response.json["datasets"]] == [2, 3] + assert response.json()["total_entries"] == 2 + assert [json_dict["id"] for json_dict in response.json()["datasets"]] == [2, 3] cutoff = today.add(days=-1).add(minutes=5).to_iso8601_string() response = admin_client.get(f"/object/datasets_summary?updated_before={cutoff}") assert response.status_code == 200 - assert response.json["total_entries"] == 2 - assert [json_dict["id"] for json_dict in response.json["datasets"]] == [1, 2] + assert response.json()["total_entries"] == 2 + assert [json_dict["id"] for json_dict in response.json()["datasets"]] == [1, 2] @pytest.mark.parametrize( "order_by, ordered_dataset_ids", @@ -188,8 +188,8 @@ def test_order_by(self, admin_client, session, order_by, ordered_dataset_ids): response = admin_client.get(f"/object/datasets_summary?order_by={order_by}") assert response.status_code == 200 - assert ordered_dataset_ids == [json_dict["id"] for json_dict in response.json["datasets"]] - assert response.json["total_entries"] == len(ordered_dataset_ids) + assert ordered_dataset_ids == [json_dict["id"] for json_dict in response.json()["datasets"]] + assert response.json()["total_entries"] == len(ordered_dataset_ids) def test_search_uri_pattern(self, admin_client, session): datasets = [ @@ -207,7 +207,7 @@ def test_search_uri_pattern(self, admin_client, session): response = admin_client.get(f"/object/datasets_summary?uri_pattern={uri_pattern}") assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -224,7 +224,7 @@ def test_search_uri_pattern(self, admin_client, session): response = admin_client.get(f"/object/datasets_summary?uri_pattern={uri_pattern}") assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -289,7 +289,7 @@ def test_correct_counts_update(self, admin_client, session, dag_maker, app, monk ): EmptyOperator(task_id="task1", outlets=[datasets[4]]) - m.setattr(app, "dag_bag", dag_maker.dagbag) + m.setattr(app.app, "dag_bag", dag_maker.dagbag) ds1_id = session.query(DatasetModel.id).filter_by(uri=datasets[0].uri).scalar() ds2_id = session.query(DatasetModel.id).filter_by(uri=datasets[1].uri).scalar() @@ -342,7 +342,7 @@ def test_correct_counts_update(self, admin_client, session, dag_maker, app, monk response = admin_client.get("/object/datasets_summary") assert response.status_code == 200 - response_data = response.json + response_data = response.json() assert response_data == { "datasets": [ { @@ -408,7 +408,7 @@ def test_limit_and_offset(self, admin_client, session, url, expected_dataset_uri response = admin_client.get(url) assert response.status_code == 200 - dataset_uris = [dataset["uri"] for dataset in response.json["datasets"]] + dataset_uris = [dataset["uri"] for dataset in response.json()["datasets"]] assert dataset_uris == expected_dataset_uris def test_should_respect_page_size_limit_default(self, admin_client, session): @@ -425,7 +425,7 @@ def test_should_respect_page_size_limit_default(self, admin_client, session): response = admin_client.get("/object/datasets_summary") assert response.status_code == 200 - assert len(response.json["datasets"]) == 25 + assert len(response.json()["datasets"]) == 25 def test_should_return_max_if_req_above(self, admin_client, session): datasets = [ @@ -441,15 +441,19 @@ def test_should_return_max_if_req_above(self, admin_client, session): response = admin_client.get("/object/datasets_summary?limit=180") assert response.status_code == 200 - assert len(response.json["datasets"]) == 50 + assert len(response.json()["datasets"]) == 50 class TestGetDatasetNextRunSummary(TestDatasetEndpoint): - def test_next_run_dataset_summary(self, dag_maker, admin_client): - with dag_maker(dag_id="upstream", schedule=[Dataset(uri="s3://bucket/key/1")], serialized=True): + def test_next_run_dataset_summary(self, dag_maker, admin_client, session): + with dag_maker( + dag_id="upstream", schedule=[Dataset(uri="s3://bucket/key/1")], serialized=True, session=session + ): EmptyOperator(task_id="task1") + session.commit() + session.close() response = admin_client.post("/next_run_datasets_summary", data={"dag_ids": ["upstream"]}) assert response.status_code == 200 - assert response.json == {"upstream": {"ready": 0, "total": 1, "uri": "s3://bucket/key/1"}} + assert response.json() == {"upstream": {"ready": 0, "total": 1, "uri": "s3://bucket/key/1"}} diff --git a/tests/www/views/test_views_extra_links.py b/tests/www/views/test_views_extra_links.py index a37e9f32d88886..6ff080262c134f 100644 --- a/tests/www/views/test_views_extra_links.py +++ b/tests/www/views/test_views_extra_links.py @@ -78,13 +78,17 @@ def dag(): @pytest.fixture(scope="module") def create_dag_run(dag): def _create_dag_run(*, execution_date, session): - return dag.create_dagrun( - state=DagRunState.RUNNING, - execution_date=execution_date, - data_interval=(execution_date, execution_date), - run_type=DagRunType.MANUAL, - session=session, - ) + try: + return dag.create_dagrun( + state=DagRunState.RUNNING, + execution_date=execution_date, + data_interval=(execution_date, execution_date), + run_type=DagRunType.MANUAL, + session=session, + ) + finally: + session.commit() + session.close() return _create_dag_run @@ -96,7 +100,7 @@ def dag_run(create_dag_run, session): @pytest.fixture(scope="module", autouse=True) def patched_app(app, dag): - with mock.patch.object(app, "dag_bag") as mock_dag_bag: + with mock.patch.object(app.app, "dag_bag") as mock_dag_bag: mock_dag_bag.get_dag.return_value = dag yield @@ -139,7 +143,7 @@ def test_extra_links_works(dag_run, task_1, viewer_client, session): ) assert response.status_code == 200 - assert json.loads(response.data.decode()) == { + assert json.loads(response.text) == { "url": "http://www.example.com/some_dummy_task/foo-bar/manual__2017-01-01T00:00:00+00:00", "error": None, } @@ -153,7 +157,7 @@ def test_global_extra_links_works(dag_run, task_1, viewer_client, session): ) assert response.status_code == 200 - assert json.loads(response.data.decode()) == { + assert json.loads(response.text) == { "url": "https://github.com/apache/airflow", "error": None, } @@ -167,10 +171,7 @@ def test_operator_extra_link_override_global_extra_link(dag_run, task_1, viewer_ ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org", "error": None} + assert json.loads(response.text) == {"url": "https://airflow.apache.org", "error": None} def test_extra_links_error_raised(dag_run, task_1, viewer_client): @@ -181,10 +182,7 @@ def test_extra_links_error_raised(dag_run, task_1, viewer_client): ) assert 404 == response.status_code - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": None, "error": "This is an error"} + assert json.loads(response.text) == {"url": None, "error": "This is an error"} def test_extra_links_no_response(dag_run, task_1, viewer_client): @@ -195,10 +193,7 @@ def test_extra_links_no_response(dag_run, task_1, viewer_client): ) assert response.status_code == 404 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": None, "error": "No URL found for no_response"} + assert json.loads(response.text) == {"url": None, "error": "No URL found for no_response"} def test_operator_extra_link_override_plugin(dag_run, task_2, viewer_client): @@ -216,10 +211,8 @@ def test_operator_extra_link_override_plugin(dag_run, task_2, viewer_client): ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + + assert json.loads(response.text) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_client): @@ -238,10 +231,8 @@ def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_ ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + + assert json.loads(response.text) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} response = viewer_client.get( f"{ENDPOINT}?dag_id={task_3.dag_id}&task_id={task_3.task_id}" @@ -250,10 +241,7 @@ def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_ ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} + assert json.loads(response.text) == {"url": "https://airflow.apache.org/1.10.5/", "error": None} # Also check that the other Operator Link defined for this operator exists response = viewer_client.get( @@ -263,7 +251,4 @@ def test_operator_extra_link_multiple_operators(dag_run, task_2, task_3, viewer_ ) assert response.status_code == 200 - response_str = response.data - if isinstance(response.data, bytes): - response_str = response_str.decode() - assert json.loads(response_str) == {"url": "https://www.google.com", "error": None} + assert json.loads(response.text) == {"url": "https://www.google.com", "error": None} diff --git a/tests/www/views/test_views_grid.py b/tests/www/views/test_views_grid.py index 47ea3d9ead2c85..9e8a250a1cef4c 100644 --- a/tests/www/views/test_views_grid.py +++ b/tests/www/views/test_views_grid.py @@ -80,7 +80,7 @@ def mapped_task_group(arg1): with TaskGroup(group_id="group"): MockOperator.partial(task_id="mapped").expand(arg1=["a", "b", "c", "d"]) - m.setattr(app, "dag_bag", dag_maker.dagbag) + m.setattr(app.app, "dag_bag", dag_maker.dagbag) yield dag_maker @@ -95,14 +95,15 @@ def dag_with_runs(dag_without_runs): run_type=DagRunType.SCHEDULED, execution_date=dag_without_runs.dag.next_dagrun_info(date).logical_date, ) - return run_1, run_2 -def test_no_runs(admin_client, dag_without_runs): +def test_no_runs(admin_client, dag_without_runs, session): + session.commit() + session.close() resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}", follow_redirects=True) - assert resp.status_code == 200, resp.json - assert resp.json == { + assert resp.status_code == 200, resp.json() + assert resp.json() == { "dag_runs": [], "groups": { "children": [ @@ -162,7 +163,9 @@ def test_no_runs(admin_client, dag_without_runs): } -def test_grid_data_filtered_on_run_type_and_run_state(admin_client, dag_with_runs): +def test_grid_data_filtered_on_run_type_and_run_state(admin_client, dag_with_runs, session): + session.commit() + session.close() for uri_params, expected_run_types, expected_run_states in [ ("run_state=success&run_state=queued", ["scheduled"], ["success"]), ("run_state=running&run_state=failed", ["scheduled"], ["running"]), @@ -176,9 +179,9 @@ def test_grid_data_filtered_on_run_type_and_run_state(admin_client, dag_with_run ), ]: resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}&{uri_params}", follow_redirects=True) - assert resp.status_code == 200, resp.json - actual_run_types = list(map(lambda x: x["run_type"], resp.json["dag_runs"])) - actual_run_states = list(map(lambda x: x["state"], resp.json["dag_runs"])) + assert resp.status_code == 200, resp.json() + actual_run_types = list(map(lambda x: x["run_type"], resp.json()["dag_runs"])) + actual_run_states = list(map(lambda x: x["state"], resp.json()["dag_runs"])) assert actual_run_types == expected_run_types assert actual_run_states == expected_run_states @@ -198,7 +201,6 @@ def test_one_run(admin_client, dag_with_runs: list[DagRun], session): - One TI not yet finished """ run1, run2 = dag_with_runs - for ti in run1.task_instances: ti.state = TaskInstanceState.SUCCESS for ti in sorted(run2.task_instances, key=lambda ti: (ti.task_id, ti.map_index)): @@ -213,14 +215,14 @@ def test_one_run(admin_client, dag_with_runs: list[DagRun], session): ti.state = TaskInstanceState.RUNNING ti.start_date = pendulum.DateTime(2021, 7, 1, 2, 3, 4, tzinfo=pendulum.UTC) ti.end_date = None - + session.commit() session.flush() - + session.close() resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}", follow_redirects=True) - assert resp.status_code == 200, resp.json + assert resp.status_code == 200, resp.json() - assert resp.json == { + assert resp.json() == { "dag_runs": [ { "conf": None, @@ -428,7 +430,9 @@ def test_has_outlet_dataset_flag(admin_client, dag_maker, session, app, monkeypa EmptyOperator(task_id="task3", outlets=[Dataset("foo"), lineagefile]) EmptyOperator(task_id="task4", outlets=[Dataset("foo")]) - m.setattr(app, "dag_bag", dag_maker.dagbag) + m.setattr(app.app, "dag_bag", dag_maker.dagbag) + session.commit() + session.close() resp = admin_client.get(f"/object/grid_data?dag_id={DAG_ID}", follow_redirects=True) def _expected_task_details(task_id, has_outlet_datasets): @@ -443,8 +447,8 @@ def _expected_task_details(task_id, has_outlet_datasets): "trigger_rule": "all_success", } - assert resp.status_code == 200, resp.json - assert resp.json == { + assert resp.status_code == 200, resp.json() + assert resp.json() == { "dag_runs": [], "groups": { "children": [ @@ -469,7 +473,7 @@ def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): with dag_maker(dag_id=DAG_ID, schedule=datasets, serialized=True, session=session): EmptyOperator(task_id="task1") - m.setattr(app, "dag_bag", dag_maker.dagbag) + m.setattr(app.app, "dag_bag", dag_maker.dagbag) ds1_id = session.query(DatasetModel.id).filter_by(uri=datasets[0].uri).scalar() ds2_id = session.query(DatasetModel.id).filter_by(uri=datasets[1].uri).scalar() @@ -499,8 +503,8 @@ def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): resp = admin_client.get(f"/object/next_run_datasets/{DAG_ID}", follow_redirects=True) - assert resp.status_code == 200, resp.json - assert resp.json == { + assert resp.status_code == 200, resp.json() + assert resp.json() == { "dataset_expression": {"all": ["s3://bucket/key/1", "s3://bucket/key/2"]}, "events": [ {"id": ds1_id, "uri": "s3://bucket/key/1", "lastUpdate": "2022-08-02T02:00:00+00:00"}, @@ -511,5 +515,5 @@ def test_next_run_datasets(admin_client, dag_maker, session, app, monkeypatch): def test_next_run_datasets_404(admin_client): resp = admin_client.get("/object/next_run_datasets/missingdag", follow_redirects=True) - assert resp.status_code == 404, resp.json - assert resp.json == {"error": "can't find dag missingdag"} + assert resp.status_code == 404, resp.json() + assert resp.json() == {"error": "can't find dag missingdag"} diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py index 5ddcb65a871f5a..15fef70b4d8d90 100644 --- a/tests/www/views/test_views_home.py +++ b/tests/www/views/test_views_home.py @@ -85,13 +85,15 @@ def call_kwargs(): update_stmt = update(DagModel).where(DagModel.dag_id == "filter_test_1").values(is_active=False) session.execute(update_stmt) + session.commit() + session.close() admin_client.get("home", follow_redirects=True) assert call_kwargs()["status_count_all"] == 3 -def test_home_status_filter_cookie(admin_client): - with admin_client: +def test_home_status_filter_cookie(admin_flask_client): + with admin_flask_client as admin_client: admin_client.get("home", follow_redirects=True) assert "all" == flask.session[FILTER_STATUS_COOKIE] @@ -115,7 +117,7 @@ def test_home_status_filter_cookie(admin_client): def user_no_importerror(app): """Create User that cannot access Import Errors""" return create_user( - app, + app.app, username="user_no_importerrors", role_name="role_no_importerrors", permissions=[ @@ -139,7 +141,7 @@ def client_no_importerror(app, user_no_importerror): def user_single_dag(app): """Create User that can only access the first DAG from TEST_FILTER_DAG_IDS""" return create_user( - app, + app.app, username="user_single_dag", role_name="role_single_dag", permissions=[ @@ -164,7 +166,7 @@ def client_single_dag(app, user_single_dag): def user_single_dag_edit(app): """Create User that can edit DAG resource only a single DAG""" return create_user( - app, + app.app, username="user_single_dag_edit", role_name="role_single_dag", permissions=[ @@ -275,8 +277,8 @@ def broken_dags_after_working(tmp_path): _process_file(path, session) -def test_home_filter_tags(working_dags, admin_client): - with admin_client: +def test_home_filter_tags(working_dags, admin_flask_client): + with admin_flask_client as admin_client: admin_client.get("home?tags=example&tags=data", follow_redirects=True) assert "example,data" == flask.session[FILTER_TAGS_COOKIE] @@ -447,7 +449,7 @@ def test_dashboard_flash_messages_type(user_client): ) def test_sorting_home_view(url, lower_key, greater_key, user_client, working_dags): resp = user_client.get(url, follow_redirects=True) - resp_html = resp.data.decode("utf-8") + resp_html = resp.text lower_index = resp_html.find(lower_key) greater_index = resp_html.find(greater_key) assert lower_index < greater_index diff --git a/tests/www/views/test_views_log.py b/tests/www/views/test_views_log.py index 3d3248f1108b2f..553d91f916c764 100644 --- a/tests/www/views/test_views_log.py +++ b/tests/www/views/test_views_log.py @@ -43,7 +43,7 @@ from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs from tests.test_utils.decorators import dont_initialize_flask_app_submodules -from tests.test_utils.www import client_with_login +from tests.test_utils.www import client_with_login, flask_client_with_login pytestmark = pytest.mark.db_test @@ -84,9 +84,9 @@ def log_app(backup_modules, log_path): ) def factory(): app = create_app(testing=True) - app.config["WTF_CSRF_ENABLED"] = False + app.app.config["WTF_CSRF_ENABLED"] = False settings.configure_orm() - security_manager = app.appbuilder.sm + security_manager = app.app.appbuilder.sm if not security_manager.find_user(username="test"): security_manager.add_user( username="test", @@ -142,7 +142,7 @@ def dags(log_app, create_dummy_dag, session): bag.bag_dag(dag=dag, root_dag=dag) bag.bag_dag(dag=dag_removed, root_dag=dag_removed) bag.sync_to_db(session=session) - log_app.dag_bag = bag + log_app.app.dag_bag = bag yield dag, dag_removed @@ -174,6 +174,9 @@ def tis(dags, session): (ti_removed_dag,) = dagrun_removed.task_instances ti_removed_dag.try_number = 1 + session.commit() + session.close() + yield ti, ti_removed_dag clear_db_runs() @@ -198,6 +201,11 @@ def create_expected_log_file(try_number): shutil.rmtree(sub_path) +@pytest.fixture +def flask_log_admin_client(log_app): + return flask_client_with_login(log_app, username="test", password="test") + + @pytest.fixture def log_admin_client(log_app): return client_with_login(log_app, username="test", password="test") @@ -233,12 +241,12 @@ def test_get_file_task_log(log_admin_client, tis, state, try_number, num_logs): response = log_admin_client.get( ENDPOINT, - data={"username": "test", "password": "test"}, + params={"username": "test", "password": "test"}, follow_redirects=True, ) assert response.status_code == 200 - data = response.data.decode() + data = response.text assert "Log by attempts" in data for num in range(1, num_logs + 1): assert f"log-group-{num}" in data @@ -271,8 +279,8 @@ def test_get_logs_with_metadata_as_download_file(log_admin_client, create_expect in content_disposition ) assert 200 == response.status_code - assert "Log for testing." in response.data.decode("utf-8") - assert "localhost\n" in response.data.decode("utf-8") + assert "Log for testing." in response.text + assert "localhost\n" in response.text DIFFERENT_LOG_FILENAME = "{{ ti.dag_id }}/{{ ti.run_id }}/{{ ti.task_id }}/{{ try_number }}.log" @@ -313,7 +321,7 @@ def test_get_logs_for_changed_filename_format_db( # Should find the log under corresponding db entry. assert 200 == response.status_code - assert "Log for testing." in response.data.decode("utf-8") + assert "Log for testing." in response.text content_disposition = response.headers["Content-Disposition"] expected_filename = ( f"{dag_run_with_log_filename.dag_id}/{dag_run_with_log_filename.run_id}/{TASK_ID}/{try_number}.log" @@ -347,7 +355,7 @@ def test_get_logs_with_metadata_as_download_large_file(_, log_admin_client): ) response = log_admin_client.get(url) - data = response.data.decode() + data = response.text assert "1st line" in data assert "2nd line" in data assert "3rd line" in data @@ -367,12 +375,12 @@ def test_get_logs_with_metadata(log_admin_client, metadata, create_expected_log_ try_number, metadata, ), - data={"username": "test", "password": "test"}, + params={"username": "test", "password": "test"}, follow_redirects=True, ) assert 200 == response.status_code - data = response.data.decode() + data = response.text assert '"message":' in data assert '"metadata":' in data assert "Log for testing." in data @@ -390,12 +398,12 @@ def test_get_logs_with_invalid_metadata(log_admin_client): 1, metadata, ), - data={"username": "test", "password": "test"}, + params={"username": "test", "password": "test"}, follow_redirects=True, ) assert response.status_code == 400 - assert response.json == {"error": "Invalid JSON metadata"} + assert response.json() == {"error": "Invalid JSON metadata"} @unittest.mock.patch( @@ -412,12 +420,12 @@ def test_get_logs_with_metadata_for_removed_dag(_, log_admin_client): 1, "{}", ), - data={"username": "test", "password": "test"}, + params={"username": "test", "password": "test"}, follow_redirects=True, ) assert 200 == response.status_code - data = response.data.decode() + data = response.text assert '"message":' in data assert '"metadata":' in data assert "airflow log line" in data @@ -439,7 +447,7 @@ def test_get_logs_response_with_ti_equal_to_none(log_admin_client): ) response = log_admin_client.get(url) - data = response.json + data = response.json() assert "message" in data assert "error" in data assert "*** Task instance did not exist in the DB\n" == data["message"] @@ -463,9 +471,9 @@ def test_get_logs_with_json_response_format(log_admin_client, create_expected_lo response = log_admin_client.get(url) assert 200 == response.status_code - assert "message" in response.json - assert "metadata" in response.json - assert "Log for testing." in response.json["message"][0][1] + assert "message" in response.json() + assert "metadata" in response.json() + assert "Log for testing." in response.json()["message"][0][1] def test_get_logs_invalid_execution_data_format(log_admin_client): @@ -484,7 +492,7 @@ def test_get_logs_invalid_execution_data_format(log_admin_client): ) response = log_admin_client.get(url) assert response.status_code == 400 - assert response.json == { + assert response.json() == { "error": ( "Given execution date 'Tuesday February 27, 2024' could not be identified as a date. " "Example date format: 2015-11-16T14:34:15+00:00" @@ -511,7 +519,7 @@ def test_get_logs_for_handler_without_read_method(mock_reader, log_admin_client) response = log_admin_client.get(url) assert 200 == response.status_code - data = response.json + data = response.json() assert "message" in data assert "metadata" in data assert "Task log handler does not support read logs." in data["message"] @@ -529,8 +537,8 @@ def test_redirect_to_external_log_with_local_log_handler(log_admin_client, task_ try_number, ) response = log_admin_client.get(url) - assert 302 == response.status_code - assert "/home" == response.headers["Location"] + assert 200 == response.status_code + assert "/home" == response.url.path class _ExternalHandler(ExternalLoggingMixin): @@ -553,7 +561,7 @@ def supports_external_link(self) -> bool: new_callable=unittest.mock.PropertyMock, return_value=_ExternalHandler(), ) -def test_redirect_to_external_log_with_external_log_handler(_, log_admin_client): +def test_redirect_to_external_log_with_external_log_handler(_, flask_log_admin_client): url_template = "redirect_to_external_log?dag_id={}&task_id={}&execution_date={}&try_number={}" try_number = 1 url = url_template.format( @@ -562,6 +570,6 @@ def test_redirect_to_external_log_with_external_log_handler(_, log_admin_client) urllib.parse.quote_plus(DEFAULT_DATE.isoformat()), try_number, ) - response = log_admin_client.get(url) + response = flask_log_admin_client.get(url) assert 302 == response.status_code assert _ExternalHandler.EXTERNAL_URL == response.headers["Location"] diff --git a/tests/www/views/test_views_mount.py b/tests/www/views/test_views_mount.py index f0c052294b60a7..df6dd390014cc1 100644 --- a/tests/www/views/test_views_mount.py +++ b/tests/www/views/test_views_mount.py @@ -34,13 +34,13 @@ def factory(): return create_app(testing=True) app = factory() - app.config["WTF_CSRF_ENABLED"] = False + app.app.config["WTF_CSRF_ENABLED"] = False return app @pytest.fixture def client(app): - return werkzeug.test.Client(app, werkzeug.wrappers.response.Response) + return werkzeug.test.Client(app.app, werkzeug.wrappers.response.Response) def test_mount(client): diff --git a/tests/www/views/test_views_paused.py b/tests/www/views/test_views_paused.py index 46b0a3aa03f1af..e54fe0ad253ccb 100644 --- a/tests/www/views/test_views_paused.py +++ b/tests/www/views/test_views_paused.py @@ -34,17 +34,17 @@ def dags(create_dummy_dag): clear_db_dags() -def test_logging_pause_dag(admin_client, dags, session): +def test_logging_pause_dag(flask_admin_client, dags, session): dag, _ = dags # is_paused=false mean pause the dag - admin_client.post(f"/paused?is_paused=false&dag_id={dag.dag_id}", follow_redirects=True) + flask_admin_client.post(f"/paused?is_paused=false&dag_id={dag.dag_id}", follow_redirects=True) dag_query = session.query(Log).filter(Log.dag_id == dag.dag_id) assert '{"is_paused": true}' in dag_query.first().extra -def test_logging_unpause_dag(admin_client, dags, session): +def test_logging_unpause_dag(flask_admin_client, dags, session): _, paused_dag = dags # is_paused=true mean unpause the dag - admin_client.post(f"/paused?is_paused=true&dag_id={paused_dag.dag_id}", follow_redirects=True) + flask_admin_client.post(f"/paused?is_paused=true&dag_id={paused_dag.dag_id}", follow_redirects=True) dag_query = session.query(Log).filter(Log.dag_id == paused_dag.dag_id) assert '{"is_paused": false}' in dag_query.first().extra diff --git a/tests/www/views/test_views_pool.py b/tests/www/views/test_views_pool.py index 3fcacbbbf8bed2..4b38c5f32ac9e2 100644 --- a/tests/www/views/test_views_pool.py +++ b/tests/www/views/test_views_pool.py @@ -83,7 +83,7 @@ def test_list(app, admin_client, pool_factory): resp = admin_client.get("/pool/list/") # We should see this link - with app.test_request_context(): + with app.app.test_request_context(): description_tag = markupsafe.Markup("{description}").format( description="test-pool-description" ) diff --git a/tests/www/views/test_views_rate_limit.py b/tests/www/views/test_views_rate_limit.py index fa4502a2753159..a6af2fc797dea4 100644 --- a/tests/www/views/test_views_rate_limit.py +++ b/tests/www/views/test_views_rate_limit.py @@ -22,7 +22,7 @@ from airflow.www.app import create_app from tests.test_utils.config import conf_vars from tests.test_utils.decorators import dont_initialize_flask_app_submodules -from tests.test_utils.www import client_with_login +from tests.test_utils.www import client_with_login, flask_client_with_login pytestmark = pytest.mark.db_test @@ -47,20 +47,22 @@ def factory(): return create_app(testing=True) app = factory() - app.config["WTF_CSRF_ENABLED"] = False + app.app.config["WTF_CSRF_ENABLED"] = False return app def test_rate_limit_one(app_with_rate_limit_one): - client_with_login( + flask_client_with_login( app_with_rate_limit_one, expected_response_code=302, username="test_admin", password="test_admin" ) - client_with_login( - app_with_rate_limit_one, expected_response_code=429, username="test_admin", password="test_admin" - ) - client_with_login( - app_with_rate_limit_one, expected_response_code=429, username="test_admin", password="test_admin" - ) + from starlette.exceptions import HTTPException + + with pytest.raises(HTTPException) as ex: + flask_client_with_login(app_with_rate_limit_one, username="test_admin", password="test_admin") + assert ex.value.status_code == 429 + with pytest.raises(HTTPException) as ex: + flask_client_with_login(app_with_rate_limit_one, username="test_admin", password="test_admin") + assert ex.value.status_code == 429 def test_rate_limit_disabled(app): diff --git a/tests/www/views/test_views_rendered.py b/tests/www/views/test_views_rendered.py index 842f1010138d49..3d26cb9f9bf039 100644 --- a/tests/www/views/test_views_rendered.py +++ b/tests/www/views/test_views_rendered.py @@ -161,7 +161,7 @@ def _create_dag_run(*, execution_date, session): @pytest.fixture def patch_app(app, dag): - with mock.patch.object(app, "dag_bag") as mock_dag_bag: + with mock.patch.object(app.app, "dag_bag") as mock_dag_bag: mock_dag_bag.get_dag.return_value = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) yield app @@ -215,7 +215,7 @@ def test_user_defined_filter_and_macros_raise_error(admin_client, create_dag_run resp = admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 - resp_html: str = resp.data.decode("utf-8") + resp_html: str = resp.text assert "echo Hello Apache Airflow" not in resp_html assert ( "Webserver does not have access to User-defined Macros or Filters when " @@ -323,7 +323,7 @@ def test_rendered_task_detail_env_secret(patch_app, admin_client, request, env, Variable.set("plain_var", "banana") Variable.set("secret_var", "monkey") - dag: DAG = patch_app.dag_bag.get_dag("testdag") + dag: DAG = patch_app.app.dag_bag.get_dag("testdag") task_secret: BashOperator = dag.get_task(task_id="task1") task_secret.env = env date = quote_plus(str(DEFAULT_DATE)) diff --git a/tests/www/views/test_views_robots.py b/tests/www/views/test_views_robots.py index 03d8547c04d4b0..319fba3a7efcde 100644 --- a/tests/www/views/test_views_robots.py +++ b/tests/www/views/test_views_robots.py @@ -25,16 +25,16 @@ def test_robots(viewer_client): resp = viewer_client.get("/robots.txt", follow_redirects=True) - assert resp.data.decode("utf-8") == "User-agent: *\nDisallow: /\n" + assert resp.text == "User-agent: *\nDisallow: /\n" def test_deployment_warning_config(admin_client): warn_text = "webserver.warn_deployment_exposure" admin_client.get("/robots.txt", follow_redirects=True) resp = admin_client.get("", follow_redirects=True) - assert warn_text in resp.data.decode("utf-8") + assert warn_text in resp.text with conf_vars({("webserver", "warn_deployment_exposure"): "False"}): admin_client.get("/robots.txt", follow_redirects=True) resp = admin_client.get("/robots.txt", follow_redirects=True) - assert warn_text not in resp.data.decode("utf-8") + assert warn_text not in resp.text diff --git a/tests/www/views/test_views_task_norun.py b/tests/www/views/test_views_task_norun.py index a0709c4303d99d..7001f141bb11d9 100644 --- a/tests/www/views/test_views_task_norun.py +++ b/tests/www/views/test_views_task_norun.py @@ -41,7 +41,7 @@ def test_task_view_no_task_instance(admin_client): url = f"/task?task_id=runme_0&dag_id=example_bash_operator&execution_date={DEFAULT_VAL}" resp = admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 - html = resp.data.decode("utf-8") + html = resp.text assert "
No Task Instance Available
" in html assert "
Task Instance Attributes
" not in html @@ -50,5 +50,5 @@ def test_rendered_templates_view_no_task_instance(admin_client): url = f"/rendered-templates?task_id=runme_0&dag_id=example_bash_operator&execution_date={DEFAULT_VAL}" resp = admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 - html = resp.data.decode("utf-8") + html = resp.text assert "Rendered Template" in html diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index bc7ce29cec73fc..71de3699d6a0ab 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -18,11 +18,11 @@ from __future__ import annotations import html -import json import unittest.mock import urllib.parse from getpass import getuser +import httpx import pendulum import pytest import time_machine @@ -47,7 +47,12 @@ from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_runs, clear_db_xcom -from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login +from tests.test_utils.www import ( + check_content_in_response, + check_content_not_in_response, + client_with_login, + flask_client_with_login, +) pytestmark = pytest.mark.db_test @@ -68,7 +73,7 @@ def reset_dagruns(): @pytest.fixture(autouse=True) def init_dagruns(app, reset_dagruns): with time_machine.travel(DEFAULT_DATE, tick=False): - app.dag_bag.get_dag("example_bash_operator").create_dagrun( + app.app.dag_bag.get_dag("example_bash_operator").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -83,7 +88,7 @@ def init_dagruns(app, reset_dagruns): dag_id="example_bash_operator", execution_date=DEFAULT_DATE, ) - app.dag_bag.get_dag("example_subdag_operator").create_dagrun( + app.app.dag_bag.get_dag("example_subdag_operator").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -91,7 +96,7 @@ def init_dagruns(app, reset_dagruns): start_date=timezone.utcnow(), state=State.RUNNING, ) - app.dag_bag.get_dag("example_xcom").create_dagrun( + app.app.dag_bag.get_dag("example_xcom").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -99,7 +104,7 @@ def init_dagruns(app, reset_dagruns): start_date=timezone.utcnow(), state=State.RUNNING, ) - app.dag_bag.get_dag("latest_only").create_dagrun( + app.app.dag_bag.get_dag("latest_only").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -107,7 +112,7 @@ def init_dagruns(app, reset_dagruns): start_date=timezone.utcnow(), state=State.RUNNING, ) - app.dag_bag.get_dag("example_task_group").create_dagrun( + app.app.dag_bag.get_dag("example_task_group").create_dagrun( run_id=DEFAULT_DAGRUN, run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, @@ -121,9 +126,9 @@ def init_dagruns(app, reset_dagruns): @pytest.fixture(scope="module") -def client_ti_without_dag_edit(app): +def flask_client_ti_without_dag_edit(app): create_user( - app, + app.app, username="all_ti_permissions_except_dag_edit", role_name="all_ti_permissions_except_dag_edit", permissions=[ @@ -138,14 +143,14 @@ def client_ti_without_dag_edit(app): ], ) - yield client_with_login( + yield flask_client_with_login( app, username="all_ti_permissions_except_dag_edit", password="all_ti_permissions_except_dag_edit", ) - delete_user(app, username="all_ti_permissions_except_dag_edit") # type: ignore - delete_roles(app) + delete_user(app.app, username="all_ti_permissions_except_dag_edit") # type: ignore + delete_roles(app.app) @pytest.mark.parametrize( @@ -347,7 +352,7 @@ def test_xcom_return_value_is_not_bytes(admin_client): def test_rendered_task_view(admin_client): url = f"task?task_id=runme_0&dag_id=example_bash_operator&execution_date={DEFAULT_VAL}" resp = admin_client.get(url, follow_redirects=True) - resp_html = resp.data.decode("utf-8") + resp_html = resp.text assert resp.status_code == 200 assert "_try_number" not in resp_html assert "try_number" in resp_html @@ -368,7 +373,7 @@ def test_rendered_k8s_without_k8s(admin_client): def test_tree_trigger_origin_tree_view(app, admin_client): - app.dag_bag.get_dag("test_tree_view").create_dagrun( + app.app.dag_bag.get_dag("test_tree_view").create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, data_interval=(DEFAULT_DATE, DEFAULT_DATE), @@ -379,12 +384,12 @@ def test_tree_trigger_origin_tree_view(app, admin_client): url = "tree?dag_id=test_tree_view" resp = admin_client.get(url, follow_redirects=True) params = {"origin": "/dags/test_tree_view/grid"} - href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params))}" + href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params, safe='/:?'))}" check_content_in_response(href, resp) def test_graph_trigger_origin_grid_view(app, admin_client): - app.dag_bag.get_dag("test_tree_view").create_dagrun( + app.app.dag_bag.get_dag("test_tree_view").create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, data_interval=(DEFAULT_DATE, DEFAULT_DATE), @@ -395,12 +400,12 @@ def test_graph_trigger_origin_grid_view(app, admin_client): url = "/dags/test_tree_view/graph" resp = admin_client.get(url, follow_redirects=True) params = {"origin": "/dags/test_tree_view/grid?tab=graph"} - href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params))}" + href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params, safe='/:?'))}" check_content_in_response(href, resp) def test_gantt_trigger_origin_grid_view(app, admin_client): - app.dag_bag.get_dag("test_tree_view").create_dagrun( + app.app.dag_bag.get_dag("test_tree_view").create_dagrun( run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, data_interval=(DEFAULT_DATE, DEFAULT_DATE), @@ -411,7 +416,7 @@ def test_gantt_trigger_origin_grid_view(app, admin_client): url = "/dags/test_tree_view/gantt" resp = admin_client.get(url, follow_redirects=True) params = {"origin": "/dags/test_tree_view/grid?tab=gantt"} - href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params))}" + href = f"/dags/test_tree_view/trigger?{html.escape(urllib.parse.urlencode(params, safe='/:?'))}" check_content_in_response(href, resp) @@ -419,16 +424,15 @@ def test_graph_view_without_dag_permission(app, one_dag_perm_user_client): url = "/dags/example_bash_operator/graph" resp = one_dag_perm_user_client.get(url, follow_redirects=True) assert resp.status_code == 200 - assert ( - resp.request.url - == "http://localhost/dags/example_bash_operator/grid?tab=graph&dag_run_id=TEST_DAGRUN" + assert resp.request.url == httpx.URL( + "http://testserver/dags/example_bash_operator/grid?tab=graph&dag_run_id=TEST_DAGRUN" ) check_content_in_response("example_bash_operator", resp) url = "/dags/example_xcom/graph" resp = one_dag_perm_user_client.get(url, follow_redirects=True) assert resp.status_code == 200 - assert resp.request.url == "http://localhost/home" + assert resp.request.url == httpx.URL("http://testserver/home") check_content_in_response("Access is Denied", resp) @@ -442,7 +446,7 @@ def test_last_dagruns_success_when_selecting_dags(admin_client): "last_dagruns", data={"dag_ids": ["example_subdag_operator"]}, follow_redirects=True ) assert resp.status_code == 200 - stats = json.loads(resp.data.decode("utf-8")) + stats = resp.text assert "example_bash_operator" not in stats assert "example_subdag_operator" in stats @@ -452,7 +456,7 @@ def test_last_dagruns_success_when_selecting_dags(admin_client): data={"dag_ids": ["example_subdag_operator", "example_bash_operator"]}, follow_redirects=True, ) - stats = json.loads(resp.data.decode("utf-8")) + stats = resp.text assert "example_bash_operator" in stats assert "example_subdag_operator" in stats check_content_not_in_response("example_xcom", resp) @@ -608,18 +612,20 @@ def new_dag_to_delete(): dag = DAG("new_dag_to_delete", is_paused_upon_creation=True) session = settings.Session() dag.sync_to_db(session=session) + session.commit() + session.close() return dag @pytest.fixture def per_dag_perm_user_client(app, new_dag_to_delete): - sm = app.appbuilder.sm + sm = app.app.appbuilder.sm perm = f"{permissions.RESOURCE_DAG_PREFIX}{new_dag_to_delete.dag_id}" sm.create_permission(permissions.ACTION_CAN_DELETE, perm) create_user( - app, + app.app, username="test_user_per_dag_perms", role_name="User with some perms", permissions=[ @@ -637,21 +643,21 @@ def per_dag_perm_user_client(app, new_dag_to_delete): password="test_user_per_dag_perms", ) - delete_user(app, username="test_user_per_dag_perms") # type: ignore - delete_roles(app) + delete_user(app.app, username="test_user_per_dag_perms") # type: ignore + delete_roles(app.app) @pytest.fixture def one_dag_perm_user_client(app): username = "test_user_one_dag_perm" dag_id = "example_bash_operator" - sm = app.appbuilder.sm + sm = app.app.appbuilder.sm perm = f"{permissions.RESOURCE_DAG_PREFIX}{dag_id}" sm.create_permission(permissions.ACTION_CAN_READ, perm) create_user( - app, + app.app, username=username, role_name="User with permission to access only one dag", permissions=[ @@ -671,8 +677,8 @@ def one_dag_perm_user_client(app): password=username, ) - delete_user(app, username=username) # type: ignore - delete_roles(app) + delete_user(app.app, username=username) # type: ignore + delete_roles(app.app) def test_delete_just_dag_per_dag_permissions(new_dag_to_delete, per_dag_perm_user_client): @@ -770,7 +776,7 @@ def _get_appbuilder_pk_string(model_view_cls, instance) -> str: return model_view_cls._serialize_pk_if_composite(model_view_cls, pk_value) -def test_task_instance_delete(session, admin_client, create_task_instance): +def test_task_instance_delete(session, flask_admin_client, create_task_instance): task_instance_to_delete = create_task_instance( task_id="test_task_instance_delete", execution_date=timezone.utcnow(), @@ -780,11 +786,13 @@ def test_task_instance_delete(session, admin_client, create_task_instance): task_id = task_instance_to_delete.task_id assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 1 - admin_client.post(f"/taskinstance/delete/{composite_key}", follow_redirects=True) + flask_admin_client.post(f"/taskinstance/delete/{composite_key}", follow_redirects=True) assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 0 -def test_task_instance_delete_permission_denied(session, client_ti_without_dag_edit, create_task_instance): +def test_task_instance_delete_permission_denied( + session, flask_client_ti_without_dag_edit, create_task_instance +): task_instance_to_delete = create_task_instance( task_id="test_task_instance_delete_permission_denied", execution_date=timezone.utcnow(), @@ -792,11 +800,14 @@ def test_task_instance_delete_permission_denied(session, client_ti_without_dag_e session=session, ) session.commit() + session.close() composite_key = _get_appbuilder_pk_string(TaskInstanceModelView, task_instance_to_delete) task_id = task_instance_to_delete.task_id assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 1 - resp = client_ti_without_dag_edit.post(f"/taskinstance/delete/{composite_key}", follow_redirects=True) + resp = flask_client_ti_without_dag_edit.post( + f"/taskinstance/delete/{composite_key}", follow_redirects=True + ) check_content_in_response("Access is Denied", resp) assert session.query(TaskInstance).filter(TaskInstance.task_id == task_id).count() == 1 @@ -984,7 +995,9 @@ def test_action_muldelete_task_instance(session, admin_client, task_search_tuple for task in tasks_to_delete ] session.bulk_save_objects(trs) + session.commit() session.flush() + session.close() # run the function to test resp = admin_client.post( @@ -1009,7 +1022,7 @@ def test_action_muldelete_task_instance(session, admin_client, task_search_tuple assert session.query(TaskReschedule).count() == 0 -def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, admin_client): +def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, flask_admin_client): """Test that the graph view doesn't fail on a recursion error.""" from airflow.models.baseoperator import chain @@ -1022,10 +1035,10 @@ def test_graph_view_doesnt_fail_on_recursion_error(app, dag_maker, admin_client) for i in range(1, 1000 + 1) ] chain(*tasks) - with unittest.mock.patch.object(app, "dag_bag") as mocked_dag_bag: + with unittest.mock.patch.object(app.app, "dag_bag") as mocked_dag_bag: mocked_dag_bag.get_dag.return_value = dag url = f"/dags/{dag.dag_id}/graph" - resp = admin_client.get(url, follow_redirects=True) + resp = flask_admin_client.get(url, follow_redirects=True) assert resp.status_code == 200 @@ -1036,7 +1049,7 @@ def test_task_instances(admin_client): follow_redirects=True, ) assert resp.status_code == 200 - assert resp.json == { + assert resp.json() == { "also_run_this": { "custom_operator_name": None, "dag_id": "example_bash_operator", diff --git a/tests/www/views/test_views_trigger_dag.py b/tests/www/views/test_views_trigger_dag.py index c53213c3e68ea8..3f3cba44fc1708 100644 --- a/tests/www/views/test_views_trigger_dag.py +++ b/tests/www/views/test_views_trigger_dag.py @@ -48,8 +48,8 @@ def initialize_one_dag(): def test_trigger_dag_button_normal_exist(admin_client): resp = admin_client.get("/", follow_redirects=True) - assert "/dags/example_bash_operator/trigger" in resp.data.decode("utf-8") - assert "return confirmDeleteDag(this, 'example_bash_operator')" in resp.data.decode("utf-8") + assert "/dags/example_bash_operator/trigger" in resp.text + assert "return confirmDeleteDag(this, 'example_bash_operator')" in resp.text # test trigger button with and without run_id @@ -174,10 +174,10 @@ def test_trigger_dag_form(admin_client): ("%2Fgraph%3Fdag_id%3Dexample_bash_operator", "http://localhost/graph?dag_id=example_bash_operator"), ], ) -def test_trigger_dag_form_origin_url(admin_client, test_origin, expected_origin): +def test_trigger_dag_form_origin_url(admin_flask_client, test_origin, expected_origin): test_dag_id = "example_bash_operator" - resp = admin_client.get(f"dags/{test_dag_id}/trigger?origin={test_origin}") + resp = admin_flask_client.get(f"dags/{test_dag_id}/trigger?origin={test_origin}") check_content_in_response(f'Cancel', resp) @@ -210,7 +210,7 @@ def test_trigger_dag_params_conf(admin_client, request_conf, expected_conf): check_content_in_response(str(expected_conf[key]), resp) -def test_trigger_dag_params_render(admin_client, dag_maker, session, app, monkeypatch): +def test_trigger_dag_params_render(admin_flask_client, dag_maker, session, app, monkeypatch): """ Test that textarea in Trigger DAG UI is pre-populated with param value set in DAG. @@ -236,8 +236,8 @@ def test_trigger_dag_params_render(admin_client, dag_maker, session, app, monkey with dag_maker(dag_id=DAG_ID, serialized=True, session=session, params={"accounts": param}): EmptyOperator(task_id="task1") - m.setattr(app, "dag_bag", dag_maker.dagbag) - resp = admin_client.get(f"dags/{DAG_ID}/trigger") + m.setattr(app.app, "dag_bag", dag_maker.dagbag) + resp = admin_flask_client.get(f"dags/{DAG_ID}/trigger") check_content_in_response( f'', @@ -246,7 +246,7 @@ def test_trigger_dag_params_render(admin_client, dag_maker, session, app, monkey @pytest.mark.parametrize("allow_html", [False, True]) -def test_trigger_dag_html_allow(admin_client, dag_maker, session, app, monkeypatch, allow_html): +def test_trigger_dag_html_allow(admin_flask_client, dag_maker, session, app, monkeypatch, allow_html): """ Test that HTML is escaped per default in description. """ @@ -277,8 +277,8 @@ def test_trigger_dag_html_allow(admin_client, dag_maker, session, app, monkeypat ): EmptyOperator(task_id="task1") - m.setattr(app, "dag_bag", dag_maker.dagbag) - resp = admin_client.get(f"dags/{DAG_ID}/trigger") + m.setattr(app.app, "dag_bag", dag_maker.dagbag) + resp = admin_flask_client.get(f"dags/{DAG_ID}/trigger") if expect_escape: check_content_in_response(escape(HTML_DESCRIPTION1), resp) @@ -309,7 +309,7 @@ def test_viewer_cant_trigger_dag(app): Test that the test_viewer user can't trigger DAGs. """ with create_test_client( - app, + app.app, user_name="test_user", role_name="test_role", permissions=[ @@ -324,7 +324,7 @@ def test_viewer_cant_trigger_dag(app): assert "Access is Denied" in response_data -def test_trigger_dag_params_array_value_none_render(admin_client, dag_maker, session, app, monkeypatch): +def test_trigger_dag_params_array_value_none_render(admin_flask_client, dag_maker, session, app, monkeypatch): """ Test that textarea in Trigger DAG UI is pre-populated with param value None and type ["null", "array"] set in DAG. @@ -341,8 +341,8 @@ def test_trigger_dag_params_array_value_none_render(admin_client, dag_maker, ses with dag_maker(dag_id=DAG_ID, serialized=True, session=session, params={"dag_param": param}): EmptyOperator(task_id="task1") - m.setattr(app, "dag_bag", dag_maker.dagbag) - resp = admin_client.get(f"dags/{DAG_ID}/trigger") + m.setattr(app.app, "dag_bag", dag_maker.dagbag) + resp = admin_flask_client.get(f"dags/{DAG_ID}/trigger") check_content_in_response( f'', diff --git a/tests/www/views/test_views_variable.py b/tests/www/views/test_views_variable.py index fcdad2bdb0bdd6..494af8a519af75 100644 --- a/tests/www/views/test_views_variable.py +++ b/tests/www/views/test_views_variable.py @@ -52,7 +52,7 @@ def clear_variables(): def user_variable_reader(app): """Create User that can only read variables""" return create_user( - app, + app.app, username="user_variable_reader", role_name="role_variable_reader", permissions=[ @@ -103,7 +103,7 @@ def test_import_variables_no_file(admin_client): check_content_in_response("Missing file or syntax error.", resp) -def test_import_variables_failed(session, admin_client): +def test_import_variables_failed(session, admin_flask_client): content = '{"str_key": "str_value"}' with mock.patch("airflow.models.Variable.set") as set_mock: @@ -112,32 +112,32 @@ def test_import_variables_failed(session, admin_client): bytes_content = BytesIO(bytes(content, encoding="utf-8")) - resp = admin_client.post( + resp = admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json")}, follow_redirects=True ) check_content_in_response("1 variable(s) failed to be updated.", resp) -def test_import_variables_success(session, admin_client): +def test_import_variables_success(session, admin_flask_client): assert session.query(Variable).count() == 0 content = '{"str_key": "str_value", "int_key": 60, "list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}' bytes_content = BytesIO(bytes(content, encoding="utf-8")) - resp = admin_client.post( + resp = admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json")}, follow_redirects=True ) check_content_in_response("4 variable(s) successfully updated.", resp) _check_last_log(session, dag_id=None, event="variables.varimport", execution_date=None) -def test_import_variables_override_existing_variables_if_set(session, admin_client, caplog): +def test_import_variables_override_existing_variables_if_set(session, admin_flask_client, caplog): assert session.query(Variable).count() == 0 Variable.set("str_key", "str_value") content = '{"str_key": "str_value", "int_key": 60}' # str_key already exists bytes_content = BytesIO(bytes(content, encoding="utf-8")) - resp = admin_client.post( + resp = admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json"), "action_if_exist": "overwrite"}, follow_redirects=True, @@ -146,13 +146,13 @@ def test_import_variables_override_existing_variables_if_set(session, admin_clie _check_last_log(session, dag_id=None, event="variables.varimport", execution_date=None) -def test_import_variables_skips_update_if_set(session, admin_client, caplog): +def test_import_variables_skips_update_if_set(session, admin_flask_client, caplog): assert session.query(Variable).count() == 0 Variable.set("str_key", "str_value") content = '{"str_key": "str_value", "int_key": 60}' # str_key already exists bytes_content = BytesIO(bytes(content, encoding="utf-8")) - resp = admin_client.post( + resp = admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json"), "action_if_exists": "skip"}, follow_redirects=True, @@ -166,13 +166,13 @@ def test_import_variables_skips_update_if_set(session, admin_client, caplog): assert "Variable: str_key already exists, skipping." in caplog.text -def test_import_variables_fails_if_action_if_exists_is_fail(session, admin_client, caplog): +def test_import_variables_fails_if_action_if_exists_is_fail(session, admin_flask_client, caplog): assert session.query(Variable).count() == 0 Variable.set("str_key", "str_value") content = '{"str_key": "str_value", "int_key": 60}' # str_key already exists bytes_content = BytesIO(bytes(content, encoding="utf-8")) - admin_client.post( + admin_flask_client.post( "/variable/varimport", data={"file": (bytes_content, "test.json"), "action_if_exists": "fail"}, follow_redirects=True, @@ -244,7 +244,7 @@ def test_action_export(admin_client, variable): assert resp.status_code == 200 assert resp.headers["Content-Type"] == "application/json; charset=utf-8" assert resp.headers["Content-Disposition"] == "attachment; filename=variables.json" - assert resp.json == {"test_key": "text_val"} + assert resp.json() == {"test_key": "text_val"} def test_action_muldelete(session, admin_client, variable):