Skip to content

Commit

Permalink
Update methods to use Connexion v3, Ginucorn command and encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulada Zakharava authored and Ulada Zakharava committed Dec 20, 2023
1 parent 19241e7 commit 1485608
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 68 deletions.
7 changes: 2 additions & 5 deletions airflow/api_connexion/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
from __future__ import annotations

from http import HTTPStatus
from typing import TYPE_CHECKING, Any
from typing import Any

from connexion import ProblemException, problem
from connexion.lifecycle import ConnexionRequest, ConnexionResponse

from airflow.utils.docs import get_docs_url

if TYPE_CHECKING:
import flask

doc_link = get_docs_url("stable-rest-api-ref.html")

EXCEPTIONS_LINK_MAP = {
Expand Down Expand Up @@ -62,7 +59,7 @@ def problem_error_handler(_request: ConnexionRequest, exception: ProblemExceptio
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
)


class NotFound(ProblemException):
Expand Down
15 changes: 0 additions & 15 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

if TYPE_CHECKING:
import connexion
from connexion import FlaskApi, FlaskApp
from sqlalchemy.orm import Session

from airflow.auth.managers.models.base_user import BaseUser
Expand Down Expand Up @@ -135,20 +134,6 @@ def is_authorized_configuration(
:param user: the user to perform the action on. If not provided (or None), it uses the current user
"""

@abstractmethod
def is_authorized_cluster_activity(
self,
*,
method: ResourceMethod,
user: BaseUser | None = None,
) -> bool:
"""
Return whether the user is authorized to perform a given action on the cluster activity.
:param method: the method to perform
:param user: the user to perform the action on. If not provided (or None), it uses the current user
"""

@abstractmethod
def is_authorized_connection(
self,
Expand Down
3 changes: 1 addition & 2 deletions airflow/cli/commands/internal_api_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

import connexion
import psutil
from flask import Flask
from flask_appbuilder import SQLA
from flask_caching import Cache
from flask_wtf.csrf import CSRFProtect
Expand Down Expand Up @@ -102,7 +101,7 @@ def internal_api(args):
"--workers",
str(num_workers),
"--worker-class",
str(args.workerclass),
"uvicorn.workers.UvicornWorker",
"--timeout",
str(worker_timeout),
"--bind",
Expand Down
16 changes: 2 additions & 14 deletions airflow/cli/commands/webserver_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def webserver(args):
app = create_app(testing=conf.getboolean("core", "unit_test_mode"))
app.run(
log_level="debug",
# reload=not app.app.config["TESTING"],
port=args.port,
host=args.hostname,
ssl_keyfile=ssl_key if ssl_cert and ssl_key else None,
Expand All @@ -377,22 +376,14 @@ def webserver(args):
)

pid_file, _, _, _ = setup_locations("webserver", pid=args.pid)
ROOT_APP_DIR = Path(__file__).parents[2].resolve()
swagger_ui_path = os.fspath(ROOT_APP_DIR.joinpath("www", "static", "dist", "swagger-ui"))
run_args = [
sys.executable,
"-m",
"gunicorn",
# "-k",
"--worker-class",
"uvicorn.workers.UvicornWorker",
# "eventlet",
# "--swagger_ui_path",
# str(swagger_ui_path),
"--workers",
str(num_workers),

# str(args.workerclass),
"--worker-class",
"uvicorn.workers.UvicornWorker",
"--timeout",
str(worker_timeout),
"--bind",
Expand Down Expand Up @@ -421,7 +412,6 @@ def webserver(args):
run_args += ["--certfile", ssl_cert, "--keyfile", ssl_key]

run_args += ["airflow.www.app:cached_app()"]
# run_args += ["airflow.www.app:cached_app()"]

if conf.getboolean("webserver", "reload_on_plugin_change", fallback=False):
log.warning(
Expand All @@ -435,8 +425,6 @@ def webserver(args):
# all writing to the database at the same time, we use the --preload option.
run_args += ["--preload"]

log.info("RUNNING GUNICORN: %s", run_args)

def kill_proc(signum: int, gunicorn_master_proc: psutil.Process | subprocess.Popen) -> NoReturn:
log.info("Received signal: %s. Closing gunicorn.", signum)
gunicorn_master_proc.terminate()
Expand Down
32 changes: 15 additions & 17 deletions airflow/providers/fab/auth_manager/fab_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,11 @@

import connexion
from connexion.options import SwaggerUIOptions
from connexion.resolver import Resolver
from flask import url_for
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload

from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod
from airflow.auth.managers.fab.cli_commands.definition import (
ROLES_COMMANDS,
SYNC_PERM_COMMAND,
USERS_COMMANDS,
)
from airflow.auth.managers.fab.models import Permission, Role, User
from airflow.auth.managers.models.resource_details import (
AccessView,
ConfigurationDetails,
Expand All @@ -51,8 +44,14 @@
GroupCommand,
)
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import DagModel
from airflow.providers.fab.auth_manager.cli_commands.definition import (
ROLES_COMMANDS,
SYNC_PERM_COMMAND,
USERS_COMMANDS,
)
from airflow.providers.fab.auth_manager.models import Permission, Role, User
from airflow.security import permissions
from airflow.security.permissions import (
ACTION_CAN_ACCESS_MENU,
Expand Down Expand Up @@ -88,11 +87,11 @@
from airflow.www.extensions.init_views import _LazyResolver

if TYPE_CHECKING:
from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride
from airflow.auth.managers.models.base_user import BaseUser
from airflow.cli.cli_config import (
CLICommand,
)
from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride

_MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE: dict[DagAccessEntity, tuple[str, ...]] = {
DagAccessEntity.AUDIT_LOG: (RESOURCE_AUDIT_LOG,),
Expand Down Expand Up @@ -197,9 +196,6 @@ def is_authorized_configuration(
) -> bool:
return self._is_authorized(method=method, resource_type=RESOURCE_CONFIG, user=user)

def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool:
return self._is_authorized(method=method, resource_type=RESOURCE_CLUSTER_ACTIVITY, user=user)

def is_authorized_connection(
self,
*,
Expand All @@ -226,10 +222,10 @@ def is_authorized_dag(
entity (e.g. DAG runs).
2. ``dag_access`` is provided which means the user wants to access a sub entity of the DAG
(e.g. DAG runs).
a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity.
b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity.
However, if no specific DAG is targeted, just check the sub entity.
a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity.
b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity. However,
if no specific DAG is targeted, just check the sub entity.
:param method: The method to authorize.
:param access_entity: The dag access entity.
Expand Down Expand Up @@ -338,7 +334,9 @@ def get_permitted_dag_ids(
@cached_property
def security_manager(self) -> FabAirflowSecurityManagerOverride:
"""Return the security manager specific to FAB."""
from airflow.auth.managers.fab.security_manager.override import FabAirflowSecurityManagerOverride
from airflow.providers.fab.auth_manager.security_manager.override import (
FabAirflowSecurityManagerOverride,
)
from airflow.www.security import AirflowSecurityManager

sm_from_config = self.appbuilder.get_app.config.get("SECURITY_MANAGER_CLASS")
Expand All @@ -352,7 +350,7 @@ def security_manager(self) -> FabAirflowSecurityManagerOverride:
warnings.warn(
"Please make your custom security manager inherit from "
"FabAirflowSecurityManagerOverride instead of AirflowSecurityManager.",
DeprecationWarning,
AirflowProviderDeprecationWarning,
)
return sm_from_config(self.appbuilder)

Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AirflowJsonProvider(JSONProvider):
def dumps(self, obj, **kwargs):
kwargs.setdefault("ensure_ascii", self.ensure_ascii)
kwargs.setdefault("sort_keys", self.sort_keys)
return json.dumps(obj, **kwargs, cls=WebEncoder)
return json.dumps(obj, cls=WebEncoder)

def loads(self, s: str | bytes, **kwargs):
return json.loads(s, **kwargs)
Expand Down
29 changes: 24 additions & 5 deletions airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
# under the License.
from __future__ import annotations

import connexion
import warnings
from datetime import timedelta

from flask import Flask
import connexion
from connexion.exceptions import ConnexionException
from flask import jsonify
from flask_appbuilder import SQLA
from flask_wtf.csrf import CSRFProtect
from markupsafe import Markup
from sqlalchemy.engine.url import make_url
from starlette.middleware.cors import CORSMiddleware

from airflow import settings
from airflow.api_internal.internal_api_call import InternalApiConfig
Expand Down Expand Up @@ -61,7 +63,6 @@
init_plugins,
)
from airflow.www.extensions.init_wsgi_middlewares import init_wsgi_middleware
from starlette.middleware.cors import CORSMiddleware

app: connexion.FlaskApp | None = None

Expand All @@ -70,6 +71,22 @@
csrf = CSRFProtect()


def custom_jsonify(obj, **kwargs):
# Check if cls key is already present
if "cls" in kwargs:
cls = kwargs.get("cls")
if cls != AirflowJsonProvider:
raise Exception(f"Conflict: cls: {cls} already set")

# Set cls to our custom provider
kwargs["cls"] = AirflowJsonProvider

try:
return jsonify(obj, **kwargs)
except ConnexionException:
raise ConnexionException("Unable to serialize data")


def create_app(config=None, testing=False):
"""Create a new instance of Airflow WWW app."""
connexion_app = connexion.FlaskApp(__name__)
Expand All @@ -83,6 +100,7 @@ def create_app(config=None, testing=False):
allow_headers=conf.get("api", "access_control_allow_headers"),
)

connexion_app.jsonify = custom_jsonify
flask_app = connexion_app.app
flask_app.secret_key = conf.get("webserver", "SECRET_KEY")

Expand Down Expand Up @@ -176,8 +194,9 @@ def create_app(config=None, testing=False):
init_api_internal(connexion_app)
init_api_experimental(flask_app)
init_api_auth_provider(connexion_app)
# needs to be after all api inits to let them add their path first
init_api_error_handlers(connexion_app)
init_api_error_handlers(
connexion_app
) # needs to be after all api inits to let them add their path first

get_auth_manager().init()

Expand Down
18 changes: 9 additions & 9 deletions airflow/www/extensions/init_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from connexion.problem import problem

from airflow.api_connexion.exceptions import problem_error_handler

from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.security import permissions
Expand Down Expand Up @@ -214,11 +213,12 @@ def _handle_http_exception(ex: starlette.exceptions.HTTPException) -> ConnexionR
return problem(
title=connexion.http_facts.HTTP_STATUS_CODES.get(ex.status_code),
detail=ex.detail,
status=ex.status_code
status=ex.status_code,
)

def _handle_api_not_found(request: ConnexionRequest, ex: starlette.exceptions.HTTPException) \
-> ConnexionResponse:
def _handle_api_not_found(
request: ConnexionRequest, ex: starlette.exceptions.HTTPException
) -> ConnexionResponse:
if any([request.url.path.startswith(p) for p in base_paths]):
# 404 errors are never handled on the blueprint level
# unless raised from a view func so actual 404 errors,
Expand All @@ -228,8 +228,9 @@ def _handle_api_not_found(request: ConnexionRequest, ex: starlette.exceptions.HT
else:
return views.not_found(ex)

def _handle_method_not_allowed(request: ConnexionRequest, ex: starlette.exceptions.HTTPException) \
-> ConnexionResponse:
def _handle_method_not_allowed(
request: ConnexionRequest, ex: starlette.exceptions.HTTPException
) -> ConnexionResponse:
if any([request.url.path.startswith(p) for p in base_paths]):
return _handle_http_exception(ex)
else:
Expand All @@ -244,14 +245,13 @@ def init_api_connexion(connexion_app: connexion.FlaskApp) -> None:
"""Initialize Stable API."""
base_path = "/api/v1"
base_paths.append(base_path)

with ROOT_APP_DIR.joinpath("api_connexion", "openapi", "v1.yaml").open() as f:
specification = safe_load(f)

swagger_ui_options = SwaggerUIOptions(
swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True),
swagger_ui_path=os.fspath(ROOT_APP_DIR.joinpath("www", "static", "dist", "swagger-ui")),
)
log.info("SWAGGER PATH: %s", os.fspath(ROOT_APP_DIR.joinpath("www", "static", "dist", "swagger-ui")))

connexion_app.add_api(
specification=specification,
Expand All @@ -261,6 +261,7 @@ def init_api_connexion(connexion_app: connexion.FlaskApp) -> None:
strict_validation=True,
validate_responses=True,
)

# flask_app = connexion_app.app
# flask_app.extensions["csrf"].exempt(api_bp)

Expand All @@ -273,7 +274,6 @@ def init_api_internal(connexion_app: connexion.FlaskApp, standalone_api: bool =
base_paths.append("/internal_api/v1")
with ROOT_APP_DIR.joinpath("api_internal", "openapi", "internal_api_v1.yaml").open() as f:
specification = safe_load(f)

swagger_ui_options = SwaggerUIOptions(
swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True),
)
Expand Down

0 comments on commit 1485608

Please sign in to comment.