diff --git a/marquez/airflow.py b/marquez/airflow.py index 9be60de45f3340..62f2e0162965f2 100644 --- a/marquez/airflow.py +++ b/marquez/airflow.py @@ -1,93 +1,85 @@ import json import pendulum -import airflow.models +from airflow.models import DAG, Log from airflow.utils.db import provide_session from marquez_client.marquez import MarquezClient +from marquez.utils import JobIdMapping -class MarquezDag(airflow.models.DAG): + +class MarquezDag(DAG): + _job_id_mapping = None + _mqz_client = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.mqz_client = MarquezClient() self.mqz_namespace = kwargs['default_args'].get('mqz_namespace', 'unknown') self.mqz_location = kwargs['default_args'].get('mqz_location', 'unknown') self.mqz_input_datasets = kwargs['default_args'].get('mqz_input_datasets', []) self.mqz_output_datasets = kwargs['default_args'].get('mqz_output_datasets', []) + self._job_id_mapping = JobIdMapping() def create_dagrun(self, *args, **kwargs): - job_name = self.dag_id - job_run_args = "{}" # TODO retrieve from DAG/tasks - start_time = pendulum.instance(kwargs['execution_date']).to_datetime_string() - end_time = None - - self.mqz_client.set_namespace(self.mqz_namespace) - self.mqz_client.create_job(job_name, self.mqz_location, self.mqz_input_datasets, self.mqz_output_datasets, - self.description) - mqz_job_run_id = self.mqz_client.create_job_run(job_name, job_run_args=job_run_args, - nominal_start_time=start_time, - nominal_end_time=end_time).run_id - self.mqz_client.mark_job_run_running(mqz_job_run_id) - - self.marquez_log('job_running', json.dumps( - {'namespace': self.mqz_namespace, - 'name': job_name, - 'description': self.description, - 'location': self.mqz_location, - 'runArgs': job_run_args, - 'nominal_start_time': start_time, - 'nominal_end_time': end_time, - 'jobrun_id': mqz_job_run_id, - 'inputDatasetUrns': self.mqz_input_datasets, - 'outputDatasetUrns': self.mqz_output_datasets - })) - - run = super().create_dagrun(*args, **kwargs) - airflow.models.Variable.set(run.run_id, mqz_job_run_id) - + run_args = "{}" # TODO extract the run Args from the tasks + mqz_job_run_id = self.report_jobrun(run_args, kwargs['execution_date']) + run = super(MarquezDag, self).create_dagrun(*args, **kwargs) + self._job_id_mapping.set(JobIdMapping.make_key(run.dag_id, run.run_id), mqz_job_run_id) return run def handle_callback(self, *args, **kwargs): + self.report_jobrun_change(args[0], **kwargs) + return super().handle_callback(*args, **kwargs) + + def report_jobrun(self, run_args, execution_date): job_name = self.dag_id - mqz_job_run_id = self.get_and_delete(args[0].run_id) + job_run_args = run_args + start_time = pendulum.instance(execution_date).to_datetime_string() + end_time = pendulum.instance(self.following_schedule(execution_date)).to_datetime_string() + mqz_client = self.get_mqz_client() + mqz_client.set_namespace(self.mqz_namespace) + mqz_client.create_job(job_name, self.mqz_location, self.mqz_input_datasets, + self.mqz_output_datasets, self.description) + mqz_job_run_id = str(mqz_client.create_job_run( + job_name, job_run_args=job_run_args, nominal_start_time=start_time, nominal_end_time=end_time).run_id) + mqz_client.mark_job_run_running(mqz_job_run_id) - if mqz_job_run_id: + self.log_marquez_event('job_running', + namespace=self.mqz_namespace, + name=job_name, + description=self.description, + location=self.mqz_location, + runArgs=job_run_args, + nominal_start_time=start_time, + nominal_end_time=end_time, + jobrun_id=mqz_job_run_id, + inputDatasetUrns=self.mqz_input_datasets, + outputDatasetUrns=self.mqz_output_datasets) + return mqz_job_run_id + def report_jobrun_change(self, dagrun, **kwargs): + mqz_job_run_id = self._job_id_mapping.pop(JobIdMapping.make_key(dagrun.dag_id, dagrun.run_id)) + if mqz_job_run_id: if kwargs.get('success'): - self.mqz_client.mark_job_run_completed(mqz_job_run_id) - self.marquez_log('job_state_change', - json.dumps({'job_name': job_name, - 'jobrun_id': mqz_job_run_id, - 'state': 'COMPLETED'})) + self.get_mqz_client().mark_job_run_completed(mqz_job_run_id) else: - self.mqz_client.mark_job_run_failed(mqz_job_run_id) - self.marquez_log('job_state_change', - json.dumps({'job_name': job_name, - 'jobrun_id': mqz_job_run_id, - 'state': 'FAILED'})) - - else: - # TODO warn that the jobrun_id couldn't be found - pass - - return super().handle_callback(*args, **kwargs) - - @provide_session - def get_and_delete(self, key, session=None): - q = session.query(airflow.models.Variable).filter(airflow.models.Variable.key == key) - if q.first() is None: - return - else: - val = q.first().val - q.delete(synchronize_session=False) - return val + self.get_mqz_client().mark_job_run_failed(mqz_job_run_id) + self.log_marquez_event('job_state_change' if mqz_job_run_id else 'job_state_change_LOST', + job_name=self.dag_id, + jobrun_id=mqz_job_run_id, + state='COMPLETED' if kwargs.get('success') else 'FAILED', + reason=kwargs['reason']) @provide_session - def marquez_log(self, event, extras, session=None): - session.add(airflow.models.Log( + def log_marquez_event(self, event, session=None, **kwargs): + session.add(Log( event=event, task_instance=None, owner="marquez", - extra=extras, + extra=json.dumps(kwargs), task_id=None, dag_id=self.dag_id)) + + def get_mqz_client(self): + if not self._mqz_client: + self._mqz_client = MarquezClient() + return self._mqz_client diff --git a/marquez/utils.py b/marquez/utils.py new file mode 100644 index 00000000000000..b89a9cb896c9be --- /dev/null +++ b/marquez/utils.py @@ -0,0 +1,28 @@ +import airflow +from airflow.utils.db import provide_session + + +class JobIdMapping: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(JobIdMapping, cls).__new__(cls, *args, **kwargs) + return cls._instance + + def set(self, key, val): + airflow.models.Variable.set(key, val) + + @provide_session + def pop(self, key, session=None): + q = session.query(airflow.models.Variable).filter(airflow.models.Variable.key == key) + if not q.first(): + return + else: + val = q.first().val + q.delete(synchronize_session=False) + return val + + @staticmethod + def make_key(job_name, run_id): + return "mqz_id_mapping-{}-{}".format(job_name, run_id) diff --git a/test/test_dag_extension.py b/test/test_dag_extension.py new file mode 100644 index 00000000000000..03b6323a771966 --- /dev/null +++ b/test/test_dag_extension.py @@ -0,0 +1,106 @@ +from datetime import datetime +from unittest.mock import Mock, create_autospec, patch + +import pytest + +import airflow.models +import marquez.utils +import pendulum +from airflow.utils.state import State +from croniter import croniter +from marquez.airflow import MarquezDag +from marquez_client.marquez import MarquezClient + + +class Context: + location = 'github://test_dag_location' + dag_id = 'test-dag-1' + namespace = 'test-namespace-1' + data_inputs = ["s3://data_input_1", "s3://data_input_2"] + data_outputs = ["s3://some_output_data"] + owner = 'test_owner' + description = 'this is a test DAG' + airflow_run_id = 'airflow_run_id_123456' + mqz_run_id = '71d29487-0b54-4ae1-9295-efd87f190c57' + start_date = datetime(2019, 1, 31, 0, 0, 0) + execution_date = datetime(2019, 2, 2, 0, 0, 0) + schedule_interval = '*/10 * * * *' + + dag = None + + def __init__(self): + self.dag = MarquezDag( + self.dag_id, + schedule_interval=self.schedule_interval, + default_args={'mqz_namespace': self.namespace, + 'mqz_location': self.location, + 'mqz_input_datasets': self.data_inputs, + 'mqz_output_datasets': self.data_outputs, + 'owner': self.owner, + 'depends_on_past': False, + 'start_date': self.start_date}, + description=self.description) + + +@pytest.fixture(scope="module") +def context(): + return Context() + + +@patch.object(airflow.models.DAG, 'create_dagrun') +@patch.object(marquez.utils.JobIdMapping, 'set') +def test_create_dagrun(mock_set, mock_dag_run, context): + + dag = context.dag + mock_mqz_client = make_mock_mqz_client(context.mqz_run_id) + dag._mqz_client = mock_mqz_client # Use a mock marquez-python client + mock_dag_run.return_value = make_mock_airflow_jobrun(dag.dag_id, context.airflow_run_id) + + # trigger an airflow DagRun + dag.create_dagrun(state=State.RUNNING, run_id=context.airflow_run_id, execution_date=context.execution_date) + + # check Marquez client was called with expected arguments + mock_mqz_client.set_namespace.assert_called_with(context.namespace) + mock_mqz_client.create_job.assert_called_once_with(context.dag_id, context.location, context.data_inputs, + context.data_outputs, context.description) + mock_mqz_client.create_job_run.assert_called_once_with( + context.dag_id, + "{}", + to_airflow_datetime_str(context.execution_date), + to_airflow_datetime_str(compute_end_time(context.schedule_interval, context.execution_date))) + + # Test if airflow's create_dagrun() is called with the expected arguments + mock_dag_run.assert_called_once_with(state=State.RUNNING, + run_id=context.airflow_run_id, + execution_date=context.execution_date) + + # Assert there is a job_id mapping being created + mock_set.assert_called_once_with(marquez.utils.JobIdMapping.make_key(context.dag_id, context.airflow_run_id), + context.mqz_run_id) + + +def make_mock_mqz_client(run_id): + mock_mqz_run = Mock() + mock_mqz_run.run_id = run_id + mock_mqz_client = create_autospec(MarquezClient) + mock_mqz_client.create_job_run.return_value = mock_mqz_run + return mock_mqz_client + + +def make_mock_airflow_jobrun(dag_id, airflow_run_id): + mock_airflow_jobrun = Mock() + mock_airflow_jobrun.run_id = airflow_run_id + mock_airflow_jobrun.dag_id = dag_id + return mock_airflow_jobrun + + +def compute_end_time(schedule_interval, start_time): + return datetime.utcfromtimestamp(croniter(schedule_interval, start_time).get_next()) + + +def to_airflow_datetime_str(dt): + return pendulum.instance(dt).to_datetime_string() + + +if __name__ == "__main__": + pytest.main()