diff --git a/airflow/www/views.py b/airflow/www/views.py index 81c54d70659461..efdd8aab1f4704 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -2289,7 +2289,8 @@ def _mark_dagrun_state_as_success(self, dag_id, dag_run_id, confirmed): return htmlsafe_json_dumps(details, separators=(",", ":")) - def _mark_dagrun_state_as_queued(self, dag_id: str, dag_run_id: str, confirmed: bool): + @provide_session + def _mark_dagrun_state_as_queued(self, dag_id: str, dag_run_id: str, confirmed: bool, session=None): if not dag_run_id: return {"status": "error", "message": "Invalid dag_run_id"} @@ -2298,13 +2299,23 @@ def _mark_dagrun_state_as_queued(self, dag_id: str, dag_run_id: str, confirmed: if not dag: return {"status": "error", "message": f"Cannot find DAG: {dag_id}"} - new_dag_state = set_dag_run_state_to_queued(dag=dag, run_id=dag_run_id, commit=confirmed) + set_dag_run_state_to_queued(dag=dag, run_id=dag_run_id, commit=confirmed) if confirmed: return {"status": "success", "message": "Marked the DagRun as queued."} else: - details = [str(t) for t in new_dag_state] + # Identify tasks that will be queued up to run when confirmed + all_task_ids = [task.task_id for task in dag.tasks] + + existing_tis = session.query(TaskInstance.task_id).filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.run_id == dag_run_id, + ) + + completed_tis_ids = [task_id for task_id, in existing_tis] + tasks_with_no_state = list(set(all_task_ids) - set(completed_tis_ids)) + details = [str(t) for t in tasks_with_no_state] return htmlsafe_json_dumps(details, separators=(",", ":")) diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py index 504ee7b09b9562..e647bf23098ad0 100644 --- a/tests/www/views/test_views_dagrun.py +++ b/tests/www/views/test_views_dagrun.py @@ -28,7 +28,7 @@ from airflow.utils.session import create_session from airflow.www.views import DagRunModelView from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user -from tests.test_utils.www import check_content_in_response, client_with_login +from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login from tests.www.views.test_views_tasks import _get_appbuilder_pk_string @@ -126,6 +126,31 @@ def running_dag_run(session): return dr +@pytest.fixture() +def completed_dag_run_with_missing_task(session): + dag = DagBag().get_dag("example_bash_operator") + execution_date = timezone.datetime(2016, 1, 9) + dr = dag.create_dagrun( + state="success", + execution_date=execution_date, + data_interval=(execution_date, execution_date), + run_id="test_dag_runs_action", + session=session, + ) + session.add(dr) + tis = [ + TaskInstance(dag.get_task("runme_0"), run_id=dr.run_id, state="success"), + TaskInstance(dag.get_task("runme_1"), run_id=dr.run_id, state="success"), + TaskInstance(dag.get_task("also_run_this"), run_id=dr.run_id, state="success"), + TaskInstance(dag.get_task("run_after_loop"), run_id=dr.run_id, state="success"), + TaskInstance(dag.get_task("this_will_skip"), run_id=dr.run_id, state="success"), + TaskInstance(dag.get_task("run_this_last"), run_id=dr.run_id, state="success"), + ] + session.bulk_save_objects(tis) + session.commit() + return dag, dr + + def test_delete_dagrun(session, admin_client, running_dag_run): composite_key = _get_appbuilder_pk_string(DagRunModelView, running_dag_run) assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 @@ -235,3 +260,15 @@ def test_set_dag_runs_action_permission_denied(client_dr_without_dag_edit, runni follow_redirects=True, ) check_content_in_response(f"Access denied for dag_id {running_dag_run.dag_id}", resp) + + +def test_dag_runs_queue_new_tasks_action(session, admin_client, completed_dag_run_with_missing_task): + dag, dag_run = completed_dag_run_with_missing_task + resp = admin_client.post( + "/dagrun_queued", + data={"dag_id": dag.dag_id, "dag_run_id": dag_run.run_id, "confirmed": False}, + ) + + check_content_in_response("runme_2", resp) + check_content_not_in_response("runme_1", resp) + assert resp.status_code == 200