Skip to content

Commit

Permalink
Implement filter_permitted_dag_ids in AWS auth manager (apache#37666)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored and howardyoo committed Mar 31, 2024
1 parent 615da3f commit 55564f5
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 13 deletions.
23 changes: 21 additions & 2 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,31 @@ def get_permitted_dag_ids(
By default, reads all the DAGs and check individually if the user has permissions to access the DAG.
Can lead to some poor performance. It is recommended to override this method in the auth manager
implementation to provide a more efficient implementation.
:param methods: whether filter readable or writable
:param user: the current user
:param session: the session
"""
dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))}
return self.filter_permitted_dag_ids(dag_ids=dag_ids, methods=methods, user=user)

def filter_permitted_dag_ids(
self,
*,
dag_ids: set[str],
methods: Container[ResourceMethod] | None = None,
user=None,
):
"""
Filter readable or writable DAGs for user.
:param dag_ids: the list of DAG ids
:param methods: whether filter readable or writable
:param user: the current user
"""
if not methods:
methods = ["PUT", "GET"]

dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))}

if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or (
"PUT" in methods and self.is_authorized_dag(method="PUT", user=user)
):
Expand Down
75 changes: 64 additions & 11 deletions airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from __future__ import annotations

import argparse
from collections import defaultdict
from functools import cached_property
from typing import TYPE_CHECKING, Sequence, cast
from typing import TYPE_CHECKING, Container, Sequence, cast

from flask import session, url_for

Expand Down Expand Up @@ -443,6 +444,60 @@ def batch_is_authorized_variable(
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())

def filter_permitted_dag_ids(
self,
*,
dag_ids: set[str],
methods: Container[ResourceMethod] | None = None,
user=None,
):
"""
Filter readable or writable DAGs for user.
:param dag_ids: the list of DAG ids
:param methods: whether filter readable or writable
:param user: the current user
"""
if not methods:
methods = ["PUT", "GET"]

if not user:
user = self.get_user()

requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
requests_list: list[IsAuthorizedRequest] = []
for dag_id in dag_ids:
for method in ["GET", "PUT"]:
if method in methods:
request: IsAuthorizedRequest = {
"method": cast(ResourceMethod, method),
"entity_type": AvpEntities.DAG,
"entity_id": dag_id,
}
requests[dag_id][cast(ResourceMethod, method)] = request
requests_list.append(request)

batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
requests=requests_list, user=user
)

def _has_access_to_dag(request: IsAuthorizedRequest):
result = self.avp_facade.get_batch_is_authorized_single_result(
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
)
return result["decision"] == "ALLOW"

return {
dag_id
for dag_id in dag_ids
if (
"GET" in methods
and _has_access_to_dag(requests[dag_id]["GET"])
or "PUT" in methods
and _has_access_to_dag(requests[dag_id]["PUT"])
)
}

def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuItem]:
"""
Filter menu items based on user permissions.
Expand All @@ -465,19 +520,25 @@ def filter_permitted_menu_items(self, menu_items: list[MenuItem]) -> list[MenuIt
requests=list(requests.values()), user=user
)

def _has_access_to_menu_item(request: IsAuthorizedRequest):
result = self.avp_facade.get_batch_is_authorized_single_result(
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
)
return result["decision"] == "ALLOW"

accessible_items = []
for menu_item in menu_items:
if menu_item.childs:
accessible_children = []
for child in menu_item.childs:
if self._has_access_to_menu_item(batch_is_authorized_results, requests[child.name], user):
if _has_access_to_menu_item(requests[child.name]):
accessible_children.append(child)
menu_item.childs = accessible_children

# Display the menu if the user has access to at least one sub item
if len(accessible_children) > 0:
accessible_items.append(menu_item)
elif self._has_access_to_menu_item(batch_is_authorized_results, requests[menu_item.name], user):
elif _has_access_to_menu_item(requests[menu_item.name]):
accessible_items.append(menu_item)

return accessible_items
Expand Down Expand Up @@ -511,14 +572,6 @@ def _get_menu_item_request(fab_resource_name: str) -> IsAuthorizedRequest:
else:
raise AirflowException(f"Unknown resource name {fab_resource_name}")

def _has_access_to_menu_item(
self, batch_is_authorized_results: list[dict], request: IsAuthorizedRequest, user: AwsAuthManagerUser
):
result = self.avp_facade.get_batch_is_authorized_single_result(
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
)
return result["decision"] == "ALLOW"


def get_parser() -> argparse.ArgumentParser:
"""Generate documentation; used by Sphinx argparse."""
Expand Down
59 changes: 59 additions & 0 deletions tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,65 @@ def test_filter_permitted_menu_items_wrong_menu_item(self, mock_get_user, auth_m
]
)

@pytest.mark.parametrize(
"methods, user",
[
(None, None),
(["PUT", "GET"], AwsAuthManagerUser(user_id="test_user_id", groups=[])),
],
)
@patch.object(AwsAuthManager, "get_user")
def test_filter_permitted_dag_ids(self, mock_get_user, methods, user, auth_manager, test_user):
dag_ids = {"dag_1", "dag_2"}
batch_is_authorized_output = [
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
},
"decision": "ALLOW",
},
]
auth_manager.avp_facade.get_batch_is_authorized_results = Mock(
return_value=batch_is_authorized_output
)

mock_get_user.return_value = test_user

result = auth_manager.filter_permitted_dag_ids(
dag_ids=dag_ids,
methods=methods,
user=user,
)

auth_manager.avp_facade.get_batch_is_authorized_results.assert_called()
assert result == {"dag_2"}

@patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for")
def test_get_url_login(self, mock_url_for, auth_manager):
auth_manager.get_url_login()
Expand Down

0 comments on commit 55564f5

Please sign in to comment.