Skip to content

Commit

Permalink
Added tests for the MarqueDag library (apache#3)
Browse files Browse the repository at this point in the history
* Added tests for the MarqueDag library
  • Loading branch information
roaraya8 authored Feb 15, 2019
1 parent 477560f commit 1299964
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 63 deletions.
118 changes: 55 additions & 63 deletions marquez/airflow.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,85 @@
import json
import pendulum
import airflow.models
from airflow.models import DAG, Log
from airflow.utils.db import provide_session
from marquez_client.marquez import MarquezClient

from marquez.utils import JobIdMapping

class MarquezDag(airflow.models.DAG):

class MarquezDag(DAG):
_job_id_mapping = None
_mqz_client = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mqz_client = MarquezClient()
self.mqz_namespace = kwargs['default_args'].get('mqz_namespace', 'unknown')
self.mqz_location = kwargs['default_args'].get('mqz_location', 'unknown')
self.mqz_input_datasets = kwargs['default_args'].get('mqz_input_datasets', [])
self.mqz_output_datasets = kwargs['default_args'].get('mqz_output_datasets', [])
self._job_id_mapping = JobIdMapping()

def create_dagrun(self, *args, **kwargs):
job_name = self.dag_id
job_run_args = "{}" # TODO retrieve from DAG/tasks
start_time = pendulum.instance(kwargs['execution_date']).to_datetime_string()
end_time = None

self.mqz_client.set_namespace(self.mqz_namespace)
self.mqz_client.create_job(job_name, self.mqz_location, self.mqz_input_datasets, self.mqz_output_datasets,
self.description)
mqz_job_run_id = self.mqz_client.create_job_run(job_name, job_run_args=job_run_args,
nominal_start_time=start_time,
nominal_end_time=end_time).run_id
self.mqz_client.mark_job_run_running(mqz_job_run_id)

self.marquez_log('job_running', json.dumps(
{'namespace': self.mqz_namespace,
'name': job_name,
'description': self.description,
'location': self.mqz_location,
'runArgs': job_run_args,
'nominal_start_time': start_time,
'nominal_end_time': end_time,
'jobrun_id': mqz_job_run_id,
'inputDatasetUrns': self.mqz_input_datasets,
'outputDatasetUrns': self.mqz_output_datasets
}))

run = super().create_dagrun(*args, **kwargs)
airflow.models.Variable.set(run.run_id, mqz_job_run_id)

run_args = "{}" # TODO extract the run Args from the tasks
mqz_job_run_id = self.report_jobrun(run_args, kwargs['execution_date'])
run = super(MarquezDag, self).create_dagrun(*args, **kwargs)
self._job_id_mapping.set(JobIdMapping.make_key(run.dag_id, run.run_id), mqz_job_run_id)
return run

def handle_callback(self, *args, **kwargs):
self.report_jobrun_change(args[0], **kwargs)
return super().handle_callback(*args, **kwargs)

def report_jobrun(self, run_args, execution_date):
job_name = self.dag_id
mqz_job_run_id = self.get_and_delete(args[0].run_id)
job_run_args = run_args
start_time = pendulum.instance(execution_date).to_datetime_string()
end_time = pendulum.instance(self.following_schedule(execution_date)).to_datetime_string()
mqz_client = self.get_mqz_client()
mqz_client.set_namespace(self.mqz_namespace)
mqz_client.create_job(job_name, self.mqz_location, self.mqz_input_datasets,
self.mqz_output_datasets, self.description)
mqz_job_run_id = str(mqz_client.create_job_run(
job_name, job_run_args=job_run_args, nominal_start_time=start_time, nominal_end_time=end_time).run_id)
mqz_client.mark_job_run_running(mqz_job_run_id)

if mqz_job_run_id:
self.log_marquez_event('job_running',
namespace=self.mqz_namespace,
name=job_name,
description=self.description,
location=self.mqz_location,
runArgs=job_run_args,
nominal_start_time=start_time,
nominal_end_time=end_time,
jobrun_id=mqz_job_run_id,
inputDatasetUrns=self.mqz_input_datasets,
outputDatasetUrns=self.mqz_output_datasets)
return mqz_job_run_id

def report_jobrun_change(self, dagrun, **kwargs):
mqz_job_run_id = self._job_id_mapping.pop(JobIdMapping.make_key(dagrun.dag_id, dagrun.run_id))
if mqz_job_run_id:
if kwargs.get('success'):
self.mqz_client.mark_job_run_completed(mqz_job_run_id)
self.marquez_log('job_state_change',
json.dumps({'job_name': job_name,
'jobrun_id': mqz_job_run_id,
'state': 'COMPLETED'}))
self.get_mqz_client().mark_job_run_completed(mqz_job_run_id)
else:
self.mqz_client.mark_job_run_failed(mqz_job_run_id)
self.marquez_log('job_state_change',
json.dumps({'job_name': job_name,
'jobrun_id': mqz_job_run_id,
'state': 'FAILED'}))

else:
# TODO warn that the jobrun_id couldn't be found
pass

return super().handle_callback(*args, **kwargs)

@provide_session
def get_and_delete(self, key, session=None):
q = session.query(airflow.models.Variable).filter(airflow.models.Variable.key == key)
if q.first() is None:
return
else:
val = q.first().val
q.delete(synchronize_session=False)
return val
self.get_mqz_client().mark_job_run_failed(mqz_job_run_id)
self.log_marquez_event('job_state_change' if mqz_job_run_id else 'job_state_change_LOST',
job_name=self.dag_id,
jobrun_id=mqz_job_run_id,
state='COMPLETED' if kwargs.get('success') else 'FAILED',
reason=kwargs['reason'])

@provide_session
def marquez_log(self, event, extras, session=None):
session.add(airflow.models.Log(
def log_marquez_event(self, event, session=None, **kwargs):
session.add(Log(
event=event,
task_instance=None,
owner="marquez",
extra=extras,
extra=json.dumps(kwargs),
task_id=None,
dag_id=self.dag_id))

def get_mqz_client(self):
if not self._mqz_client:
self._mqz_client = MarquezClient()
return self._mqz_client
28 changes: 28 additions & 0 deletions marquez/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import airflow
from airflow.utils.db import provide_session


class JobIdMapping:
_instance = None

def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(JobIdMapping, cls).__new__(cls, *args, **kwargs)
return cls._instance

def set(self, key, val):
airflow.models.Variable.set(key, val)

@provide_session
def pop(self, key, session=None):
q = session.query(airflow.models.Variable).filter(airflow.models.Variable.key == key)
if not q.first():
return
else:
val = q.first().val
q.delete(synchronize_session=False)
return val

@staticmethod
def make_key(job_name, run_id):
return "mqz_id_mapping-{}-{}".format(job_name, run_id)
106 changes: 106 additions & 0 deletions test/test_dag_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from datetime import datetime
from unittest.mock import Mock, create_autospec, patch

import pytest

import airflow.models
import marquez.utils
import pendulum
from airflow.utils.state import State
from croniter import croniter
from marquez.airflow import MarquezDag
from marquez_client.marquez import MarquezClient


class Context:
location = 'github://test_dag_location'
dag_id = 'test-dag-1'
namespace = 'test-namespace-1'
data_inputs = ["s3://data_input_1", "s3://data_input_2"]
data_outputs = ["s3://some_output_data"]
owner = 'test_owner'
description = 'this is a test DAG'
airflow_run_id = 'airflow_run_id_123456'
mqz_run_id = '71d29487-0b54-4ae1-9295-efd87f190c57'
start_date = datetime(2019, 1, 31, 0, 0, 0)
execution_date = datetime(2019, 2, 2, 0, 0, 0)
schedule_interval = '*/10 * * * *'

dag = None

def __init__(self):
self.dag = MarquezDag(
self.dag_id,
schedule_interval=self.schedule_interval,
default_args={'mqz_namespace': self.namespace,
'mqz_location': self.location,
'mqz_input_datasets': self.data_inputs,
'mqz_output_datasets': self.data_outputs,
'owner': self.owner,
'depends_on_past': False,
'start_date': self.start_date},
description=self.description)


@pytest.fixture(scope="module")
def context():
return Context()


@patch.object(airflow.models.DAG, 'create_dagrun')
@patch.object(marquez.utils.JobIdMapping, 'set')
def test_create_dagrun(mock_set, mock_dag_run, context):

dag = context.dag
mock_mqz_client = make_mock_mqz_client(context.mqz_run_id)
dag._mqz_client = mock_mqz_client # Use a mock marquez-python client
mock_dag_run.return_value = make_mock_airflow_jobrun(dag.dag_id, context.airflow_run_id)

# trigger an airflow DagRun
dag.create_dagrun(state=State.RUNNING, run_id=context.airflow_run_id, execution_date=context.execution_date)

# check Marquez client was called with expected arguments
mock_mqz_client.set_namespace.assert_called_with(context.namespace)
mock_mqz_client.create_job.assert_called_once_with(context.dag_id, context.location, context.data_inputs,
context.data_outputs, context.description)
mock_mqz_client.create_job_run.assert_called_once_with(
context.dag_id,
"{}",
to_airflow_datetime_str(context.execution_date),
to_airflow_datetime_str(compute_end_time(context.schedule_interval, context.execution_date)))

# Test if airflow's create_dagrun() is called with the expected arguments
mock_dag_run.assert_called_once_with(state=State.RUNNING,
run_id=context.airflow_run_id,
execution_date=context.execution_date)

# Assert there is a job_id mapping being created
mock_set.assert_called_once_with(marquez.utils.JobIdMapping.make_key(context.dag_id, context.airflow_run_id),
context.mqz_run_id)


def make_mock_mqz_client(run_id):
mock_mqz_run = Mock()
mock_mqz_run.run_id = run_id
mock_mqz_client = create_autospec(MarquezClient)
mock_mqz_client.create_job_run.return_value = mock_mqz_run
return mock_mqz_client


def make_mock_airflow_jobrun(dag_id, airflow_run_id):
mock_airflow_jobrun = Mock()
mock_airflow_jobrun.run_id = airflow_run_id
mock_airflow_jobrun.dag_id = dag_id
return mock_airflow_jobrun


def compute_end_time(schedule_interval, start_time):
return datetime.utcfromtimestamp(croniter(schedule_interval, start_time).get_next())


def to_airflow_datetime_str(dt):
return pendulum.instance(dt).to_datetime_string()


if __name__ == "__main__":
pytest.main()

0 comments on commit 1299964

Please sign in to comment.