-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support lambdas and local functions as callbacks in parallel ExternalSource #3269
Changes from 9 commits
2ed1dc5
7dd84d4
8b66881
139dc45
a4980a0
9d443fe
72101cd
383a9e1
01f06f5
87e70f0
f7b9fdb
ec3c2eb
2040752
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import multiprocessing | ||
from collections import OrderedDict | ||
from nvidia.dali import backend as _b | ||
from nvidia.dali import pickling | ||
from nvidia.dali._multiproc.worker import worker | ||
from nvidia.dali._multiproc.messages import ScheduledTasks | ||
from nvidia.dali._multiproc.shared_batch import SharedBatchMeta | ||
|
@@ -168,7 +169,7 @@ class ProcPool: | |
|
||
def __init__( | ||
self, callbacks, prefetch_queue_depths, num_workers=1, start_method="fork", | ||
initial_chunk_size=1024 * 1024): | ||
initial_chunk_size=1024 * 1024, py_callback_pickler=None): | ||
if len(callbacks) != len(prefetch_queue_depths): | ||
raise RuntimeError("Number of prefetch queues must match number of callbacks") | ||
if any(prefetch_queue_depth <= 0 for prefetch_queue_depth in prefetch_queue_depths): | ||
|
@@ -187,6 +188,7 @@ def __init__( | |
"Alternatively you can change Python workers starting method from ``fork`` to ``spawn`` " | ||
"(see DALI Pipeline's ``py_start_method`` option for details). ") | ||
mp = multiprocessing.get_context(start_method) | ||
callback_pickler = pickling.CustomPickler.create(py_callback_pickler) | ||
if num_workers < 1: | ||
raise RuntimeError("num_workers must be a positive integer") | ||
self._num_workers = num_workers | ||
|
@@ -202,10 +204,14 @@ def __init__( | |
task_r, task_w = mp.Pipe(duplex=False) | ||
res_r, res_w = mp.Pipe(duplex=False) | ||
sock_reader, sock_writer = socket.socketpair() | ||
if callback_pickler is None: | ||
callbacks_ = callbacks | ||
else: | ||
callbacks_ = callback_pickler.dumps(callbacks) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, so in the first case it creates kind of additional layer of serialization where multiprocessing pickler just sees already serialized callbacks. At first I simply set Dali customized pickler to be used directly by multiprocessing but it is unfortunately a global change for the whole process and:
|
||
process = mp.Process( | ||
target=worker, | ||
args=(i, callbacks, prefetch_queue_depths, initial_chunk_size, | ||
task_r, res_w, sock_writer), | ||
args=(i, callbacks_, prefetch_queue_depths, initial_chunk_size, | ||
task_r, res_w, sock_writer, callback_pickler), | ||
) | ||
self._task_pipes.append(task_w) | ||
self._res_pipes.append(res_r) | ||
|
@@ -350,7 +356,7 @@ def __init__(self, num_callbacks, queue_depths, pool): | |
@classmethod | ||
def from_groups( | ||
cls, groups, keep_alive_queue_size, start_method="fork", num_workers=1, | ||
initial_chunk_size=1024 * 1024): | ||
initial_chunk_size=1024 * 1024, py_callback_pickler=None): | ||
"""Creates new WorkerPool instance for given list of ExternalSource groups. | ||
|
||
Parameters | ||
|
@@ -371,7 +377,7 @@ def from_groups( | |
""" | ||
callbacks = [group.callback for group in groups] | ||
queue_depths = [keep_alive_queue_size + group.prefetch_queue_depth for group in groups] | ||
pool = ProcPool(callbacks, queue_depths, num_workers, start_method, initial_chunk_size) | ||
pool = ProcPool(callbacks, queue_depths, num_workers, start_method, initial_chunk_size, py_callback_pickler) | ||
return cls(len(callbacks), queue_depths, pool) | ||
|
||
def schedule_batch(self, context_i, batch_i, dst_chunk_i, tasks): | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,144 @@ | ||||||
# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
import inspect | ||||||
import pickle | ||||||
import sys | ||||||
import types | ||||||
import marshal | ||||||
import importlib | ||||||
import io | ||||||
|
||||||
dummy_lambda = lambda : 0 | ||||||
|
||||||
def set_funcion_state(fun, state): | ||||||
fun.__globals__.update(state['global_refs']) | ||||||
fun.__defaults__ = state['defaults'] | ||||||
fun.__kwdefaults__ = state['kwdefaults'] | ||||||
|
||||||
def function_unpickle(name, qualname, code, closure): | ||||||
code = marshal.loads(code) | ||||||
globs = {'__builtins__': __builtins__} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why "globs"?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a dictionary containing "module global scope" of the function, the one available later on as |
||||||
fun = types.FunctionType(code, globs, name, closure=closure) | ||||||
fun.__qualname__ = qualname | ||||||
return fun | ||||||
|
||||||
def function_by_value_reducer(fun): | ||||||
cl_vars = inspect.getclosurevars(fun) | ||||||
code = marshal.dumps(fun.__code__) | ||||||
basic_def = (fun.__name__, fun.__qualname__, code, fun.__closure__) | ||||||
fun_context = { | ||||||
'global_refs': cl_vars.globals, | ||||||
'defaults': fun.__defaults__, | ||||||
'kwdefaults': fun.__kwdefaults__ | ||||||
} | ||||||
return function_unpickle, basic_def, fun_context, None, None, set_funcion_state | ||||||
|
||||||
def module_unpickle(name, origin, submodule_search_locations): | ||||||
if name in sys.modules: | ||||||
return sys.modules[name] | ||||||
spec = importlib.util.spec_from_file_location( | ||||||
name, origin, | ||||||
submodule_search_locations=submodule_search_locations) | ||||||
module = importlib.util.module_from_spec(spec) | ||||||
sys.modules[name] = module | ||||||
spec.loader.exec_module(module) | ||||||
return module | ||||||
|
||||||
def module_reducer(module): | ||||||
spec = module.__spec__ | ||||||
return module_unpickle, (spec.name, spec.origin, spec.submodule_search_locations) | ||||||
|
||||||
def set_cell_state(cell, state): | ||||||
cell.cell_contents = state['cell_contents'] | ||||||
|
||||||
def cell_unpickle(): | ||||||
return types.CellType(None) | ||||||
|
||||||
def cell_reducer(cell): | ||||||
return (cell_unpickle, tuple(), {'cell_contents': cell.cell_contents}, None, None, set_cell_state) | ||||||
|
||||||
|
||||||
class DaliCallbackPickler(pickle.Pickler): | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: weird empty line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The one line space between first method and the |
||||||
def reducer_override(self, obj): | ||||||
if inspect.ismodule(obj): | ||||||
return module_reducer(obj) | ||||||
if isinstance(obj, types.CellType): | ||||||
return cell_reducer(obj) | ||||||
if inspect.isfunction(obj): | ||||||
if isinstance(obj, type(dummy_lambda)) and obj.__name__ == dummy_lambda.__name__ or \ | ||||||
getattr(obj, '_dali_pickle_by_value', False): | ||||||
return function_by_value_reducer(obj) | ||||||
if '<locals>' in obj.__qualname__: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it work with regular local functions, such as: def f():
def g(): return 42; ? try:
pickle.dumps(obj)
except (whatever it raises):
return function_by_value_reducer_(obj) ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, g will be serialized by value and plain pickling fails as pickle recursively steps down the attributes and raises AttributeError when encountering I'd rather limit this It would be nice if we could catch the exception by overriding dump method of the pickler, catch AttributeError and only then fall back to obj.__dali_pickle_by_value = True. Unfortunately it doesn't seem to be (easily) doable as the pickle.Pickler class is in fact a wrapper around C implementation of pickle and overridden dump method would be called only once for a top level object; recursive steps for dependencies/contents will be serialized by C code unaware of override. |
||||||
try: | ||||||
pickle.dumps(obj) | ||||||
except AttributeError as e: | ||||||
if "Can't pickle local object" in str(e): | ||||||
return function_by_value_reducer(obj) | ||||||
return NotImplemented | ||||||
|
||||||
|
||||||
def dumps(obj, protocol=None, **kwargs): | ||||||
f = io.BytesIO() | ||||||
DaliCallbackPickler(f, protocol, **kwargs).dump(obj) | ||||||
return f.getvalue() | ||||||
|
||||||
|
||||||
loads = pickle.loads | ||||||
|
||||||
|
||||||
class CustomPickler: | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: weird empty line |
||||||
@classmethod | ||||||
def create(cls, py_callback_pickler): | ||||||
if py_callback_pickler is None or isinstance(py_callback_pickler, cls): | ||||||
return py_callback_pickler | ||||||
if hasattr(py_callback_pickler, 'dumps') and hasattr(py_callback_pickler, 'loads'): | ||||||
return cls.of_reducer(py_callback_pickler) | ||||||
if isinstance(py_callback_pickler, (tuple, list)): | ||||||
params = [None] * 3 | ||||||
for i, item in enumerate(py_callback_pickler): | ||||||
params[i] = item | ||||||
reducer, kwargs_dumps, kwargs_loads = params | ||||||
return cls.of_reducer(reducer, kwargs_dumps, kwargs_loads) | ||||||
raise ValueError("Unsupported py_callback_pickler value provided.") | ||||||
|
||||||
@classmethod | ||||||
def of_reducer(cls, reducer, dumps_kwargs=None, loads_kwargs=None): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a weird name. Is it required by some API? What does this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nah, it's not required by any API, I meant create_instance_of_this_class_from_reducer and for some reason used of rather than from. |
||||||
return cls(reducer.dumps, reducer.loads, dumps_kwargs, loads_kwargs) | ||||||
|
||||||
def __init__(self, dumps, loads, dumps_kwargs, loads_kwargs): | ||||||
self._dumps = dumps | ||||||
self._loads = loads | ||||||
self.dumps_kwargs = dumps_kwargs or {} | ||||||
self.loads_kwargs = loads_kwargs or {} | ||||||
|
||||||
def dumps(self, obj): | ||||||
return self._dumps(obj, **self.dumps_kwargs) | ||||||
|
||||||
def loads(self, obj): | ||||||
return self._loads(obj, **self.loads_kwargs) | ||||||
|
||||||
|
||||||
def pickle_by_value(fun): | ||||||
""" | ||||||
Hints parallel external source to serialize a decorated global function by value | ||||||
rather than by reference, which would be a default behavior of Python's pickler. | ||||||
""" | ||||||
if inspect.isfunction(fun): | ||||||
setattr(fun, '_dali_pickle_by_value', True) | ||||||
return fun | ||||||
else: | ||||||
raise TypeError("Only functions can be explicitely set to be pickled by value") |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -18,6 +18,7 @@ | |||||||||||||
from nvidia.dali import tensors as Tensors | ||||||||||||||
from nvidia.dali import types | ||||||||||||||
from nvidia.dali._multiproc.pool import WorkerPool | ||||||||||||||
from nvidia.dali import pickling as dali_pickle | ||||||||||||||
from nvidia.dali.backend import CheckDLPackCapsule | ||||||||||||||
from threading import local as tls | ||||||||||||||
from . import data_node as _data_node | ||||||||||||||
|
@@ -135,13 +136,34 @@ class Pipeline(object): | |||||||||||||
you will need to call :meth:`start_py_workers` before calling :meth:`build` of any | ||||||||||||||
of the pipelines. You can find more details and caveats of both methods in Python's | ||||||||||||||
``multiprocessing`` module documentation. | ||||||||||||||
`py_callback_pickler` : module or tuple, default = nvidia.dali.pickling | ||||||||||||||
If `py_start_method` is set to *spawn* callback passed to parallel ExternalSource must be picklable. | ||||||||||||||
If run in Python3.8 or newer, DALI uses customized pickle (`nvidia.dali.pickling`) when | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want to advertise this module to the users? I think it could be internal. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I decided to advertise it because apart from serving as default callback pickler it also contains @pickle_by_value decorator whose whole purpose is to be available for manual usage by users. 1.What about moving everything apart from the decorator, loads and dumps methods to some new module inside |
||||||||||||||
serializing callbacks to support serialization of local functions and lambdas. | ||||||||||||||
|
||||||||||||||
However, if you need to serialize more complex objects like local classes or you are running | ||||||||||||||
older version of Python you can provide external serialization package such as dill or cloudpickle | ||||||||||||||
mzient marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
that implements two methods: `dumps` and `loads` to make DALI use them to serialize | ||||||||||||||
external source callbacks. | ||||||||||||||
|
||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Never mind, yo do. :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||||||||
Valid value for `py_callback_pickler` is either a module/object implementing | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||||||||
dumps and loads methods or a tuple where first item is the module/object and the next | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||||||||
two optional parameters are extra kwargs to be passed when calling dumps and loads respectively. | ||||||||||||||
Provided methods and kwargs must themselves be picklable. | ||||||||||||||
|
||||||||||||||
If you run Python3.8 or newer and use default `nvidia.dali.pickling` you can hint DALI to serialize | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||||||||
global functions by value rather than by reference by decorating them | ||||||||||||||
with `@dali.pickling.pickle_by_value`. It may be especially useful when working with | ||||||||||||||
Jupyter notebook to work around the issue of worker process being unable to import | ||||||||||||||
the callback defined as a global function inside the notebook. | ||||||||||||||
""" | ||||||||||||||
def __init__(self, batch_size = -1, num_threads = -1, device_id = -1, seed = -1, | ||||||||||||||
exec_pipelined=True, prefetch_queue_depth=2, | ||||||||||||||
exec_async=True, bytes_per_sample=0, | ||||||||||||||
set_affinity=False, max_streams=-1, default_cuda_stream_priority = 0, | ||||||||||||||
*, | ||||||||||||||
enable_memory_stats=False, py_num_workers=1, py_start_method="fork"): | ||||||||||||||
enable_memory_stats=False, py_num_workers=1, py_start_method="fork", | ||||||||||||||
py_callback_pickler=dali_pickle): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||||||||||
self._sinks = [] | ||||||||||||||
self._max_batch_size = batch_size | ||||||||||||||
self._num_threads = num_threads | ||||||||||||||
|
@@ -172,6 +194,7 @@ def __init__(self, batch_size = -1, num_threads = -1, device_id = -1, seed = -1, | |||||||||||||
self._default_cuda_stream_priority = default_cuda_stream_priority | ||||||||||||||
self._py_num_workers = py_num_workers | ||||||||||||||
self._py_start_method = py_start_method | ||||||||||||||
self._py_callback_pickler = py_callback_pickler | ||||||||||||||
mzient marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
self._api_type = None | ||||||||||||||
self._skip_api_check = False | ||||||||||||||
self._graph_out = None | ||||||||||||||
|
@@ -562,7 +585,8 @@ def _start_py_workers(self): | |||||||||||||
if not self._parallel_input_callbacks: | ||||||||||||||
return | ||||||||||||||
self._py_pool = WorkerPool.from_groups( | ||||||||||||||
self._parallel_input_callbacks, self._prefetch_queue_depth, self._py_start_method, self._py_num_workers) | ||||||||||||||
self._parallel_input_callbacks, self._prefetch_queue_depth, self._py_start_method, | ||||||||||||||
self._py_num_workers, py_callback_pickler=self._py_callback_pickler) | ||||||||||||||
# ensure processes started by the pool are termineted when pipeline is no longer used | ||||||||||||||
weakref.finalize(self, lambda pool : pool.close(), self._py_pool) | ||||||||||||||
self._py_pool_started = True | ||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,24 @@ | ||||||
# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
|
||||||
# Used in test_external_source_parallel_custom_serialization to check if modules | ||||||
# are properly imported during callback deserialization. Such test only makes sense | ||||||
# if this module is not automatically imported when worker process starts, so don't | ||||||
# import this file globally | ||||||
|
||||||
import numpy as np | ||||||
|
||||||
def cb(x): | ||||||
return np.full((10, 100), x.idx_in_epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this?
I'm not a fan of trailing underscore and leading underscore has a meaning in Python which I guess we should not abuse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done