Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support lambdas and local functions as callbacks in parallel ExternalSource #3269

Merged
merged 13 commits into from
Aug 26, 2021
Merged
16 changes: 11 additions & 5 deletions dali/python/nvidia/dali/_multiproc/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import multiprocessing
from collections import OrderedDict
from nvidia.dali import backend as _b
from nvidia.dali import pickling
from nvidia.dali._multiproc.worker import worker
from nvidia.dali._multiproc.messages import ScheduledTasks
from nvidia.dali._multiproc.shared_batch import SharedBatchMeta
Expand Down Expand Up @@ -168,7 +169,7 @@ class ProcPool:

def __init__(
self, callbacks, prefetch_queue_depths, num_workers=1, start_method="fork",
initial_chunk_size=1024 * 1024):
initial_chunk_size=1024 * 1024, py_callback_pickler=None):
if len(callbacks) != len(prefetch_queue_depths):
raise RuntimeError("Number of prefetch queues must match number of callbacks")
if any(prefetch_queue_depth <= 0 for prefetch_queue_depth in prefetch_queue_depths):
Expand All @@ -187,6 +188,7 @@ def __init__(
"Alternatively you can change Python workers starting method from ``fork`` to ``spawn`` "
"(see DALI Pipeline's ``py_start_method`` option for details). ")
mp = multiprocessing.get_context(start_method)
callback_pickler = pickling.CustomPickler.create(py_callback_pickler)
if num_workers < 1:
raise RuntimeError("num_workers must be a positive integer")
self._num_workers = num_workers
Expand All @@ -202,10 +204,14 @@ def __init__(
task_r, task_w = mp.Pipe(duplex=False)
res_r, res_w = mp.Pipe(duplex=False)
sock_reader, sock_writer = socket.socketpair()
if callback_pickler is None:
callbacks_ = callbacks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this?

Suggested change
callbacks_ = callbacks
callbacks_arg = callbacks

I'm not a fan of trailing underscore and leading underscore has a meaning in Python which I guess we should not abuse.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

else:
callbacks_ = callback_pickler.dumps(callbacks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now callbacks contains either pickled or raw callbacks?

Copy link
Member Author

@stiepan stiepan Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so in the first case it creates kind of additional layer of serialization where multiprocessing pickler just sees already serialized callbacks. At first I simply set Dali customized pickler to be used directly by multiprocessing but it is unfortunately a global change for the whole process and:

  1. does not work if multiprocessing already started some actual work before DALI was able to modify the context
  2. there's a question if that's not going to interfere with other packages using multiprocessing, like pytorch (it seemed to work okay, but I think that if we can avoid worrying about that let's avoid it.)
  3. for user provided packages for serialization I still would go for this type of double serialization

process = mp.Process(
target=worker,
args=(i, callbacks, prefetch_queue_depths, initial_chunk_size,
task_r, res_w, sock_writer),
args=(i, callbacks_, prefetch_queue_depths, initial_chunk_size,
task_r, res_w, sock_writer, callback_pickler),
)
self._task_pipes.append(task_w)
self._res_pipes.append(res_r)
Expand Down Expand Up @@ -350,7 +356,7 @@ def __init__(self, num_callbacks, queue_depths, pool):
@classmethod
def from_groups(
cls, groups, keep_alive_queue_size, start_method="fork", num_workers=1,
initial_chunk_size=1024 * 1024):
initial_chunk_size=1024 * 1024, py_callback_pickler=None):
"""Creates new WorkerPool instance for given list of ExternalSource groups.

Parameters
Expand All @@ -371,7 +377,7 @@ def from_groups(
"""
callbacks = [group.callback for group in groups]
queue_depths = [keep_alive_queue_size + group.prefetch_queue_depth for group in groups]
pool = ProcPool(callbacks, queue_depths, num_workers, start_method, initial_chunk_size)
pool = ProcPool(callbacks, queue_depths, num_workers, start_method, initial_chunk_size, py_callback_pickler)
return cls(len(callbacks), queue_depths, pool)

def schedule_batch(self, context_i, batch_i, dst_chunk_i, tasks):
Expand Down
6 changes: 4 additions & 2 deletions dali/python/nvidia/dali/_multiproc/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -183,7 +183,7 @@ def close(self):
chunk.close()


def worker(worker_id, callbacks, prefetch_queue_depths, initial_chunk_size, task_pipe, res_pipe, sock):
def worker(worker_id, callbacks, prefetch_queue_depths, initial_chunk_size, task_pipe, res_pipe, sock, callback_pickler):
"""Entry point of worker process.

Computes the data in the main thread, in separate threads:
Expand All @@ -206,6 +206,8 @@ def worker(worker_id, callbacks, prefetch_queue_depths, initial_chunk_size, task
`sock` : socket
Python wrapper around Unix socket used to pass file descriptors identifying shared memory chunk to parent process.
"""
if callback_pickler is not None:
callbacks = callback_pickler.loads(callbacks)
contexts = None
batch_dispatcher = SharedBatchesDispatcher(worker_id, sock, res_pipe)
task_receiver = TaskReceiver(task_pipe)
Expand Down
144 changes: 144 additions & 0 deletions dali/python/nvidia/dali/pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import pickle
import sys
import types
import marshal
import importlib
import io

dummy_lambda = lambda : 0

def set_funcion_state(fun, state):
fun.__globals__.update(state['global_refs'])
fun.__defaults__ = state['defaults']
fun.__kwdefaults__ = state['kwdefaults']

def function_unpickle(name, qualname, code, closure):
code = marshal.loads(code)
globs = {'__builtins__': __builtins__}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "globs"?

Suggested change
globs = {'__builtins__': __builtins__}
builtins = {'__builtins__': __builtins__}

Copy link
Member Author

@stiepan stiepan Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a dictionary containing "module global scope" of the function, the one available later on as fun.__globals__ immutable attribute (you cannot replace the dict but can update its contents). If one called globals() inside the instantiated function this is the object that would be returned. So it is rather a coincidence that this dictionary contains only builtins at this point - all other global references (if any) are added at the set state stage to handle possible cyclic references.
I didn't call it globals on the other hand to avoid shadowing globals() inside function_unpickle

fun = types.FunctionType(code, globs, name, closure=closure)
fun.__qualname__ = qualname
return fun

def function_by_value_reducer(fun):
cl_vars = inspect.getclosurevars(fun)
code = marshal.dumps(fun.__code__)
basic_def = (fun.__name__, fun.__qualname__, code, fun.__closure__)
fun_context = {
'global_refs': cl_vars.globals,
'defaults': fun.__defaults__,
'kwdefaults': fun.__kwdefaults__
}
return function_unpickle, basic_def, fun_context, None, None, set_funcion_state

def module_unpickle(name, origin, submodule_search_locations):
if name in sys.modules:
return sys.modules[name]
spec = importlib.util.spec_from_file_location(
name, origin,
submodule_search_locations=submodule_search_locations)
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
spec.loader.exec_module(module)
return module

def module_reducer(module):
spec = module.__spec__
return module_unpickle, (spec.name, spec.origin, spec.submodule_search_locations)

def set_cell_state(cell, state):
cell.cell_contents = state['cell_contents']

def cell_unpickle():
return types.CellType(None)

def cell_reducer(cell):
return (cell_unpickle, tuple(), {'cell_contents': cell.cell_contents}, None, None, set_cell_state)


class DaliCallbackPickler(pickle.Pickler):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: weird empty line

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one line space between first method and the class ... :?
I can remove it if you think it somehow clashes with the code formatting in DALI though I tried to use it everywhere both in this and other PRs. D:
Most of the classes I can find have docstrings but those are followed with a single empty line too.

def reducer_override(self, obj):
if inspect.ismodule(obj):
return module_reducer(obj)
if isinstance(obj, types.CellType):
return cell_reducer(obj)
if inspect.isfunction(obj):
if isinstance(obj, type(dummy_lambda)) and obj.__name__ == dummy_lambda.__name__ or \
getattr(obj, '_dali_pickle_by_value', False):
return function_by_value_reducer(obj)
if '<locals>' in obj.__qualname__:
Copy link
Contributor

@mzient mzient Aug 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work with regular local functions, such as:

def f():
  def g(): return 42;

?
g.__qualname__ == 'f.<locals>.g', but it's otherwise quite trivial.
How about switching the logic to:

try:
  pickle.dumps(obj)
except (whatever it raises):
  return function_by_value_reducer_(obj)

?

Copy link
Member Author

@stiepan stiepan Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, g will be serialized by value and plain pickling fails as pickle recursively steps down the attributes and raises AttributeError when encountering <locals>.

I'd rather limit this pickle.dumps test to the situations when I am nearly sure that it is going to fail (or maybe should get rid of the test whatsoever?) for efficiency reasons. If the test succeeded it serialized something only for us to discard that result and proceed with a single step of serialization.

It would be nice if we could catch the exception by overriding dump method of the pickler, catch AttributeError and only then fall back to obj.__dali_pickle_by_value = True. Unfortunately it doesn't seem to be (easily) doable as the pickle.Pickler class is in fact a wrapper around C implementation of pickle and overridden dump method would be called only once for a top level object; recursive steps for dependencies/contents will be serialized by C code unaware of override.

try:
pickle.dumps(obj)
except AttributeError as e:
if "Can't pickle local object" in str(e):
return function_by_value_reducer(obj)
return NotImplemented


def dumps(obj, protocol=None, **kwargs):
f = io.BytesIO()
DaliCallbackPickler(f, protocol, **kwargs).dump(obj)
return f.getvalue()


loads = pickle.loads


class CustomPickler:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: weird empty line

@classmethod
def create(cls, py_callback_pickler):
if py_callback_pickler is None or isinstance(py_callback_pickler, cls):
return py_callback_pickler
if hasattr(py_callback_pickler, 'dumps') and hasattr(py_callback_pickler, 'loads'):
return cls.of_reducer(py_callback_pickler)
if isinstance(py_callback_pickler, (tuple, list)):
params = [None] * 3
for i, item in enumerate(py_callback_pickler):
params[i] = item
reducer, kwargs_dumps, kwargs_loads = params
return cls.of_reducer(reducer, kwargs_dumps, kwargs_loads)
raise ValueError("Unsupported py_callback_pickler value provided.")

@classmethod
def of_reducer(cls, reducer, dumps_kwargs=None, loads_kwargs=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a weird name. Is it required by some API? What does this of stand for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, it's not required by any API, I meant create_instance_of_this_class_from_reducer and for some reason used of rather than from.

return cls(reducer.dumps, reducer.loads, dumps_kwargs, loads_kwargs)

def __init__(self, dumps, loads, dumps_kwargs, loads_kwargs):
self._dumps = dumps
self._loads = loads
self.dumps_kwargs = dumps_kwargs or {}
self.loads_kwargs = loads_kwargs or {}

def dumps(self, obj):
return self._dumps(obj, **self.dumps_kwargs)

def loads(self, obj):
return self._loads(obj, **self.loads_kwargs)


def pickle_by_value(fun):
"""
Hints parallel external source to serialize a decorated global function by value
rather than by reference, which would be a default behavior of Python's pickler.
"""
if inspect.isfunction(fun):
setattr(fun, '_dali_pickle_by_value', True)
return fun
else:
raise TypeError("Only functions can be explicitely set to be pickled by value")
28 changes: 26 additions & 2 deletions dali/python/nvidia/dali/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nvidia.dali import tensors as Tensors
from nvidia.dali import types
from nvidia.dali._multiproc.pool import WorkerPool
from nvidia.dali import pickling as dali_pickle
from nvidia.dali.backend import CheckDLPackCapsule
from threading import local as tls
from . import data_node as _data_node
Expand Down Expand Up @@ -135,13 +136,34 @@ class Pipeline(object):
you will need to call :meth:`start_py_workers` before calling :meth:`build` of any
of the pipelines. You can find more details and caveats of both methods in Python's
``multiprocessing`` module documentation.
`py_callback_pickler` : module or tuple, default = nvidia.dali.pickling
If `py_start_method` is set to *spawn* callback passed to parallel ExternalSource must be picklable.
If run in Python3.8 or newer, DALI uses customized pickle (`nvidia.dali.pickling`) when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to advertise this module to the users? I think it could be internal.

Copy link
Member Author

@stiepan stiepan Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to advertise it because apart from serving as default callback pickler it also contains @pickle_by_value decorator whose whole purpose is to be available for manual usage by users.

1.What about moving everything apart from the decorator, loads and dumps methods to some new module inside _multproc?
2. Or renaming all the things that would be moved in 1. to start with underscore?

serializing callbacks to support serialization of local functions and lambdas.

However, if you need to serialize more complex objects like local classes or you are running
older version of Python you can provide external serialization package such as dill or cloudpickle
mzient marked this conversation as resolved.
Show resolved Hide resolved
that implements two methods: `dumps` and `loads` to make DALI use them to serialize
external source callbacks.

Copy link
Contributor

@mzient mzient Aug 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
external source callbacks.
external source callbacks. You can even pass a module directly as ``py_callback_pickler``::
import dill
src = fn.external_source(lambda sample_info: 42, batch=False, parallel=True, py_callback_pickler=dill)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, yo do. :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Valid value for `py_callback_pickler` is either a module/object implementing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Valid value for `py_callback_pickler` is either a module/object implementing
A valid value for `py_callback_pickler` is either a module/object implementing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

dumps and loads methods or a tuple where first item is the module/object and the next
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dumps and loads methods or a tuple where first item is the module/object and the next
``dumps`` and ``loads`` methods or a tuple where first item is the module/object and the next

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

two optional parameters are extra kwargs to be passed when calling dumps and loads respectively.
Provided methods and kwargs must themselves be picklable.

If you run Python3.8 or newer and use default `nvidia.dali.pickling` you can hint DALI to serialize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If you run Python3.8 or newer and use default `nvidia.dali.pickling` you can hint DALI to serialize
If you run Python3.8 or newer and use the default `nvidia.dali.pickling` you can hint DALI to serialize

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

global functions by value rather than by reference by decorating them
with `@dali.pickling.pickle_by_value`. It may be especially useful when working with
Jupyter notebook to work around the issue of worker process being unable to import
the callback defined as a global function inside the notebook.
"""
def __init__(self, batch_size = -1, num_threads = -1, device_id = -1, seed = -1,
exec_pipelined=True, prefetch_queue_depth=2,
exec_async=True, bytes_per_sample=0,
set_affinity=False, max_streams=-1, default_cuda_stream_priority = 0,
*,
enable_memory_stats=False, py_num_workers=1, py_start_method="fork"):
enable_memory_stats=False, py_num_workers=1, py_start_method="fork",
py_callback_pickler=dali_pickle):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
py_callback_pickler=dali_pickle):
py_callback_pickler=None):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self._sinks = []
self._max_batch_size = batch_size
self._num_threads = num_threads
Expand Down Expand Up @@ -172,6 +194,7 @@ def __init__(self, batch_size = -1, num_threads = -1, device_id = -1, seed = -1,
self._default_cuda_stream_priority = default_cuda_stream_priority
self._py_num_workers = py_num_workers
self._py_start_method = py_start_method
self._py_callback_pickler = py_callback_pickler
mzient marked this conversation as resolved.
Show resolved Hide resolved
self._api_type = None
self._skip_api_check = False
self._graph_out = None
Expand Down Expand Up @@ -562,7 +585,8 @@ def _start_py_workers(self):
if not self._parallel_input_callbacks:
return
self._py_pool = WorkerPool.from_groups(
self._parallel_input_callbacks, self._prefetch_queue_depth, self._py_start_method, self._py_num_workers)
self._parallel_input_callbacks, self._prefetch_queue_depth, self._py_start_method,
self._py_num_workers, py_callback_pickler=self._py_callback_pickler)
# ensure processes started by the pool are termineted when pipeline is no longer used
weakref.finalize(self, lambda pool : pool.close(), self._py_pool)
self._py_pool_started = True
Expand Down
24 changes: 24 additions & 0 deletions dali/test/python/import_module_test_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Used in test_external_source_parallel_custom_serialization to check if modules
# are properly imported during callback deserialization. Such test only makes sense
# if this module is not automatically imported when worker process starts, so don't
# import this file globally

import numpy as np

def cb(x):
return np.full((10, 100), x.idx_in_epoch)
Loading