From 54ac4594d13ee347ce930e28a97a98541a6d4a70 Mon Sep 17 00:00:00 2001 From: Ulada Zakharava Date: Thu, 23 Nov 2023 14:33:45 +0000 Subject: [PATCH] Update methods to use Connexion v3, Ginucorn command and encoding --- airflow/api_connexion/exceptions.py | 59 ++++---- airflow/auth/managers/base_auth_manager.py | 4 +- airflow/cli/commands/internal_api_command.py | 15 +- airflow/cli/commands/webserver_command.py | 8 +- .../fab/auth_manager/fab_auth_manager.py | 23 ++-- airflow/utils/json.py | 2 +- airflow/www/app.py | 49 +++++-- airflow/www/extensions/init_views.py | 130 ++++++++---------- setup.cfg | 2 +- 9 files changed, 150 insertions(+), 142 deletions(-) diff --git a/airflow/api_connexion/exceptions.py b/airflow/api_connexion/exceptions.py index 75d9261ef6d444..154cc3d599bed5 100644 --- a/airflow/api_connexion/exceptions.py +++ b/airflow/api_connexion/exceptions.py @@ -17,16 +17,13 @@ from __future__ import annotations from http import HTTPStatus -from typing import TYPE_CHECKING, Any +from typing import Any -import werkzeug -from connexion import FlaskApi, ProblemException, problem +from connexion import ProblemException, problem +from connexion.lifecycle import ConnexionRequest, ConnexionResponse from airflow.utils.docs import get_docs_url -if TYPE_CHECKING: - import flask - doc_link = get_docs_url("stable-rest-api-ref.html") EXCEPTIONS_LINK_MAP = { @@ -40,37 +37,29 @@ } -def common_error_handler(exception: BaseException) -> flask.Response: +def problem_error_handler(_request: ConnexionRequest, exception: ProblemException) -> ConnexionResponse: """Use to capture connexion exceptions and add link to the type field.""" - if isinstance(exception, ProblemException): - link = EXCEPTIONS_LINK_MAP.get(exception.status) - if link: - response = problem( - status=exception.status, - title=exception.title, - detail=exception.detail, - type=link, - instance=exception.instance, - headers=exception.headers, - ext=exception.ext, - ) - else: - response = problem( - status=exception.status, - title=exception.title, - detail=exception.detail, - type=exception.type, - instance=exception.instance, - headers=exception.headers, - ext=exception.ext, - ) + link = EXCEPTIONS_LINK_MAP.get(exception.status) + if link: + return problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=link, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) else: - if not isinstance(exception, werkzeug.exceptions.HTTPException): - exception = werkzeug.exceptions.InternalServerError() - - response = problem(title=exception.name, detail=exception.description, status=exception.code) - - return FlaskApi.get_response(response) + return problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=exception.type, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) class NotFound(ProblemException): diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 466e7287741124..517eaded876720 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -32,7 +32,7 @@ from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: - from flask import Blueprint + import connexion from sqlalchemy.orm import Session from airflow.auth.managers.models.base_user import BaseUser @@ -79,7 +79,7 @@ def get_cli_commands() -> list[CLICommand]: """ return [] - def get_api_endpoints(self) -> None | Blueprint: + def get_api_endpoints(self, connexion_app: connexion.FlaskApp) -> connexion.apps.flask.FlaskApi: """Return API endpoint(s) definition for the auth manager.""" return None diff --git a/airflow/cli/commands/internal_api_command.py b/airflow/cli/commands/internal_api_command.py index dd938015378a79..86d9cc575b75ce 100644 --- a/airflow/cli/commands/internal_api_command.py +++ b/airflow/cli/commands/internal_api_command.py @@ -28,8 +28,8 @@ from tempfile import gettempdir from time import sleep +import connexion import psutil -from flask import Flask from flask_appbuilder import SQLA from flask_caching import Cache from flask_wtf.csrf import CSRFProtect @@ -54,7 +54,7 @@ from airflow.www.extensions.init_views import init_api_internal, init_error_handlers log = logging.getLogger(__name__) -app: Flask | None = None +app: connexion.FlaskApp | None = None @cli_utils.action_cli @@ -73,8 +73,8 @@ def internal_api(args): log.info(f"Starting the Internal API server on port {args.port} and host {args.hostname}.") app = create_app(testing=conf.getboolean("core", "unit_test_mode")) app.run( - debug=True, # nosec - use_reloader=not app.config["TESTING"], + log_level="debug", + # reload=not app.app.config["TESTING"], port=args.port, host=args.hostname, ) @@ -101,7 +101,7 @@ def internal_api(args): "--workers", str(num_workers), "--worker-class", - str(args.workerclass), + "uvicorn.workers.UvicornWorker", "--timeout", str(worker_timeout), "--bind", @@ -195,7 +195,8 @@ def start_and_monitor_gunicorn(args): def create_app(config=None, testing=False): """Create a new instance of Airflow Internal API app.""" - flask_app = Flask(__name__) + connexion_app = connexion.FlaskApp(__name__) + flask_app = connexion_app.app flask_app.config["APP_NAME"] = "Airflow Internal API" flask_app.config["TESTING"] = testing @@ -240,7 +241,7 @@ def create_app(config=None, testing=False): with flask_app.app_context(): init_error_handlers(flask_app) - init_api_internal(flask_app, standalone_api=True) + init_api_internal(connexion_app, standalone_api=True) init_jinja_globals(flask_app) init_xframe_protection(flask_app) diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py index 4cb7939fd7fb5d..abe6a32c60977a 100644 --- a/airflow/cli/commands/webserver_command.py +++ b/airflow/cli/commands/webserver_command.py @@ -355,11 +355,11 @@ def webserver(args): print(f"Starting the web server on port {args.port} and host {args.hostname}.") app = create_app(testing=conf.getboolean("core", "unit_test_mode")) app.run( - debug=True, - use_reloader=not app.config["TESTING"], + log_level="debug", port=args.port, host=args.hostname, - ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None, + ssl_keyfile=ssl_key if ssl_cert and ssl_key else None, + ssl_certfile=ssl_cert if ssl_cert and ssl_key else None, ) else: print( @@ -383,7 +383,7 @@ def webserver(args): "--workers", str(num_workers), "--worker-class", - str(args.workerclass), + "uvicorn.workers.UvicornWorker", "--timeout", str(worker_timeout), "--bind", diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py b/airflow/providers/fab/auth_manager/fab_auth_manager.py index 201dc050f9ad4f..5a39de53da301d 100644 --- a/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -22,8 +22,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Container -from connexion import FlaskApi -from flask import Blueprint, url_for +import connexion +from connexion.options import SwaggerUIOptions +from flask import url_for from sqlalchemy import select from sqlalchemy.orm import Session, joinedload @@ -83,7 +84,7 @@ ) from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.yaml import safe_load -from airflow.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver +from airflow.www.extensions.init_views import _LazyResolver if TYPE_CHECKING: from airflow.auth.managers.models.base_user import BaseUser @@ -147,21 +148,23 @@ def get_cli_commands() -> list[CLICommand]: SYNC_PERM_COMMAND, # not in a command group ] - def get_api_endpoints(self) -> None | Blueprint: + def get_api_endpoints(self, connexion_app: connexion.FlaskApp) -> connexion.apps.flask.FlaskApi: folder = Path(__file__).parents[0].resolve() # this is airflow/auth/managers/fab/ with folder.joinpath("openapi", "v1.yaml").open() as f: specification = safe_load(f) - return FlaskApi( + + swagger_ui_options = SwaggerUIOptions( + swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True), + ) + + return connexion_app.add_api( specification=specification, resolver=_LazyResolver(), base_path="/auth/fab/v1", - options={ - "swagger_ui": conf.getboolean("webserver", "enable_swagger_ui", fallback=True), - }, + swagger_ui_options=swagger_ui_options, strict_validation=True, validate_responses=True, - validator_map={"body": _CustomErrorRequestBodyValidator}, - ).blueprint + ) def get_user_display_name(self) -> str: """Return the user's display name associated to the user in session.""" diff --git a/airflow/utils/json.py b/airflow/utils/json.py index 4d89e340c1cd48..92d691be194e83 100644 --- a/airflow/utils/json.py +++ b/airflow/utils/json.py @@ -37,7 +37,7 @@ class AirflowJsonProvider(JSONProvider): def dumps(self, obj, **kwargs): kwargs.setdefault("ensure_ascii", self.ensure_ascii) kwargs.setdefault("sort_keys", self.sort_keys) - return json.dumps(obj, **kwargs, cls=WebEncoder) + return json.dumps(obj, cls=WebEncoder) def loads(self, s: str | bytes, **kwargs): return json.loads(s, **kwargs) diff --git a/airflow/www/app.py b/airflow/www/app.py index b8be40a4210da6..2e85bd8c821729 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -20,11 +20,14 @@ import warnings from datetime import timedelta -from flask import Flask +import connexion +from connexion.exceptions import ConnexionException +from flask import jsonify from flask_appbuilder import SQLA from flask_wtf.csrf import CSRFProtect from markupsafe import Markup from sqlalchemy.engine.url import make_url +from starlette.middleware.cors import CORSMiddleware from airflow import settings from airflow.api_internal.internal_api_call import InternalApiConfig @@ -61,16 +64,44 @@ ) from airflow.www.extensions.init_wsgi_middlewares import init_wsgi_middleware -app: Flask | None = None +app: connexion.FlaskApp | None = None # Initializes at the module level, so plugins can access it. # See: /docs/plugins.rst csrf = CSRFProtect() +def custom_jsonify(obj, **kwargs): + # Check if cls key is already present + if "cls" in kwargs: + cls = kwargs.get("cls") + if cls != AirflowJsonProvider: + raise Exception(f"Conflict: cls: {cls} already set") + + # Set cls to our custom provider + kwargs["cls"] = AirflowJsonProvider + + try: + return jsonify(obj, **kwargs) + except ConnexionException: + raise ConnexionException("Unable to serialize data") + + def create_app(config=None, testing=False): """Create a new instance of Airflow WWW app.""" - flask_app = Flask(__name__) + connexion_app = connexion.FlaskApp(__name__) + + connexion_app.add_middleware( + CORSMiddleware, + connexion.middleware.MiddlewarePosition.BEFORE_ROUTING, + allow_origins=conf.get("api", "access_control_allow_origins"), + allow_credentials=True, + allow_methods=conf.get("api", "access_control_allow_methods"), + allow_headers=conf.get("api", "access_control_allow_headers"), + ) + + connexion_app.jsonify = custom_jsonify + flask_app = connexion_app.app flask_app.secret_key = conf.get("webserver", "SECRET_KEY") flask_app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(minutes=settings.get_session_lifetime_config()) @@ -156,14 +187,16 @@ def create_app(config=None, testing=False): init_appbuilder_links(flask_app) init_plugins(flask_app) init_error_handlers(flask_app) - init_api_connexion(flask_app) + init_api_connexion(connexion_app) if conf.getboolean("webserver", "run_internal_api", fallback=False): if not _ENABLE_AIP_44: raise RuntimeError("The AIP_44 is not enabled so you cannot use it.") - init_api_internal(flask_app) + init_api_internal(connexion_app) init_api_experimental(flask_app) - init_api_auth_provider(flask_app) - init_api_error_handlers(flask_app) # needs to be after all api inits to let them add their path first + init_api_auth_provider(connexion_app) + init_api_error_handlers( + connexion_app + ) # needs to be after all api inits to let them add their path first get_auth_manager().init() @@ -171,7 +204,7 @@ def create_app(config=None, testing=False): init_xframe_protection(flask_app) init_airflow_session_interface(flask_app) init_check_user_active(flask_app) - return flask_app + return connexion_app def cached_app(config=None, testing=False): diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py index 92f2145764c69f..6030cf718d96c2 100644 --- a/airflow/www/extensions/init_views.py +++ b/airflow/www/extensions/init_views.py @@ -23,12 +23,14 @@ from pathlib import Path from typing import TYPE_CHECKING -from connexion import FlaskApi, ProblemException, Resolver -from connexion.decorators.validation import RequestBodyValidator -from connexion.exceptions import BadRequestProblem -from flask import request - -from airflow.api_connexion.exceptions import common_error_handler +import connexion +import starlette.exceptions +from connexion import ProblemException, Resolver +from connexion.lifecycle import ConnexionRequest, ConnexionResponse +from connexion.options import SwaggerUIOptions +from connexion.problem import problem + +from airflow.api_connexion.exceptions import problem_error_handler from airflow.configuration import conf from airflow.exceptions import RemovedInAirflow3Warning from airflow.security import permissions @@ -167,26 +169,6 @@ def init_error_handlers(app: Flask): from airflow.www import views app.register_error_handler(500, views.show_traceback) - app.register_error_handler(404, views.not_found) - - -def set_cors_headers_on_response(response): - """Add response headers.""" - allow_headers = conf.get("api", "access_control_allow_headers") - allow_methods = conf.get("api", "access_control_allow_methods") - allow_origins = conf.get("api", "access_control_allow_origins") - if allow_headers: - response.headers["Access-Control-Allow-Headers"] = allow_headers - if allow_methods: - response.headers["Access-Control-Allow-Methods"] = allow_methods - if allow_origins == "*": - response.headers["Access-Control-Allow-Origin"] = "*" - elif allow_origins: - allowed_origins = allow_origins.split(" ") - origin = request.environ.get("HTTP_ORIGIN", allowed_origins[0]) - if origin in allowed_origins: - response.headers["Access-Control-Allow-Origin"] = origin - return response class _LazyResolution: @@ -220,74 +202,71 @@ def resolve(self, operation): return _LazyResolution(self.resolve_function_from_operation_id, operation_id) -class _CustomErrorRequestBodyValidator(RequestBodyValidator): - """Custom request body validator that overrides error messages. - - By default, Connextion emits a very generic *None is not of type 'object'* - error when receiving an empty request body (with the view specifying the - body as non-nullable). We overrides it to provide a more useful message. - """ - - def validate_schema(self, data, url): - if not self.is_null_value_valid and data is None: - raise BadRequestProblem(detail="Request body must not be empty") - return super().validate_schema(data, url) - - base_paths: list[str] = [] # contains the list of base paths that have api endpoints -def init_api_error_handlers(app: Flask) -> None: +def init_api_error_handlers(connexion_app: connexion.FlaskApp) -> None: """Add error handlers for 404 and 405 errors for existing API paths.""" from airflow.www import views - @app.errorhandler(404) - def _handle_api_not_found(ex): - if any([request.path.startswith(p) for p in base_paths]): + def _handle_http_exception(ex: starlette.exceptions.HTTPException) -> ConnexionResponse: + return problem( + title=connexion.http_facts.HTTP_STATUS_CODES.get(ex.status_code), + detail=ex.detail, + status=ex.status_code, + ) + + def _handle_api_not_found( + request: ConnexionRequest, ex: starlette.exceptions.HTTPException + ) -> ConnexionResponse: + if any([request.url.path.startswith(p) for p in base_paths]): # 404 errors are never handled on the blueprint level # unless raised from a view func so actual 404 errors, # i.e. "no route for it" defined, need to be handled # here on the application level - return common_error_handler(ex) + return _handle_http_exception(ex) else: return views.not_found(ex) - @app.errorhandler(405) - def _handle_method_not_allowed(ex): - if any([request.path.startswith(p) for p in base_paths]): - return common_error_handler(ex) + def _handle_method_not_allowed( + request: ConnexionRequest, ex: starlette.exceptions.HTTPException + ) -> ConnexionResponse: + if any([request.url.path.startswith(p) for p in base_paths]): + return _handle_http_exception(ex) else: return views.method_not_allowed(ex) - app.register_error_handler(ProblemException, common_error_handler) + connexion_app.add_error_handler(404, _handle_api_not_found) + connexion_app.add_error_handler(405, _handle_method_not_allowed) + connexion_app.add_error_handler(ProblemException, problem_error_handler) -def init_api_connexion(app: Flask) -> None: +def init_api_connexion(connexion_app: connexion.FlaskApp) -> None: """Initialize Stable API.""" base_path = "/api/v1" base_paths.append(base_path) with ROOT_APP_DIR.joinpath("api_connexion", "openapi", "v1.yaml").open() as f: specification = safe_load(f) - api_bp = FlaskApi( + swagger_ui_options = SwaggerUIOptions( + swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True), + swagger_ui_path=os.fspath(ROOT_APP_DIR.joinpath("www", "static", "dist", "swagger-ui")), + ) + + connexion_app.add_api( specification=specification, resolver=_LazyResolver(), base_path=base_path, - options={ - "swagger_ui": conf.getboolean("webserver", "enable_swagger_ui", fallback=True), - "swagger_path": os.fspath(ROOT_APP_DIR.joinpath("www", "static", "dist", "swagger-ui")), - }, + swagger_ui_options=swagger_ui_options, strict_validation=True, validate_responses=True, - validator_map={"body": _CustomErrorRequestBodyValidator}, - ).blueprint - api_bp.after_request(set_cors_headers_on_response) + ) - app.register_blueprint(api_bp) - app.extensions["csrf"].exempt(api_bp) + # flask_app = connexion_app.app + # flask_app.extensions["csrf"].exempt(api_bp) -def init_api_internal(app: Flask, standalone_api: bool = False) -> None: +def init_api_internal(connexion_app: connexion.FlaskApp, standalone_api: bool = False) -> None: """Initialize Internal API.""" if not standalone_api and not conf.getboolean("webserver", "run_internal_api", fallback=False): return @@ -295,18 +274,20 @@ def init_api_internal(app: Flask, standalone_api: bool = False) -> None: base_paths.append("/internal_api/v1") with ROOT_APP_DIR.joinpath("api_internal", "openapi", "internal_api_v1.yaml").open() as f: specification = safe_load(f) - api_bp = FlaskApi( + swagger_ui_options = SwaggerUIOptions( + swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True), + ) + + connexion_app.add_api( specification=specification, base_path="/internal_api/v1", - options={"swagger_ui": conf.getboolean("webserver", "enable_swagger_ui", fallback=True)}, + swagger_ui_options=swagger_ui_options, strict_validation=True, validate_responses=True, - ).blueprint - api_bp.after_request(set_cors_headers_on_response) + ) - app.register_blueprint(api_bp) - app.after_request_funcs.setdefault(api_bp.name, []).append(set_cors_headers_on_response) - app.extensions["csrf"].exempt(api_bp) + # flask_app = connexion_app.app + # flask_app.extensions["csrf"].exempt(api_bp) def init_api_experimental(app): @@ -326,11 +307,12 @@ def init_api_experimental(app): app.extensions["csrf"].exempt(endpoints.api_experimental) -def init_api_auth_provider(app): +def init_api_auth_provider(connexion_app: connexion.FlaskApp): """Initialize the API offered by the auth manager.""" auth_mgr = get_auth_manager() - blueprint = auth_mgr.get_api_endpoints() - if blueprint: + api = auth_mgr.get_api_endpoints(connexion_app) + if api: + blueprint = api.blueprint base_paths.append(blueprint.url_prefix) - app.register_blueprint(blueprint) - app.extensions["csrf"].exempt(blueprint) + flask_app = connexion_app.app + flask_app.extensions["csrf"].exempt(blueprint) diff --git a/setup.cfg b/setup.cfg index 9077a02915d703..486b47cae500fe 100644 --- a/setup.cfg +++ b/setup.cfg @@ -85,7 +85,7 @@ install_requires = # The usage was added in #30596, seemingly only to override and improve the default error message. # Either revert that change or find another way, preferably without using connexion internals. # This limit can be removed after https://github.com/apache/airflow/issues/35234 is fixed - connexion[flask]>=2.10.0,<3.0 + connexion[flask]>=3.0 cron-descriptor>=1.2.24 croniter>=0.3.17 cryptography>=0.9.3