Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TaskStatusReporter class to fix Windows pickle issue #1992

Merged
merged 2 commits into from
Feb 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions luigi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,21 @@ class TaskProcess(multiprocessing.Process):

Mainly for convenience since this is run in a separate process. """

def __init__(self, task, worker_id, result_queue, tracking_url_callback,
status_message_callback, use_multiprocessing=False, worker_timeout=0):
def __init__(self, task, worker_id, result_queue, status_reporter,
use_multiprocessing=False, worker_timeout=0):
super(TaskProcess, self).__init__()
self.task = task
self.worker_id = worker_id
self.result_queue = result_queue
self.tracking_url_callback = tracking_url_callback
self.status_message_callback = status_message_callback
self.status_reporter = status_reporter
if task.worker_timeout is not None:
worker_timeout = task.worker_timeout
self.timeout_time = time.time() + worker_timeout if worker_timeout else None
self.use_multiprocessing = use_multiprocessing or self.timeout_time is not None

def _run_get_new_deps(self):
self.task.set_tracking_url = self.tracking_url_callback
self.task.set_status_message = self.status_message_callback
self.task.set_tracking_url = self.status_reporter.update_tracking_url
self.task.set_status_message = self.status_reporter.update_status

task_gen = self.task.run()

Expand Down Expand Up @@ -246,6 +245,30 @@ def terminate(self):
return super(TaskProcess, self).terminate()


class TaskStatusReporter(object):
"""
Reports task status information to the scheduler.

This object must be pickle-able for passing to `TaskProcess` on systems
where fork method needs to pickle the process object (e.g. Windows).
"""
def __init__(self, scheduler, task_id, worker_id):
self._task_id = task_id
self._worker_id = worker_id
self._scheduler = scheduler

def update_tracking_url(self, tracking_url):
self._scheduler.add_task(
task_id=self._task_id,
worker=self._worker_id,
status=RUNNING,
tracking_url=tracking_url
)

def update_status(self, message):
self._scheduler.set_task_status_message(self._task_id, message)


class SingleProcessPool(object):
"""
Dummy process pool for using a single processor.
Expand Down Expand Up @@ -869,19 +892,9 @@ def _run_task(self, task_id):
task_process.run()

def _create_task_process(self, task):
def update_tracking_url(tracking_url):
self._scheduler.add_task(
task_id=task.task_id,
worker=self._id,
status=RUNNING,
tracking_url=tracking_url,
)

def update_status_message(message):
self._scheduler.set_task_status_message(task.task_id, message)

reporter = TaskStatusReporter(self._scheduler, task.task_id, self._id)
return TaskProcess(
task, self._id, self._task_result_queue, update_tracking_url, update_status_message,
task, self._id, self._task_result_queue, reporter,
use_multiprocessing=bool(self.worker_processes > 1),
worker_timeout=self._config.timeout
)
Expand Down
6 changes: 3 additions & 3 deletions test/worker_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def on_success(self):

task = SuccessTask()
result_queue = multiprocessing.Queue()
task_process = TaskProcess(task, 1, result_queue, lambda: None, lambda: None)
task_process = TaskProcess(task, 1, result_queue, mock.Mock())

with mock.patch.object(result_queue, 'put') as mock_put:
task_process.run()
Expand All @@ -81,7 +81,7 @@ def on_failure(self, exception):

task = FailTask()
result_queue = multiprocessing.Queue()
task_process = TaskProcess(task, 1, result_queue, lambda: None, lambda: None)
task_process = TaskProcess(task, 1, result_queue, mock.Mock())

with mock.patch.object(result_queue, 'put') as mock_put:
task_process.run()
Expand All @@ -100,7 +100,7 @@ def run(self):
queue = mock.Mock()
worker_id = 1

task_process = TaskProcess(task, worker_id, queue, lambda: None, lambda: None)
task_process = TaskProcess(task, worker_id, queue, mock.Mock())
task_process.start()

parent = Process(task_process.pid)
Expand Down