Skip to content

Commit

Permalink
Update methods to use Connexion v3, Ginucorn command and encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulada Zakharava authored and Ulada Zakharava committed Dec 20, 2023
1 parent 2b31f37 commit 54ac459
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 142 deletions.
59 changes: 24 additions & 35 deletions airflow/api_connexion/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
from __future__ import annotations

from http import HTTPStatus
from typing import TYPE_CHECKING, Any
from typing import Any

import werkzeug
from connexion import FlaskApi, ProblemException, problem
from connexion import ProblemException, problem
from connexion.lifecycle import ConnexionRequest, ConnexionResponse

from airflow.utils.docs import get_docs_url

if TYPE_CHECKING:
import flask

doc_link = get_docs_url("stable-rest-api-ref.html")

EXCEPTIONS_LINK_MAP = {
Expand All @@ -40,37 +37,29 @@
}


def common_error_handler(exception: BaseException) -> flask.Response:
def problem_error_handler(_request: ConnexionRequest, exception: ProblemException) -> ConnexionResponse:
"""Use to capture connexion exceptions and add link to the type field."""
if isinstance(exception, ProblemException):
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=link,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
return problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=link,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
if not isinstance(exception, werkzeug.exceptions.HTTPException):
exception = werkzeug.exceptions.InternalServerError()

response = problem(title=exception.name, detail=exception.description, status=exception.code)

return FlaskApi.get_response(response)
return problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)


class NotFound(ProblemException):
Expand Down
4 changes: 2 additions & 2 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from flask import Blueprint
import connexion
from sqlalchemy.orm import Session

from airflow.auth.managers.models.base_user import BaseUser
Expand Down Expand Up @@ -79,7 +79,7 @@ def get_cli_commands() -> list[CLICommand]:
"""
return []

def get_api_endpoints(self) -> None | Blueprint:
def get_api_endpoints(self, connexion_app: connexion.FlaskApp) -> connexion.apps.flask.FlaskApi:
"""Return API endpoint(s) definition for the auth manager."""
return None

Expand Down
15 changes: 8 additions & 7 deletions airflow/cli/commands/internal_api_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from tempfile import gettempdir
from time import sleep

import connexion
import psutil
from flask import Flask
from flask_appbuilder import SQLA
from flask_caching import Cache
from flask_wtf.csrf import CSRFProtect
Expand All @@ -54,7 +54,7 @@
from airflow.www.extensions.init_views import init_api_internal, init_error_handlers

log = logging.getLogger(__name__)
app: Flask | None = None
app: connexion.FlaskApp | None = None


@cli_utils.action_cli
Expand All @@ -73,8 +73,8 @@ def internal_api(args):
log.info(f"Starting the Internal API server on port {args.port} and host {args.hostname}.")
app = create_app(testing=conf.getboolean("core", "unit_test_mode"))
app.run(
debug=True, # nosec
use_reloader=not app.config["TESTING"],
log_level="debug",
# reload=not app.app.config["TESTING"],
port=args.port,
host=args.hostname,
)
Expand All @@ -101,7 +101,7 @@ def internal_api(args):
"--workers",
str(num_workers),
"--worker-class",
str(args.workerclass),
"uvicorn.workers.UvicornWorker",
"--timeout",
str(worker_timeout),
"--bind",
Expand Down Expand Up @@ -195,7 +195,8 @@ def start_and_monitor_gunicorn(args):

def create_app(config=None, testing=False):
"""Create a new instance of Airflow Internal API app."""
flask_app = Flask(__name__)
connexion_app = connexion.FlaskApp(__name__)
flask_app = connexion_app.app

flask_app.config["APP_NAME"] = "Airflow Internal API"
flask_app.config["TESTING"] = testing
Expand Down Expand Up @@ -240,7 +241,7 @@ def create_app(config=None, testing=False):

with flask_app.app_context():
init_error_handlers(flask_app)
init_api_internal(flask_app, standalone_api=True)
init_api_internal(connexion_app, standalone_api=True)

init_jinja_globals(flask_app)
init_xframe_protection(flask_app)
Expand Down
8 changes: 4 additions & 4 deletions airflow/cli/commands/webserver_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,11 @@ def webserver(args):
print(f"Starting the web server on port {args.port} and host {args.hostname}.")
app = create_app(testing=conf.getboolean("core", "unit_test_mode"))
app.run(
debug=True,
use_reloader=not app.config["TESTING"],
log_level="debug",
port=args.port,
host=args.hostname,
ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None,
ssl_keyfile=ssl_key if ssl_cert and ssl_key else None,
ssl_certfile=ssl_cert if ssl_cert and ssl_key else None,
)
else:
print(
Expand All @@ -383,7 +383,7 @@ def webserver(args):
"--workers",
str(num_workers),
"--worker-class",
str(args.workerclass),
"uvicorn.workers.UvicornWorker",
"--timeout",
str(worker_timeout),
"--bind",
Expand Down
23 changes: 13 additions & 10 deletions airflow/providers/fab/auth_manager/fab_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from pathlib import Path
from typing import TYPE_CHECKING, Container

from connexion import FlaskApi
from flask import Blueprint, url_for
import connexion
from connexion.options import SwaggerUIOptions
from flask import url_for
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload

Expand Down Expand Up @@ -83,7 +84,7 @@
)
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.yaml import safe_load
from airflow.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver
from airflow.www.extensions.init_views import _LazyResolver

if TYPE_CHECKING:
from airflow.auth.managers.models.base_user import BaseUser
Expand Down Expand Up @@ -147,21 +148,23 @@ def get_cli_commands() -> list[CLICommand]:
SYNC_PERM_COMMAND, # not in a command group
]

def get_api_endpoints(self) -> None | Blueprint:
def get_api_endpoints(self, connexion_app: connexion.FlaskApp) -> connexion.apps.flask.FlaskApi:
folder = Path(__file__).parents[0].resolve() # this is airflow/auth/managers/fab/
with folder.joinpath("openapi", "v1.yaml").open() as f:
specification = safe_load(f)
return FlaskApi(

swagger_ui_options = SwaggerUIOptions(
swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True),
)

return connexion_app.add_api(
specification=specification,
resolver=_LazyResolver(),
base_path="/auth/fab/v1",
options={
"swagger_ui": conf.getboolean("webserver", "enable_swagger_ui", fallback=True),
},
swagger_ui_options=swagger_ui_options,
strict_validation=True,
validate_responses=True,
validator_map={"body": _CustomErrorRequestBodyValidator},
).blueprint
)

def get_user_display_name(self) -> str:
"""Return the user's display name associated to the user in session."""
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AirflowJsonProvider(JSONProvider):
def dumps(self, obj, **kwargs):
kwargs.setdefault("ensure_ascii", self.ensure_ascii)
kwargs.setdefault("sort_keys", self.sort_keys)
return json.dumps(obj, **kwargs, cls=WebEncoder)
return json.dumps(obj, cls=WebEncoder)

