Skip to content
This repository has been archived by the owner on May 22, 2021. It is now read-only.

Commit

Permalink
[AIRFLOW-5902] avoid unnecessary sleep to maintain local task job hea…
Browse files Browse the repository at this point in the history
…rt rate (apache#6553)

sleep to maintain heart rate is already done by the hearbeat() call
  • Loading branch information
Qingping Hou authored and galuszkak committed Mar 5, 2020
1 parent 3a010b1 commit 8442d26
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 28 deletions.
21 changes: 9 additions & 12 deletions airflow/jobs/base_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def heartbeat(self):
This also allows for any job to be killed externally, regardless
of who is running it or on which machine it is running.
Note that if your heartbeat is set to 60 seconds and you call this
Note that if your heart rate is set to 60 seconds and you call this
method after 10 seconds of processing since the last heartbeat, it
will sleep 50 seconds to complete the 60 seconds and keep a steady
heart rate. If you go over 60 seconds before calling it, it won't
Expand All @@ -172,17 +172,14 @@ def heartbeat(self):
if self.state == State.SHUTDOWN:
self.kill()

is_unit_test = conf.getboolean('core', 'unit_test_mode')
if not is_unit_test:
# Figure out how long to sleep for
sleep_for = 0
if self.latest_heartbeat:
seconds_remaining = self.heartrate - \
(timezone.utcnow() - self.latest_heartbeat)\
.total_seconds()
sleep_for = max(0, seconds_remaining)

sleep(sleep_for)
# Figure out how long to sleep for
sleep_for = 0
if self.latest_heartbeat:
seconds_remaining = self.heartrate - \
(timezone.utcnow() - self.latest_heartbeat)\
.total_seconds()
sleep_for = max(0, seconds_remaining)
sleep(sleep_for)

# Update last heartbeat time
with create_session() as session:
Expand Down
8 changes: 0 additions & 8 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import os
import signal
import time

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -112,13 +111,6 @@ def signal_handler(signum, frame):
"exceeded limit ({}s)."
.format(time_since_last_heartbeat,
heartbeat_time_limit))

if time_since_last_heartbeat < self.heartrate:
sleep_for = self.heartrate - time_since_last_heartbeat
self.log.debug("Time since last heartbeat(%.2f s) < heartrate(%s s)"
", sleeping for %s s", time_since_last_heartbeat,
self.heartrate, sleep_for)
time.sleep(sleep_for)
finally:
self.on_kill()

Expand Down
65 changes: 57 additions & 8 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
class TestLocalTaskJob(unittest.TestCase):
def setUp(self):
clear_db_runs()
patcher = patch('airflow.jobs.base_job.sleep')
self.addCleanup(patcher.stop)
self.mock_base_job_sleep = patcher.start()

def test_localtaskjob_essential_attr(self):
"""
Expand Down Expand Up @@ -111,7 +114,7 @@ def test_localtaskjob_heartbeat(self, mock_pid):
session.merge(ti)
session.commit()

job1.heartbeat_callback()
job1.heartbeat_callback(session=None)

mock_pid.return_value = 2
self.assertRaises(AirflowException, job1.heartbeat_callback)
Expand All @@ -122,11 +125,7 @@ def test_heartbeat_failed_fast(self, mock_getpid):
Test that task heartbeat will sleep when it fails fast
"""
mock_getpid.return_value = 1

heartbeat_records = []

def heartbeat_recorder(**kwargs):
heartbeat_records.append(timezone.utcnow())
self.mock_base_job_sleep.side_effect = time.sleep

with create_session() as session:
dagbag = models.DagBag(
Expand All @@ -152,9 +151,10 @@ def heartbeat_recorder(**kwargs):

job = LocalTaskJob(task_instance=ti, executor=TestExecutor(do_update=False))
job.heartrate = 2
job.heartbeat_callback = heartbeat_recorder
heartbeat_records = []
job.heartbeat_callback = lambda session: heartbeat_records.append(job.latest_heartbeat)
job._execute()
self.assertGreater(len(heartbeat_records), 1)
self.assertGreater(len(heartbeat_records), 2)
for i in range(1, len(heartbeat_records)):
time1 = heartbeat_records[i - 1]
time2 = heartbeat_records[i]
Expand Down Expand Up @@ -242,3 +242,52 @@ def test_localtaskjob_double_trigger(self):
self.assertEqual(ti.state, State.RUNNING)

session.close()

def test_localtaskjob_maintain_heart_rate(self):
dagbag = models.DagBag(
dag_folder=TEST_DAG_FOLDER,
include_examples=False,
)
dag = dagbag.dags.get('test_localtaskjob_double_trigger')
task = dag.get_task('test_localtaskjob_double_trigger_task')

session = settings.Session()

dag.clear()
dag.create_dagrun(run_id="test",
state=State.SUCCESS,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
session=session)

ti_run = TI(task=task, execution_date=DEFAULT_DATE)
ti_run.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti_run,
executor=SequentialExecutor())

# this should make sure we only heartbeat once and exit at the second
# loop in _execute()
return_codes = [None, 0]

def multi_return_code():
return return_codes.pop(0)

time_start = time.time()
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
with patch.object(StandardTaskRunner, 'start', return_value=None) as mock_start:
with patch.object(StandardTaskRunner, 'return_code') as mock_ret_code:
mock_ret_code.side_effect = multi_return_code
job1.run()
self.assertEqual(mock_start.call_count, 1)
self.assertEqual(mock_ret_code.call_count, 2)
time_end = time.time()

self.assertEqual(self.mock_base_job_sleep.call_count, 1)
self.assertEqual(job1.state, State.SUCCESS)

# Consider we have patched sleep call, it should not be sleeping to
# keep up with the heart rate in other unpatched places
#
# We already make sure patched sleep call is only called once
self.assertLess(time_end - time_start, job1.heartrate)
session.close()

0 comments on commit 8442d26

Please sign in to comment.