Skip to content

Commit

Permalink
Add clear DagRun endpoint. (#23451)
Browse files Browse the repository at this point in the history
  • Loading branch information
tirkarthi authored May 24, 2022
1 parent f352ee6 commit b83cc9b
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 0 deletions.
56 changes: 56 additions & 0 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,16 @@
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dag_run_schema import (
DAGRunCollection,
clear_dagrun_form_schema,
dagrun_collection_schema,
dagrun_schema,
dagruns_batch_form_schema,
set_dagrun_state_form_schema,
)
from airflow.api_connexion.schemas.task_instance_schema import (
TaskInstanceReferenceCollection,
task_instance_reference_collection_schema,
)
from airflow.api_connexion.types import APIResponse
from airflow.models import DagModel, DagRun
from airflow.security import permissions
Expand Down Expand Up @@ -318,3 +323,54 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW
set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True)
dag_run = session.query(DagRun).get(dag_run.id)
return dagrun_schema.dump(dag_run)


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
],
)
@provide_session
def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Clear a dag run."""
dag_run: Optional[DagRun] = (
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
)
if dag_run is None:
error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
raise NotFound(error_message)
try:
post_body = clear_dagrun_form_schema.load(request.json)
except ValidationError as err:
raise BadRequest(detail=str(err))

dry_run = post_body.get('dry_run', False)
dag = current_app.dag_bag.get_dag(dag_id)
start_date = dag_run.logical_date
end_date = dag_run.logical_date

if dry_run:
task_instances = dag.clear(
start_date=start_date,
end_date=end_date,
task_ids=None,
include_subdags=True,
include_parentdag=True,
only_failed=False,
dry_run=True,
)
return task_instance_reference_collection_schema.dump(
TaskInstanceReferenceCollection(task_instances=task_instances)
)
else:
dag.clear(
start_date=start_date,
end_date=end_date,
task_ids=None,
include_subdags=True,
include_parentdag=True,
only_failed=False,
)
dag_run.refresh_from_db()
return dagrun_schema.dump(dag_run)
47 changes: 47 additions & 0 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,43 @@ paths:
'404':
$ref: '#/components/responses/NotFound'

/dags/{dag_id}/dagRuns/{dag_run_id}/clear:
parameters:
- $ref: '#/components/parameters/DAGID'
- $ref: '#/components/parameters/DAGRunID'

post:
summary: Clear a DAG run
description: |
Clear a DAG run.
*New in version 2.4.0*
x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint
operationId: clear_dag_run
tags: [DAGRun]
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/ClearDagRun'

responses:
'200':
description: Success.
content:
application/json:
schema:
$ref: '#/components/schemas/DAGRun'
'400':
$ref: '#/components/responses/BadRequest'
'401':
$ref: '#/components/responses/Unauthenticated'
'403':
$ref: '#/components/responses/PermissionDenied'
'404':
$ref: '#/components/responses/NotFound'

/eventLogs:
get:
summary: List log entries
Expand Down Expand Up @@ -3311,6 +3348,16 @@ components:
nullable: true

# Form
ClearDagRun:
type: object
properties:
dry_run:
description: |
If set, don't actually run this operation. The response will contain a list of task instances
planned to be cleaned, but not modified in any way.
type: boolean
default: true

ClearTaskInstance:
type: object
properties:
Expand Down
7 changes: 7 additions & 0 deletions airflow/api_connexion/schemas/dag_run_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ class SetDagRunStateFormSchema(Schema):
)


class ClearDagRunStateFormSchema(Schema):
"""Schema for handling the request of clearing a DAG run"""

dry_run = fields.Boolean(load_default=True)


class DAGRunCollection(NamedTuple):
"""List of DAGRuns with metadata"""

Expand Down Expand Up @@ -158,4 +164,5 @@ class Meta:
dagrun_schema = DAGRunSchema()
dagrun_collection_schema = DAGRunCollectionSchema()
set_dagrun_state_form_schema = SetDagRunStateFormSchema()
clear_dagrun_form_schema = ClearDagRunStateFormSchema()
dagruns_batch_form_schema = DagRunsBatchFormSchema()
114 changes: 114 additions & 0 deletions tests/api_connexion/endpoints/test_dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,3 +1369,117 @@ def test_should_respond_404(self):
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 404


class TestClearDagRun(TestDagRunEndpoint):
def test_should_respond_200(self, dag_maker, session):
dag_id = "TEST_DAG_ID"
dag_run_id = "TEST_DAG_RUN_ID"
with dag_maker(dag_id) as dag:
task = EmptyOperator(task_id="task_id", dag=dag)
self.app.dag_bag.bag_dag(dag, root_dag=dag)
dr = dag_maker.create_dagrun(run_id=dag_run_id)
ti = dr.get_task_instance(task_id="task_id")
ti.task = task
ti.state = State.SUCCESS
session.merge(ti)
session.commit()

request_json = {"dry_run": False}

response = self.client.post(
f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear",
json=request_json,
environ_overrides={"REMOTE_USER": "test"},
)

dr = session.query(DagRun).filter(DagRun.run_id == dr.run_id).first()
assert response.status_code == 200
assert response.json == {
"conf": {},
"dag_id": dag_id,
"dag_run_id": dag_run_id,
"end_date": None,
"execution_date": dr.execution_date.isoformat(),
"external_trigger": False,
"logical_date": dr.logical_date.isoformat(),
"start_date": dr.logical_date.isoformat(),
"state": "queued",
"data_interval_start": dr.data_interval_start.isoformat(),
"data_interval_end": dr.data_interval_end.isoformat(),
"last_scheduling_decision": None,
"run_type": dr.run_type,
}

ti.refresh_from_db()
assert ti.state is None

def test_dry_run(self, dag_maker, session):
"""Test that dry_run being True returns TaskInstances without clearing DagRun"""
dag_id = "TEST_DAG_ID"
dag_run_id = "TEST_DAG_RUN_ID"
with dag_maker(dag_id) as dag:
task = EmptyOperator(task_id="task_id", dag=dag)
self.app.dag_bag.bag_dag(dag, root_dag=dag)
dr = dag_maker.create_dagrun(run_id=dag_run_id)
ti = dr.get_task_instance(task_id="task_id")
ti.task = task
ti.state = State.SUCCESS
session.merge(ti)
session.commit()

request_json = {"dry_run": True}

response = self.client.post(
f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear",
json=request_json,
environ_overrides={"REMOTE_USER": "test"},
)

assert response.status_code == 200
assert response.json == {
"task_instances": [
{
"dag_id": dag_id,
"dag_run_id": dag_run_id,
"execution_date": dr.execution_date.isoformat(),
"task_id": "task_id",
}
]
}

ti.refresh_from_db()
assert ti.state == State.SUCCESS

dr = session.query(DagRun).filter(DagRun.run_id == dr.run_id).first()
assert dr.state == "running"

def test_should_raises_401_unauthenticated(self, session):
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/clear",
json={
"dry_run": True,
},
)

assert_401(response)

def test_should_raise_403_forbidden(self):
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/clear",
json={
"dry_run": True,
},
environ_overrides={"REMOTE_USER": "test_no_permissions"},
)
assert response.status_code == 403

def test_should_respond_404(self):
response = self.client.post(
"api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/clear",
json={
"dry_run": True,
},
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 404

0 comments on commit b83cc9b

Please sign in to comment.