Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Taks 1 - Refactor get_api_endpoints() #3

Merged
merged 3 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

if TYPE_CHECKING:
import connexion
from flask import Blueprint
from flask_appbuilder.menu import MenuItem
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -82,8 +81,8 @@ def get_cli_commands() -> list[CLICommand]:
"""
return []

def get_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None | Blueprint:
"""Return API endpoint(s) definition for the auth manager."""
def set_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None:
"""Set API endpoint(s) definition for the auth manager."""
return None

def get_user_name(self) -> str:
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/fab/auth_manager/fab_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import TYPE_CHECKING, Container

from connexion.options import SwaggerUIOptions
from flask import Blueprint, url_for
from flask import url_for
from sqlalchemy import select
from sqlalchemy.orm import Session, joinedload

Expand Down Expand Up @@ -150,7 +150,7 @@ def get_cli_commands() -> list[CLICommand]:
SYNC_PERM_COMMAND, # not in a command group
]

def get_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None | Blueprint:
def set_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None:
folder = Path(__file__).parents[0].resolve() # this is airflow/auth/managers/fab/
with folder.joinpath("openapi", "v1.yaml").open() as f:
specification = safe_load(f)
Expand All @@ -159,15 +159,15 @@ def get_api_endpoints(self, connexion_app: connexion.FlaskApp) -> None | Bluepri
swagger_ui=conf.getboolean("webserver", "enable_swagger_ui", fallback=True),
)

api = connexion_app.add_api(
connexion_app.add_api(
specification=specification,
resolver=_LazyResolver(),
base_path="/auth/fab/v1",
swagger_ui_options=swagger_ui_options,
strict_validation=True,
validate_responses=True,
)
return api.blueprint if api else None
return None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return None
self.appbuilder.app.extensions["csrf"].exempt(api.blueprint)
return None

According to the proposed solution, you it is suggested to move it here I guess.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't give you KeyError "crsf"?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, you have to add an if statement. But this does not solve the CSRF error

Copy link
Owner

@sudiptob2 sudiptob2 Feb 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        if api:
            self.appbuilder.app.extensions["csrf"].exempt(api.blueprint)
        return None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this connexion_app.add_api() doesn't return an object, so I just removed "if" statement.

I kind of understanding Vlada's struggle now.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that is why, CSRF error is still there. In this comment I think Rob suggested some solution
apache#36052 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what he meant here. I might need some time to understand it.

"Find a way to add csrf extension to a newly created blueprint using connexion: to retrieve blueprint object from connexion_app variable to save the current logic (flask_app.extensions["csrf"].exempt(blueprint)) or find a way to add this extension on connexion level(check the documentation for available options)."

Seems like we cannot get the blueprint, so we better focus on the second option(add this extension on the connexion level). I'm reading the connexion source code, but no luck yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to understand this repo. I will let you know if I find something useful.


def get_user_display_name(self) -> str:
"""Return the user's display name associated to the user in session."""
Expand Down
10 changes: 10 additions & 0 deletions airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from datetime import timedelta

import connexion
from flask import request
from flask_appbuilder import SQLA
from flask_wtf.csrf import CSRFProtect
from markupsafe import Markup
Expand Down Expand Up @@ -73,6 +74,15 @@ def create_app(config=None, testing=False):
"""Create a new instance of Airflow WWW app."""
connexion_app = connexion.FlaskApp(__name__)

@connexion_app.app.before_request
def before_request():
"""Exempts the view function associated with '/api/v1' requests from CSRF protection."""
if request.path.startswith("/api/v1"): # TODO: make sure this path is correct
view_function = flask_app.view_functions.get(request.endpoint)
if view_function:
# Exempt the view function from CSRF protection
connexion_app.app.extensions["csrf"].exempt(view_function)

connexion_app.add_middleware(
CORSMiddleware,
connexion.middleware.MiddlewarePosition.BEFORE_ROUTING,
Expand Down
6 changes: 1 addition & 5 deletions airflow/www/extensions/init_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,4 @@ def init_api_experimental(app):
def init_api_auth_provider(connexion_app: connexion.FlaskApp):
"""Initialize the API offered by the auth manager."""
auth_mgr = get_auth_manager()
blueprint = auth_mgr.get_api_endpoints(connexion_app)
if blueprint:
base_paths.append(blueprint.url_prefix if blueprint.url_prefix else "")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not have any blueprint here, so we can not update base_path But we need to handle base path update logic somewhere. Whats the impact of removing this line.

flask_app = connexion_app.app
flask_app.extensions["csrf"].exempt(blueprint)
auth_mgr.set_api_endpoints(connexion_app)
2 changes: 1 addition & 1 deletion docs/apache-airflow/core-concepts/auth-manager.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Auth managers may vend CLI commands which will be included in the ``airflow`` co
Rest API
^^^^^^^^

Auth managers may vend Rest API endpoints which will be included in the :doc:`/stable-rest-api-ref` by implementing the ``get_api_endpoints`` method. The endpoints can be used to manage resources such as users, groups, roles (if any) handled by your auth manager. Endpoints are only vended for the currently configured auth manager.
Auth managers may vend Rest API endpoints which will be included in the :doc:`/stable-rest-api-ref` by implementing the ``set_api_endpoints`` method. The endpoints can be used to manage resources such as users, groups, roles (if any) handled by your auth manager. Endpoints are only vended for the currently configured auth manager.

Next Steps
^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions tests/auth/managers/test_base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class TestBaseAuthManager:
def test_get_cli_commands_return_empty_list(self, auth_manager):
assert auth_manager.get_cli_commands() == []

def test_get_api_endpoints_return_none(self, auth_manager):
assert auth_manager.get_api_endpoints() is None
def test_set_api_endpoints_return_none(self, auth_manager):
assert auth_manager.set_api_endpoints() is None

def test_get_user_name(self, auth_manager):
user = Mock()
Expand Down