Skip to content

Commit

Permalink
Switch to Connexion 3 framework
Browse files Browse the repository at this point in the history
This is a huge PR being result of over a 100 commits
made by a number of people in #apache#36052 and apache#37638. It
switches to Connexion 3 as the driving backend
implementation for both - Airflow REST APIs and Flask
app that powers Airflow UI. It should be largely
backwards compatible when it comes to behaviour of
both APIs and Airflow Webserver views, however due to
decisions made by Connexion 3 maintainers, it changes
heavily the technology stack used under-the-hood:

1) Connexion App is an ASGI-compatible Open-API spec-first
   framework using ASGI as an interface between webserver
   and Python web application. ASGI is an asynchronous
   successor of WSGI.

2) Connexion itself is using Starlette to run asynchronous
   web services in Python.

3) We continue using gunicorn appliation server that still
   uses WSGI standard, which means that we can continue using
   Flask and we are usig standard Uvicorn ASGI webserver that
   converts the ASGI interface to WSGI interface of Gunicorn

Some of the problems handled in this PR

There were two problem was with session handling:

* the get_session_cookie - did not get the right cookie - it returned
  "session" string. The right fix was to change cookie_jar into
  cookie.jar because this is where apparently TestClient of starlette
  is holding the cookies (visible when you debug)

* The client does not accept "set_cookie" method - it accepts passing
  cookies via "cookies" dictionary - this is the usual httpx client
  - see  https://www.starlette.io/testclient/ - so we have to set
  cookie directly in the get method to try it out

Add "flask_client_with_login" for tests that neeed flask client

Some tests require functionality not available to Starlette test
client as they use Flask test client specific features - for those
we have an option to get flask test client instead of starlette
one.

Fix error handling for new connection 3 approach

Error handling for Connexion 3 integration needed to be reworked.

The way it behaves is much the same as it works in main:

* for API errors - we get application/problem+json responses
* for UI erros - we have rendered views
* for redirection - we have correct location header (it's been
  missing)
* the api error handled was not added as available middleware
  in the www tests

It should fix all test_views_base.py tests which were failing
on lack of location header for redirection.

Fix wrong response is tests_view_cluster_activity

The problem in the test was that Starlette Test Client opens a new
connection and start new session, while flask test client
uses the same database session. The test did not show data because
the data was not committed and session was not closed - which also
failed sqlite local tests with "database is locked" error.

Fix test_extra_links

The tests were failing again because the dagrun created was not
committed and session not closed. This worked with flask client
that used the same session accidentally but did not work with
test client from Starlette. Also it caused "database locked"
in sqlite / local tests.

Switch to non-deprecated auth manager

Fix to test_views_log.py

This PR partially fixes sessions and request parameter for
test_views_log. Some tests are still failing but for different reasons -
to be investigated.

Fix views_custom_user_views tests

The problem in those tests was that the check in security manager
was based on the assumption that the security manager was shared
between the client and test flask application - because they
were coming from the same flask app. But when we use starlette,
the call goes to a new process started and the user is deleted in
the database - so the shortcut of checking the security manager
did not work.

The change is that we are now checking if the user is deleted by
calling /users/show (we need a new users READ permission for that)
 - this way we go to the database and check if the user was indeed
deleted.

Fix test_task_instance_endpoint tests

There were two reasons for the test failed:

* when the Job was added to task instance, the task instance was not
  merged in session, which means that commit did not store the added
  Job

* some of the tests were expecting a call with specific session
  and they failed because session was different. Replacing the
  session with mock.ANY tells pytest that this parameter can be
  anything - we will have different session when when the call will
  be made with ASGI/Starlette

Fix parameter validation

* added default value for limit parameter across the board. Connexion
  3 does not like if the parameter had no default and we had not
    provided one - even if our custom decorated was adding it. Adding
  default value and updating our decorator to treat None as `default`
  fixed a number of problems where limits were not passed

