diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py index b554b6ecb0b969..c103025202a5e9 100644 --- a/airflow/jobs/base_job.py +++ b/airflow/jobs/base_job.py @@ -155,7 +155,7 @@ def heartbeat(self): This also allows for any job to be killed externally, regardless of who is running it or on which machine it is running. - Note that if your heartbeat is set to 60 seconds and you call this + Note that if your heart rate is set to 60 seconds and you call this method after 10 seconds of processing since the last heartbeat, it will sleep 50 seconds to complete the 60 seconds and keep a steady heart rate. If you go over 60 seconds before calling it, it won't @@ -172,17 +172,14 @@ def heartbeat(self): if self.state == State.SHUTDOWN: self.kill() - is_unit_test = conf.getboolean('core', 'unit_test_mode') - if not is_unit_test: - # Figure out how long to sleep for - sleep_for = 0 - if self.latest_heartbeat: - seconds_remaining = self.heartrate - \ - (timezone.utcnow() - self.latest_heartbeat)\ - .total_seconds() - sleep_for = max(0, seconds_remaining) - - sleep(sleep_for) + # Figure out how long to sleep for + sleep_for = 0 + if self.latest_heartbeat: + seconds_remaining = self.heartrate - \ + (timezone.utcnow() - self.latest_heartbeat)\ + .total_seconds() + sleep_for = max(0, seconds_remaining) + sleep(sleep_for) # Update last heartbeat time with create_session() as session: diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index eda796c7b5b652..aea022cf11c6f3 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -20,7 +20,6 @@ import os import signal -import time from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -112,13 +111,6 @@ def signal_handler(signum, frame): "exceeded limit ({}s)." .format(time_since_last_heartbeat, heartbeat_time_limit)) - - if time_since_last_heartbeat < self.heartrate: - sleep_for = self.heartrate - time_since_last_heartbeat - self.log.debug("Time since last heartbeat(%.2f s) < heartrate(%s s)" - ", sleeping for %s s", time_since_last_heartbeat, - self.heartrate, sleep_for) - time.sleep(sleep_for) finally: self.on_kill() diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index d8af907226a84d..9aa2f6c5e6308e 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -43,6 +43,9 @@ class TestLocalTaskJob(unittest.TestCase): def setUp(self): clear_db_runs() + patcher = patch('airflow.jobs.base_job.sleep') + self.addCleanup(patcher.stop) + self.mock_base_job_sleep = patcher.start() def test_localtaskjob_essential_attr(self): """ @@ -111,7 +114,7 @@ def test_localtaskjob_heartbeat(self, mock_pid): session.merge(ti) session.commit() - job1.heartbeat_callback() + job1.heartbeat_callback(session=None) mock_pid.return_value = 2 self.assertRaises(AirflowException, job1.heartbeat_callback) @@ -122,11 +125,7 @@ def test_heartbeat_failed_fast(self, mock_getpid): Test that task heartbeat will sleep when it fails fast """ mock_getpid.return_value = 1 - - heartbeat_records = [] - - def heartbeat_recorder(**kwargs): - heartbeat_records.append(timezone.utcnow()) + self.mock_base_job_sleep.side_effect = time.sleep with create_session() as session: dagbag = models.DagBag( @@ -152,9 +151,10 @@ def heartbeat_recorder(**kwargs): job = LocalTaskJob(task_instance=ti, executor=TestExecutor(do_update=False)) job.heartrate = 2 - job.heartbeat_callback = heartbeat_recorder + heartbeat_records = [] + job.heartbeat_callback = lambda session: heartbeat_records.append(job.latest_heartbeat) job._execute() - self.assertGreater(len(heartbeat_records), 1) + self.assertGreater(len(heartbeat_records), 2) for i in range(1, len(heartbeat_records)): time1 = heartbeat_records[i - 1] time2 = heartbeat_records[i] @@ -242,3 +242,52 @@ def test_localtaskjob_double_trigger(self): self.assertEqual(ti.state, State.RUNNING) session.close() + + def test_localtaskjob_maintain_heart_rate(self): + dagbag = models.DagBag( + dag_folder=TEST_DAG_FOLDER, + include_examples=False, + ) + dag = dagbag.dags.get('test_localtaskjob_double_trigger') + task = dag.get_task('test_localtaskjob_double_trigger_task') + + session = settings.Session() + + dag.clear() + dag.create_dagrun(run_id="test", + state=State.SUCCESS, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session) + + ti_run = TI(task=task, execution_date=DEFAULT_DATE) + ti_run.refresh_from_db() + job1 = LocalTaskJob(task_instance=ti_run, + executor=SequentialExecutor()) + + # this should make sure we only heartbeat once and exit at the second + # loop in _execute() + return_codes = [None, 0] + + def multi_return_code(): + return return_codes.pop(0) + + time_start = time.time() + from airflow.task.task_runner.standard_task_runner import StandardTaskRunner + with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_start: + with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code: + mock_ret_code.side_effect = multi_return_code + job1.run() + self.assertEqual(mock_start.call_count, 1) + self.assertEqual(mock_ret_code.call_count, 2) + time_end = time.time() + + self.assertEqual(self.mock_base_job_sleep.call_count, 1) + self.assertEqual(job1.state, State.SUCCESS) + + # Consider we have patched sleep call, it should not be sleeping to + # keep up with the heart rate in other unpatched places + # + # We already make sure patched sleep call is only called once + self.assertLess(time_end - time_start, job1.heartrate) + session.close()