diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index dbae45dc79af05..0e8ea1770f2f1d 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -612,6 +612,9 @@ def start(self): user code. """ + # Start a new process group + os.setpgid(0, 0) + self.log.info("Processing files using up to %s processes at a time ", self._parallelism) self.log.info("Process each file at most once every %s seconds", self._file_process_interval) self.log.info( diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 1c5765a7dd2bef..f2dffd378c1590 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -262,71 +262,82 @@ def f(t): return s -def reap_process_group(pid, log, sig=signal.SIGTERM, +def reap_process_group(pgid, log, sig=signal.SIGTERM, timeout=DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM): """ - Tries really hard to terminate all children (including grandchildren). Will send + Tries really hard to terminate all processes in the group (including grandchildren). Will send sig (SIGTERM) to the process group of pid. If any process is alive after timeout a SIGKILL will be send. :param log: log handler - :param pid: pid to kill + :param pgid: process group id to kill :param sig: signal type :param timeout: how much time a process has to terminate """ + returncodes = {} + def on_terminate(p): log.info("Process %s (%s) terminated with exit code %s", p, p.pid, p.returncode) + returncodes[p.pid] = p.returncode - if pid == os.getpid(): - raise RuntimeError("I refuse to kill myself") - - try: - parent = psutil.Process(pid) - except psutil.NoSuchProcess: - # Race condition - the process already exited - return + def signal_procs(sig): + try: + os.killpg(pgid, sig) + except OSError as err: + # If operation not permitted error is thrown due to run_as_user, + # use sudo -n(--non-interactive) to kill the process + if err.errno == errno.EPERM: + subprocess.check_call( + ["sudo", "-n", "kill", "-" + str(sig)] + map(children, lambda p: str(p.pid)) + ) + else: + raise - children = parent.children(recursive=True) - children.append(parent) + if pgid == os.getpgid(0): + raise RuntimeError("I refuse to kill myself") try: - pg = os.getpgid(pid) - except OSError as err: - # Skip if not such process - we experience a race and it just terminated - if err.errno == errno.ESRCH: - return - raise + parent = psutil.Process(pgid) - log.info("Sending %s to GPID %s", sig, pg) + children = parent.children(recursive=True) + children.append(parent) + except psutil.NoSuchProcess: + # The process already exited, but maybe it's children haven't. + children = [] + for p in psutil.process_iter(): + try: + if os.getpgid(p.pid) == pgid and p.pid != 0: + children.append(p) + except OSError: + pass + + log.info("Sending %s to GPID %s", sig, pgid) try: - os.killpg(os.getpgid(pid), sig) + signal_procs(sig) except OSError as err: + # No such process, which means there is no such process group - our job + # is done if err.errno == errno.ESRCH: - return - # If operation not permitted error is thrown due to run_as_user, - # use sudo -n(--non-interactive) to kill the process - if err.errno == errno.EPERM: - subprocess.check_call(["sudo", "-n", "kill", "-" + str(sig), str(os.getpgid(pid))]) - raise + return returncodes _, alive = psutil.wait_procs(children, timeout=timeout, callback=on_terminate) if alive: for p in alive: - log.warning("process %s (%s) did not respond to SIGTERM. Trying SIGKILL", p, pid) + log.warning("process %s did not respond to SIGTERM. Trying SIGKILL", p) try: - os.killpg(os.getpgid(pid), signal.SIGKILL) + signal_procs(signal.SIGKILL) except OSError as err: - if err.errno == errno.ESRCH: - return - raise + if err.errno != errno.ESRCH: + raise - gone, alive = psutil.wait_procs(alive, timeout=timeout, callback=on_terminate) + _, alive = psutil.wait_procs(alive, timeout=timeout, callback=on_terminate) if alive: for p in alive: log.error("Process %s (%s) could not be killed. Giving up.", p, p.pid) + return returncodes def parse_template_string(template_string): diff --git a/tests/dags/test_on_kill.py b/tests/dags/test_on_kill.py index c5ae29852b8519..04b6d437766dfc 100644 --- a/tests/dags/test_on_kill.py +++ b/tests/dags/test_on_kill.py @@ -25,6 +25,11 @@ class DummyWithOnKill(DummyOperator): def execute(self, context): + import os + # This runs extra processes, so that we can be sure that we correctly + # tidy up all processes launched by a task when killing + if not os.fork(): + os.system('sleep 10') time.sleep(10) def on_kill(self): diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index 46ae1b1efe5909..24484b228cb492 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -75,7 +75,7 @@ def test_start_and_terminate(self): local_task_job.task_instance = mock.MagicMock() local_task_job.task_instance.run_as_user = None local_task_job.task_instance.command_as_list.return_value = [ - 'airflow', 'tasks', 'test', 'test_mark_success', 'task1', '2016-01-01' + 'airflow', 'tasks', 'test', 'test_on_kill', 'task1', '2016-01-01' ] runner = StandardTaskRunner(local_task_job) @@ -84,19 +84,14 @@ def test_start_and_terminate(self): pgid = os.getpgid(runner.process.pid) self.assertGreater(pgid, 0) + self.assertNotEqual(pgid, os.getpgid(0), "Task should be in a different process group to us") - procs = [] - for p in psutil.process_iter(): - try: - if os.getpgid(p.pid) == pgid and p.pid != 0: - procs.append(p) - except OSError: - pass + processes = list(self._procs_in_pgroup(pgid)) runner.terminate() - for p in procs: - self.assertFalse(psutil.pid_exists(p.pid), "{} is still alive".format(p)) + for process in processes: + self.assertFalse(psutil.pid_exists(process.pid), "{} is still alive".format(process)) self.assertIsNotNone(runner.return_code()) @@ -105,23 +100,19 @@ def test_start_and_terminate_run_as_user(self): local_task_job.task_instance = mock.MagicMock() local_task_job.task_instance.run_as_user = getpass.getuser() local_task_job.task_instance.command_as_list.return_value = [ - 'airflow', 'tasks', 'test', 'test_mark_success', 'task1', '2016-01-01' + 'airflow', 'tasks', 'test', 'test_on_kill', 'task1', '2016-01-01' ] runner = StandardTaskRunner(local_task_job) + runner.start() time.sleep(0.5) pgid = os.getpgid(runner.process.pid) self.assertGreater(pgid, 0) + self.assertNotEqual(pgid, os.getpgid(0), "Task should be in a different process group to us") - processes = [] - for process in psutil.process_iter(): - try: - if os.getpgid(process.pid) == pgid and process.pid != 0: - processes.append(process) - except OSError: - pass + processes = list(self._procs_in_pgroup(pgid)) runner.terminate() @@ -162,8 +153,15 @@ def test_on_kill(self): runner = StandardTaskRunner(job1) runner.start() - # Give the task some time to startup - time.sleep(10) + # give the task some time to startup + time.sleep(3) + + pgid = os.getpgid(runner.process.pid) + self.assertGreater(pgid, 0) + self.assertNotEqual(pgid, os.getpgid(0), "Task should be in a different process group to us") + + processes = list(self._procs_in_pgroup(pgid)) + runner.terminate() # Wait some time for the result @@ -175,6 +173,18 @@ def test_on_kill(self): with open(path, "r") as f: self.assertEqual("ON_KILL_TEST", f.readline()) + for process in processes: + self.assertFalse(psutil.pid_exists(process.pid), "{} is still alive".format(process)) + + @staticmethod + def _procs_in_pgroup(pgid): + for p in psutil.process_iter(attrs=['pid', 'name']): + try: + if os.getpgid(p.pid) == pgid and p.pid != 0: + yield p + except OSError: + pass + if __name__ == '__main__': unittest.main()