* swapped openapi specification for /datasets/{uri} and /dataset/events.
  Since `{uri}` was defined first, connection matched `events` with
  `{uri}` and chose parameter definitions from `{uri}` not events

Fix test_log_enpoint tests

The problem here was that some sessions should be committed/closed but
also in order to run it standalone we wanted to create log templates
in the database - as it relied implcitly on log templates created by
other tests.

Fix test_views_dagrun, test_views_tasks and test_views_log

Fixed by switching to use flask client for testing rather than
starlette. Starlette client in this case has some side effects that are
also impacting Sqlite's session being created in a different
thread and deleted with close_all_sessions fixture.

Fix test_views_dagrun

Fixed by switching to use flask client for testing rather than
starlette. Starlette client in this case has some side effects that are
also impacting Sqlite's session being created in a different
thread and deleted with close_all_sessions fixture.

Co-authored-by: sudipto baral <[email protected]>
Co-authored-by: satoshi-sh <[email protected]>
Co-authored-by: Maksim Yermakou <[email protected]>
Co-authored-by: Ulada Zakharava <[email protected]>

Better API initialization including vending of API specification.

The way paths are added and initialized is better (for example
FAB contributes their path via new method in Auth Manager.

This also add back-compatibility to FAB auth manaager to continue
working on Airflow 2.9.
  • Loading branch information
Ulada Zakharava authored and potiuk committed May 29, 2024
1 parent aba8def commit 7f06d57
Show file tree
Hide file tree
Showing 128 changed files with 2,932 additions and 2,653 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/basic-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ jobs:
env:
HATCH_ENV: "test"
working-directory: ./clients/python
- name: Compile www assets
run: breeze compile-www-assets
- name: "Install Airflow in editable mode with fab for webserver tests"
run: pip install -e ".[fab]"
- name: "Install Python client"
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API
@provide_session
def get_connections(
*,
limit: int,
limit: int | None = None,
offset: int = 0,
order_by: str = "id",
session: Session = NEW_SESSION,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_dag_details(
@provide_session
def get_dags(
*,
limit: int,
limit: int | None = None,
offset: int = 0,
tags: Collection[str] | None = None,
dag_id_pattern: str | None = None,
Expand Down
11 changes: 7 additions & 4 deletions airflow/api_connexion/endpoints/dag_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Sequence

from flask import Response, current_app
from connexion import NoContent
from flask import current_app
from itsdangerous import BadSignature, URLSafeSerializer
from sqlalchemy import exc, select

Expand All @@ -39,7 +40,9 @@

@security.requires_access_dag("PUT")
@provide_session
def reparse_dag_file(*, file_token: str, session: Session = NEW_SESSION) -> Response:
def reparse_dag_file(
*, file_token: str, session: Session = NEW_SESSION
) -> tuple[str | NoContent, HTTPStatus]:
"""Request re-parsing a DAG file."""
secret_key = current_app.config["SECRET_KEY"]
auth_s = URLSafeSerializer(secret_key)
Expand All @@ -65,5 +68,5 @@ def reparse_dag_file(*, file_token: str, session: Session = NEW_SESSION) -> Resp
session.commit()
except exc.IntegrityError:
session.rollback()
return Response("Duplicate request", HTTPStatus.CREATED)
return Response(status=HTTPStatus.CREATED)
return "Duplicate request", HTTPStatus.CREATED
return NoContent, HTTPStatus.CREATED
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
@provide_session
def get_dag_warnings(
*,
limit: int,
limit: int | None = None,
dag_id: str | None = None,
warning_type: str | None = None,
offset: int | None = None,
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse:
@provide_session
def get_datasets(
*,
limit: int,
limit: int | None = None,
offset: int = 0,
uri_pattern: str | None = None,
dag_ids: str | None = None,
Expand Down Expand Up @@ -113,11 +113,11 @@ def get_datasets(


@security.requires_access_dataset("GET")
@provide_session
@format_parameters({"limit": check_limit})
@provide_session
def get_dataset_events(
*,
limit: int,
limit: int | None = None,
offset: int = 0,
order_by: str = "timestamp",
dataset_id: int | None = None,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/event_log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_event_logs(
included_events: str | None = None,
before: str | None = None,
after: str | None = None,
limit: int,
limit: int | None = None,
offset: int | None = None,
order_by: str = "event_log_id",
session: Session = NEW_SESSION,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) ->
@provide_session
def get_import_errors(
*,
limit: int,
limit: int | None = None,
offset: int | None = None,
order_by: str = "import_error_id",
session: Session = NEW_SESSION,
Expand Down
5 changes: 4 additions & 1 deletion airflow/api_connexion/endpoints/log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ def get_log(
logs = logs[0] if task_try_number is not None else logs
# we must have token here, so we can safely ignore it
token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment]
return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs))
return Response(
logs_schema.dumps(LogResponseObject(continuation_token=token, content=logs)),
headers={"Content-Type": "application/json"},
)
# text/plain. Stream
logs = task_log_reader.read_log_stream(ti, task_try_number, metadata)

Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse:
@provide_session
def get_pools(
*,
limit: int,
limit: int | None = None,
order_by: str = "id",
offset: int | None = None,
session: Session = NEW_SESSION,
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _apply_range_filter(query: Select, key: ClauseElement, value_range: tuple[T,
@provide_session
def get_task_instances(
*,
limit: int,
limit: int | None = None,
dag_id: str | None = None,
dag_run_id: str | None = None,
execution_date_gte: str | None = None,
Expand Down
55 changes: 23 additions & 32 deletions airflow/api_connexion/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Any

import werkzeug
from connexion import FlaskApi, ProblemException, problem
from connexion import ProblemException, problem

from airflow.utils.docs import get_docs_url

if TYPE_CHECKING:
import flask
from connexion.lifecycle import ConnexionRequest, ConnexionResponse

doc_link = get_docs_url("stable-rest-api-ref.html")

Expand All @@ -40,37 +39,29 @@
}


def common_error_handler(exception: BaseException) -> flask.Response:
def problem_error_handler(_request: ConnexionRequest, exception: ProblemException) -> ConnexionResponse:
"""Use to capture connexion exceptions and add link to the type field."""
if isinstance(exception, ProblemException):
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=link,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
return problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=link,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
if not isinstance(exception, werkzeug.exceptions.HTTPException):
exception = werkzeug.exceptions.InternalServerError()

response = problem(title=exception.name, detail=exception.description, status=exception.code)

return FlaskApi.get_response(response)
return problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)


class NotFound(ProblemException):
Expand Down
61 changes: 35 additions & 26 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1453,6 +1453,10 @@ paths:
responses:
"204":
description: Success.
content:
text/html:
schema:
type: string
"400":
$ref: "#/components/responses/BadRequest"
"401":
Expand Down Expand Up @@ -1829,6 +1833,10 @@ paths:
responses:
"204":
description: Success.
content:
text/html:
schema:
type: string
"400":
$ref: "#/components/responses/BadRequest"
"401":
Expand Down Expand Up @@ -1971,8 +1979,8 @@ paths:
response = self.client.get(
request_url,
query_string={"token": token},
headers={"Accept": "text/plain"},
environ_overrides={"REMOTE_USER": "test"},
headers={"Accept": "text/plain","REMOTE_USER": "test"},
)
continuation_token = response.json["continuation_token"]
metadata = URLSafeSerializer(key).loads(continuation_token)
Expand Down Expand Up @@ -2106,7 +2114,7 @@ paths:
properties:
content:
type: string
plain/text:
text/plain:
schema:
type: string

Expand Down Expand Up @@ -2192,29 +2200,6 @@ paths:
"403":
$ref: "#/components/responses/PermissionDenied"

/datasets/{uri}:
parameters:
- $ref: "#/components/parameters/DatasetURI"
get:
summary: Get a dataset
description: Get a dataset by uri.
x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
operationId: get_dataset
tags: [Dataset]
responses:
"200":
description: Success.
content:
application/json:
schema:
$ref: "#/components/schemas/Dataset"
"401":
$ref: "#/components/responses/Unauthenticated"
"403":
$ref: "#/components/responses/PermissionDenied"
"404":
$ref: "#/components/responses/NotFound"

/datasets/events:
get:
summary: Get dataset events
Expand Down Expand Up @@ -2272,6 +2257,30 @@ paths:
'404':
$ref: '#/components/responses/NotFound'

/datasets/{uri}:
parameters:
- $ref: "#/components/parameters/DatasetURI"
get:
summary: Get a dataset
description: Get a dataset by uri.
x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
operationId: get_dataset
tags: [Dataset]
responses:
"200":
description: Success.
content:
application/json:
schema:
$ref: "#/components/schemas/Dataset"
"401":
$ref: "#/components/responses/Unauthenticated"
"403":
$ref: "#/components/responses/PermissionDenied"
"404":
$ref: "#/components/responses/NotFound"


/config:
get:
summary: Get current configuration
Expand Down
14 changes: 9 additions & 5 deletions airflow/api_connexion/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def validate_istimezone(value: datetime) -> None:
raise BadRequest("Invalid datetime format", detail="Naive datetime is disallowed")


def format_datetime(value: str) -> datetime:
def format_datetime(value: str | None) -> datetime | None:
"""
Format datetime objects.
Expand All @@ -50,6 +50,8 @@ def format_datetime(value: str) -> datetime:
This should only be used within connection views because it raises 400
"""
if value is None:
return None
value = value.strip()
if value[-1] != "Z":
value = value.replace(" ", "+")
Expand All @@ -59,7 +61,7 @@ def format_datetime(value: str) -> datetime:
raise BadRequest("Incorrect datetime argument", detail=str(err))


def check_limit(value: int) -> int:
def check_limit(value: int | None) -> int:
"""
Check the limit does not exceed configured value.
Expand All @@ -68,7 +70,8 @@ def check_limit(value: int) -> int:
"""
max_val = conf.getint("api", "maximum_page_limit") # user configured max page limit
fallback = conf.getint("api", "fallback_page_limit")

if value is None:
return fallback
if value > max_val:
log.warning(
"The limit param value %s passed in API exceeds the configured maximum page limit %s",
Expand Down Expand Up @@ -99,8 +102,9 @@ def format_parameters_decorator(func: T) -> T:
@wraps(func)
def wrapped_function(*args, **kwargs):
for key, formatter in params_formatters.items():
if key in kwargs:
kwargs[key] = formatter(kwargs[key])
value = formatter(kwargs.get(key))
if value:
kwargs[key] = value
return func(*args, **kwargs)

return cast(T, wrapped_function)
Expand Down
13 changes: 11 additions & 2 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Container, Literal, Sequence
from typing import TYPE_CHECKING, Any, Container, Literal, Sequence

from flask_appbuilder.menu import MenuItem
from sqlalchemy import select
Expand Down Expand Up @@ -82,7 +82,7 @@ def get_cli_commands() -> list[CLICommand]:
return []

def get_api_endpoints(self) -> None | Blueprint:
"""Return API endpoint(s) definition for the auth manager."""
"""Return API endpoint(s) definition for the auth manager for Airflow 2.9."""
return None

def get_user_name(self) -> str:
Expand Down Expand Up @@ -442,3 +442,12 @@ def security_manager(self) -> AirflowSecurityManagerV2:
from airflow.www.security_manager import AirflowSecurityManagerV2

return AirflowSecurityManagerV2(self.appbuilder)

def get_auth_manager_api_specification(self) -> tuple[str | None, dict[Any, Any]]:
"""
Return the mount point and specification (openapi) for auth manager contributed API (Airflow 2.10).
By default is raises NotImplementedError which produces a warning in airflow logs when auth manager is
initialized, but you can return None, {} if the auth manager does not contribute API.
"""
raise NotImplementedError
Loading

0 comments on commit 7f06d57

Please sign in to comment.