Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented missing wait_for_completion feature from baseTriggerDagRunOperator class #32

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions airflow_multi_dagrun/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import time
import typing as t

from airflow.api.common.experimental.trigger_dag import trigger_dag
from airflow.exceptions import AirflowException
from airflow.models import DagRun
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.utils import timezone
Expand All @@ -12,26 +14,29 @@


class TriggerMultiDagRunOperator(TriggerDagRunOperator):

def __init__(self, op_args=None, op_kwargs=None, python_callable=None, *args, **kwargs):
super(TriggerMultiDagRunOperator, self).__init__(*args, **kwargs)
self.op_args = op_args or []
self.op_kwargs = op_kwargs or {}
self.python_callable = python_callable

@provide_session
def execute(self, context: t.Dict, session=None):
@staticmethod
def get_multi_dag_run_xcom_key(execution_date) -> str:
return f"created_dagrun_key_{execution_date}"

def execute(self, context: Context):
context.update(self.op_kwargs)
self.op_kwargs = determine_kwargs(self.python_callable, self.op_args, context)

created_dr_ids = []
created_drs = []
for conf in self.python_callable(*self.op_args, **self.op_kwargs):
if not conf:
break

execution_date = timezone.utcnow()

run_id = conf.get('run_id')
run_id = conf.get("run_id")
if not run_id:
run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date)

Expand All @@ -47,13 +52,48 @@ def execute(self, context: t.Dict, session=None):
self.log.info("Created DagRun %s, %s - %s", dag_run, self.trigger_dag_id, run_id)
else:
dag_run = dag_run[0]
self.log.warning("Fetched existed DagRun %s, %s - %s", dag_run, self.trigger_dag_id, run_id)
self.log.warning(
"Fetched existed DagRun %s, %s - %s", dag_run, self.trigger_dag_id, run_id
)

created_dr_ids.append(dag_run.id)
created_drs.append(dag_run)

if created_dr_ids:
xcom_key = get_multi_dag_run_xcom_key(context['execution_date'])
context['ti'].xcom_push(xcom_key, created_dr_ids)
xcom_key = self.get_multi_dag_run_xcom_key(context["execution_date"])
context["ti"].xcom_push(xcom_key, created_dr_ids)
self.log.info("Pushed %s DagRun's ids with key %s", len(created_dr_ids), xcom_key)

if self.wait_for_completion:
failed_dags = {}

# Hold on while we still have running DAGs...
while created_drs:
self.log.info(
"Waiting for DAGs triggered by %s to complete...", self.trigger_dag_id
)
time.sleep(self.poke_interval)

# Check every running DAG.
for dag_run in created_drs:
dag_run.refresh_from_db()
state = dag_run.state

if state in self.failed_states:
# If the DAG has failed, mark it as such and remove it from list.
failed_dags[self.trigger_dag_id] = state
created_drs.remove(dag_run)
elif state in self.allowed_states:
# If the DAG succeeded, log in the successful event and remove it too.
self.log.info(
"%s finished with allowed state %s", self.trigger_dag_id, state
)
created_drs.remove(dag_run)

if failed_dags:
# Raise an Airflow exception if any of the DAGs failed.
failures = "; ".join(f"{key}: {value}" for key, value in failed_dags.items())
raise AirflowException(f"Failed DAGs: {failures}")
else:
self.log.info("No DagRuns created")