def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)
Expand Down
49 changes: 41 additions & 8 deletions airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
import warnings
from datetime import timedelta

from flask import Flask
import connexion
from connexion.exceptions import ConnexionException
from flask import jsonify
from flask_appbuilder import SQLA
from flask_wtf.csrf import CSRFProtect
from markupsafe import Markup
from sqlalchemy.engine.url import make_url
from starlette.middleware.cors import CORSMiddleware

from airflow import settings
from airflow.api_internal.internal_api_call import InternalApiConfig
Expand Down Expand Up @@ -61,16 +64,44 @@
)
from airflow.www.extensions.init_wsgi_middlewares import init_wsgi_middleware

app: Flask | None = None
app: connexion.FlaskApp | None = None

# Initializes at the module level, so plugins can access it.
# See: /docs/plugins.rst
csrf = CSRFProtect()


def custom_jsonify(obj, **kwargs):
# Check if cls key is already present
if "cls" in kwargs:
cls = kwargs.get("cls")
if cls != AirflowJsonProvider:
raise Exception(f"Conflict: cls: {cls} already set")

# Set cls to our custom provider
kwargs["cls"] = AirflowJsonProvider

try:
return jsonify(obj, **kwargs)
except ConnexionException:
raise ConnexionException("Unable to serialize data")


def create_app(config=None, testing=False):
"""Create a new instance of Airflow WWW app."""
flask_app = Flask(__name__)
connexion_app = connexion.FlaskApp(__name__)

connexion_app.add_middleware(
CORSMiddleware,
connexion.middleware.MiddlewarePosition.BEFORE_ROUTING,
allow_origins=conf.get("api", "access_control_allow_origins"),
allow_credentials=True,
allow_methods=conf.get("api", "access_control_allow_methods"),
allow_headers=conf.get("api", "access_control_allow_headers"),
)

connexion_app.jsonify = custom_jsonify
flask_app = connexion_app.app
flask_app.secret_key = conf.get("webserver", "SECRET_KEY")

flask_app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(minutes=settings.get_session_lifetime_config())
Expand Down Expand Up @@ -156,22 +187,24 @@ def create_app(config=None, testing=False):
init_appbuilder_links(flask_app)
init_plugins(flask_app)
init_error_handlers(flask_app)
init_api_connexion(flask_app)
init_api_connexion(connexion_app)
if conf.getboolean("webserver", "run_internal_api", fallback=False):
if not _ENABLE_AIP_44:
raise RuntimeError("The AIP_44 is not enabled so you cannot use it.")
init_api_internal(flask_app)
init_api_internal(connexion_app)
init_api_experimental(flask_app)
init_api_auth_provider(flask_app)
init_api_error_handlers(flask_app) # needs to be after all api inits to let them add their path first
init_api_auth_provider(connexion_app)
init_api_error_handlers(
connexion_app
) # needs to be after all api inits to let them add their path first

get_auth_manager().init()

init_jinja_globals(flask_app)
init_xframe_protection(flask_app)
init_airflow_session_interface(flask_app)
init_check_user_active(flask_app)
return flask_app
return connexion_app


def cached_app(config=None, testing=False):
Expand Down
Loading

0 comments on commit 54ac459

Please sign in to comment.