Skip to content

Commit

Permalink
Add rest API to query for providers (#13394)
Browse files Browse the repository at this point in the history
* Add API to query for providers (#12468)

* Improve tests speed (#12468)
  • Loading branch information
nic314 authored May 7, 2021
1 parent b8c0fde commit 9dad095
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 0 deletions.
47 changes: 47 additions & 0 deletions airflow/api_connexion/endpoints/provider_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re
from typing import Dict, List

from airflow.api_connexion import security
from airflow.api_connexion.schemas.provider_schema import ProviderCollection, provider_collection_schema
from airflow.providers_manager import ProviderInfo, ProvidersManager
from airflow.security import permissions


def _remove_rst_syntax(value: str) -> str:
return re.sub("[`_<>]", "", value.strip(" \n."))


def _provider_mapper(provider: ProviderInfo) -> Dict:
return {
"package_name": provider[1]["package-name"],
"description": _remove_rst_syntax(provider[1]["description"]),
"version": provider[0],
}


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)])
def get_providers():
"""Get providers"""
providers_info: List[ProviderInfo] = list(ProvidersManager().providers.values())
providers = [_provider_mapper(d) for d in providers_info]
total_entries = len(providers)
return provider_collection_schema.dump(
ProviderCollection(providers=providers, total_entries=total_entries)
)
45 changes: 45 additions & 0 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ info:
|-|-|
| v2.0 | Initial release |
| v2.0.2 | Added /plugins endpoint |
| v2.1 | New providers endpoint |
# Trying the API
Expand Down Expand Up @@ -849,6 +850,26 @@ paths:
'404':
$ref: '#/components/responses/NotFound'

/providers:
get:
summary: List providers
x-openapi-router-controller: airflow.api_connexion.endpoints.provider_endpoint
operationId: get_providers
tags: [Provider]
responses:
'200':
description: List of providers.
content:
application/json:
schema:
allOf:
- $ref: '#/components/schemas/ProviderCollection'
- $ref: '#/components/schemas/CollectionInfo'
'401':
$ref: '#/components/responses/Unauthenticated'
'403':
$ref: '#/components/responses/PermissionDenied'

/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances:
parameters:
- $ref: '#/components/parameters/DAGID'
Expand Down Expand Up @@ -2016,6 +2037,29 @@ components:
- $ref: '#/components/schemas/CollectionInfo'


Provider:
description: The provider
type: object
properties:
package_name:
type: string
description: The package name of the provider.
description:
type: string
description: The description of the provider.
version:
type: string
description: The version of the provider.

ProviderCollection:
type: object
properties:
providers:
type: array
items:
$ref: '#/components/schemas/Provider'


SLAMiss:
type: object
properties:
Expand Down Expand Up @@ -3390,6 +3434,7 @@ tags:
- name: ImportError
- name: Monitoring
- name: Pool
- name: Provider
- name: TaskInstance
- name: Variable
- name: XCom
Expand Down
46 changes: 46 additions & 0 deletions airflow/api_connexion/schemas/provider_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import List, NamedTuple

from marshmallow import Schema, fields


class ProviderSchema(Schema):
"""Provider schema"""

package_name = fields.String(required=True)
description = fields.String(required=True)
version = fields.String(required=True)


class ProviderCollection(NamedTuple):
"""List of Providers"""

providers: List[ProviderSchema]
total_entries: int


class ProviderCollectionSchema(Schema):
"""Provider Collection schema"""

providers = fields.List(fields.Nested(ProviderSchema))
total_entries = fields.Int()


provider_collection_schema = ProviderCollectionSchema()
provider_schema = ProviderSchema()
1 change: 1 addition & 0 deletions airflow/security/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
RESOURCE_PERMISSION_VIEW = "Permission Views" # Refers to a Perm <-> View mapping, not an MVC View.
RESOURCE_POOL = "Pools"
RESOURCE_PLUGIN = "Plugins"
RESOURCE_PROVIDER = "Providers"
RESOURCE_ROLE = "Roles"
RESOURCE_SLA_MISS = "SLA Misses"
RESOURCE_TASK_INSTANCE = "Task Instances"
Expand Down
1 change: 1 addition & 0 deletions airflow/www/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin): # pylint: disable=
(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL),
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL),
(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER),
(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE),
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE),
Expand Down
1 change: 1 addition & 0 deletions docs/apache-airflow/security/access-control.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ Endpoint
/pools/{pool_name} DELETE Pool.can_delete Op
/pools/{pool_name} GET Pool.can_read Op
/pools/{pool_name} PATCH Pool.can_edit Op
/providers GET Provider.can_read Op
/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances GET DAGs.can_read, DAG Runs.can_read, Task Instances.can_read Viewer
/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id} GET DAGs.can_read, DAG Runs.can_read, Task Instances.can_read Viewer
/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/links GET DAGs.can_read, DAG Runs.can_read, Task Instances.can_read Viewer
Expand Down
123 changes: 123 additions & 0 deletions tests/api_connexion/endpoints/test_provider_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from collections import OrderedDict
from unittest import mock

import pytest

from airflow.security import permissions
from tests.test_utils.api_connexion_utils import create_user, delete_user

MOCK_PROVIDERS = OrderedDict(
[
(
'apache-airflow-providers-amazon',
(
'1.0.0',
{
'package-name': 'apache-airflow-providers-amazon',
'name': 'Amazon',
'description': '`Amazon Web Services (AWS) <https://aws.amazon.com/>`__.\n',
'versions': ['1.0.0'],
},
),
),
(
'apache-airflow-providers-apache-cassandra',
(
'1.0.0',
{
'package-name': 'apache-airflow-providers-apache-cassandra',
'name': 'Apache Cassandra',
'description': '`Apache Cassandra <http://cassandra.apache.org/>`__.\n',
'versions': ['1.0.0'],
},
),
),
]
)


@pytest.fixture(scope="module")
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
app, # type: ignore
username="test",
role_name="Test",
permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)],
)
create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore

yield app

delete_user(app, username="test") # type: ignore
delete_user(app, username="test_no_permissions") # type: ignore


class TestBaseProviderEndpoint:
@pytest.fixture(autouse=True)
def setup_attrs(self, configured_app) -> None:
self.app = configured_app
self.client = self.app.test_client() # type:ignore


class TestGetProviders(TestBaseProviderEndpoint):
@mock.patch(
"airflow.providers_manager.ProvidersManager.providers",
new_callable=mock.PropertyMock,
return_value={},
)
def test_response_200_empty_list(self, mock_providers):
response = self.client.get("/api/v1/providers", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
assert response.json == {"providers": [], "total_entries": 0}

@mock.patch(
"airflow.providers_manager.ProvidersManager.providers",
new_callable=mock.PropertyMock,
return_value=MOCK_PROVIDERS,
)
def test_response_200(self, mock_providers):
response = self.client.get("/api/v1/providers", environ_overrides={'REMOTE_USER': "test"})
assert response.status_code == 200
assert response.json == {
'providers': [
{
'description': 'Amazon Web Services (AWS) https://aws.amazon.com/',
'package_name': 'apache-airflow-providers-amazon',
'version': '1.0.0',
},
{
'description': 'Apache Cassandra http://cassandra.apache.org/',
'package_name': 'apache-airflow-providers-apache-cassandra',
'version': '1.0.0',
},
],
'total_entries': 2,
}

def test_should_raises_401_unauthenticated(self):
response = self.client.get("/api/v1/providers")
assert response.status_code == 401

def test_should_raise_403_forbidden(self):
response = self.client.get(
"/api/v1/providers", environ_overrides={'REMOTE_USER': "test_no_permissions"}
)
assert response.status_code == 403

0 comments on commit 9dad095

Please sign in to comment.