Skip to content

Commit

Permalink
Revert "Optimize nn.Module __call__ fast path for dynamo (#95931)" (#…
Browse files Browse the repository at this point in the history
…96242)

Reverting due to concerns over silent unsoundness (skipped hooks) if users have directly added hooks dicts without using official torch APIs.

This reverts commit 2604533.

Pull Request resolved: pytorch/pytorch#96242
Approved by: https://github.com/albanD
  • Loading branch information
wconstab authored and cyyever committed Mar 12, 2023
1 parent dd8ed48 commit 92ef868
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 103 deletions.
2 changes: 1 addition & 1 deletion test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ def guard_fail_fn(failure):
handle.remove()
self.assertEqual(compiled_func(inp), outer_func(inp))
self.assertEqual(compiled_func(inp).item(), 7)
self.assertTrue("hooks" in failure_reason)
self.assertTrue("forward_hooks.keys" in failure_reason)
self.assertEqual(cc.frame_count, 1 + 1)
self.assertEqual(cc.op_count, 6 + 4)

Expand Down
1 change: 1 addition & 0 deletions test/nn/test_module_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def bw_hook(inc, h_module, grad_input, grad_output):
self.assertTrue(isinstance(h_module, module))
self.assertEqual(grad_output[0], torch.ones(5, 5) * 2)
counter['backwards'] += inc

test_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(1, *args))

module_1(input)
Expand Down
28 changes: 0 additions & 28 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15945,34 +15945,6 @@ def forward(self, x, y):
with torch.jit.fuser(fuser_name):
self.checkModule(MyModule(), (x, y))

def test_scriptmodule_update_has_hooks(self):

class SimpleModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self):
pass

def forward_hook(self, input: Tuple[()], output: None):
pass

m = SimpleModule()
hook = m.register_forward_hook(forward_hook)
sm = torch.jit.script(m)
self.assertTrue(sm._has_hooks)

# Todo this is bad: ideally the handle would update the scriptmodule too,
# but this is a pre-existing bug
hook.remove()
self.assertTrue(sm._has_hooks)
self.assertFalse(m._has_hooks)

# at least manual use of the update function works
del sm._forward_hooks[0]
sm._update_has_hooks()
self.assertFalse(sm._has_hooks)

# known to be failing in tracer
EXCLUDE_TRACED = {
# The following fail due to #12024.
Expand Down
27 changes: 11 additions & 16 deletions torch/jit/_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,6 @@
"dump_patches",
]

def _configure_hooks(script_module):
# Copy the forward hooks and pre-hooks to the new ScriptModule
# to allow the hooks to be run from eager as ScriptFunctions
for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
script_module._forward_pre_hooks[idx] = fn
for idx, fn in enumerate(script_module._c._get_forward_hooks()):
script_module._forward_hooks[idx] = fn

# The _update_has_hooks method sets _has_hooks attr which is used by _call_impl during python execution
# of a scriptmodule. This needs to be run both on scriptmodule creation and loading.
script_module._update_has_hooks = torch.nn.Module._update_has_hooks.__get__(script_module)
script_module._update_has_hooks()

def _compile_and_register_class(obj, rcb, qualified_name):
script_class = _get_script_class(obj)

Expand Down Expand Up @@ -541,7 +528,6 @@ def init_fn(script_module):
if name in ignored_properties:
continue
item = getattr(nn_module, name, None)

if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
unbound_function = getattr(nn_module, name).__func__
bound_method = unbound_function.__get__(script_module)
Expand All @@ -554,6 +540,7 @@ def init_fn(script_module):

# Actually create the ScriptModule, initializing it with the function we just defined
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)

# Compile methods if necessary
if concrete_type not in concrete_type_store.methods_compiled:
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
Expand All @@ -563,8 +550,13 @@ def init_fn(script_module):
torch._C._run_emit_module_hook(cpp_module)
concrete_type_store.methods_compiled.add(concrete_type)

# Copy the forward hooks and pre-hooks to the new ScriptModule
# to allow the hooks to be run from eager as ScriptFunctions
for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
script_module._forward_pre_hooks[idx] = fn
for idx, fn in enumerate(script_module._c._get_forward_hooks()):
script_module._forward_hooks[idx] = fn

_configure_hooks(script_module)

# Special handling so methods like __len__ work in script methods on classes derived from containers
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \
Expand Down Expand Up @@ -889,7 +881,10 @@ def init_fn(script_module):
setattr(script_module, name, wrap_cpp_module(cpp_module))
script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type(script_module._c._type())

_configure_hooks(script_module)
for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()):
script_module._forward_pre_hooks[idx] = fn
for idx, fn in enumerate(script_module._c._get_forward_hooks()):
script_module._forward_hooks[idx] = fn

return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)

Expand Down
52 changes: 10 additions & 42 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,6 @@ def __setstate__(self, state: Dict):
_global_is_full_backward_hook: Optional[bool] = None
_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
_global_forward_hooks: Dict[int, Callable] = OrderedDict()
_has_global_hooks: bool = False

def _update_has_global_hooks():
global _has_global_hooks
_has_global_hooks = bool(
_global_backward_pre_hooks
or _global_backward_hooks
or _global_forward_hooks
or _global_forward_pre_hooks
)

_EXTRA_STATE_KEY_SUFFIX = '_extra_state'

Expand Down Expand Up @@ -209,7 +199,6 @@ def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHand
"""
handle = hooks.RemovableHandle(_global_forward_pre_hooks)
_global_forward_pre_hooks[handle.id] = hook
_update_has_global_hooks()
return handle


Expand Down Expand Up @@ -242,7 +231,6 @@ def register_module_forward_hook(hook: Callable[..., None]) -> RemovableHandle:
"""
handle = hooks.RemovableHandle(_global_forward_hooks)
_global_forward_hooks[handle.id] = hook
_update_has_global_hooks()
return handle


Expand All @@ -267,9 +255,9 @@ def register_module_backward_hook(
"global Module hook. Please use only one of them.")

_global_is_full_backward_hook = False

handle = hooks.RemovableHandle(_global_backward_hooks)
_global_backward_hooks[handle.id] = hook
_update_has_global_hooks()
return handle


Expand Down Expand Up @@ -305,7 +293,6 @@ def register_module_full_backward_pre_hook(
``handle.remove()``
"""
_update_has_global_hooks()
handle = hooks.RemovableHandle(_global_backward_pre_hooks)
_global_backward_pre_hooks[handle.id] = hook
return handle
Expand Down Expand Up @@ -445,12 +432,8 @@ def forward(self, x):
_state_dict_pre_hooks: Dict[int, Callable]
_load_state_dict_post_hooks: Dict[int, Callable]
_modules: Dict[str, Optional['Module']]
_has_hooks: bool = False
call_super_init: bool = False

# we want _has_hooks to be updated properly by _update_has_hooks in jit ScriptModules
__jit_ignored_attributes__ = ["_has_hooks"]

def __init__(self, *args, **kwargs) -> None:
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand Down Expand Up @@ -494,14 +477,6 @@ def __init__(self, *args, **kwargs) -> None:

forward: Callable[..., Any] = _forward_unimplemented

def _update_has_hooks(self):
self._has_hooks = bool(
self._backward_hooks
or self._backward_pre_hooks
or self._forward_hooks
or self._forward_pre_hooks
)

def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
r"""Adds a buffer to the module.
Expand Down Expand Up @@ -1212,11 +1187,10 @@ def register_full_backward_pre_hook(
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._backward_pre_hooks, module=self)
handle = hooks.RemovableHandle(self._backward_pre_hooks)
self._backward_pre_hooks[handle.id] = hook
if prepend:
self._backward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
self._update_has_hooks()
return handle

def register_backward_hook(
Expand All @@ -1239,9 +1213,8 @@ def register_backward_hook(

self._is_full_backward_hook = False

handle = hooks.RemovableHandle(self._backward_hooks, module=self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
self._update_has_hooks()
return handle

def register_full_backward_hook(
Expand Down Expand Up @@ -1298,11 +1271,10 @@ def register_full_backward_hook(

self._is_full_backward_hook = True

handle = hooks.RemovableHandle(self._backward_hooks, module=self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
if prepend:
self._backward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
self._update_has_hooks()
return handle

def _get_backward_hooks(self):
Expand Down Expand Up @@ -1428,16 +1400,14 @@ def register_forward_pre_hook(
"""
handle = hooks.RemovableHandle(
self._forward_pre_hooks,
extra_dict=self._forward_pre_hooks_with_kwargs,
module=self
extra_dict=self._forward_pre_hooks_with_kwargs
)
self._forward_pre_hooks[handle.id] = hook
if with_kwargs:
self._forward_pre_hooks_with_kwargs[handle.id] = True

if prepend:
self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
self._update_has_hooks()
return handle

def register_forward_hook(
Expand Down Expand Up @@ -1491,16 +1461,14 @@ def register_forward_hook(
"""
handle = hooks.RemovableHandle(
self._forward_hooks,
extra_dict=self._forward_hooks_with_kwargs,
module=self
extra_dict=self._forward_hooks_with_kwargs
)
self._forward_hooks[handle.id] = hook
if with_kwargs:
self._forward_hooks_with_kwargs[handle.id] = True

if prepend:
self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
self._update_has_hooks()
return handle

def _slow_forward(self, *input, **kwargs):
Expand All @@ -1526,10 +1494,10 @@ def _slow_forward(self, *input, **kwargs):
def _call_impl(self, *args, **kwargs):
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward. It's slow for dynamo to guard on the state
# of all these hook dicts individually, so instead it can guard on 2 bools and we just
# have to promise to keep them up to date when hooks are added or removed via official means.
if not self._has_hooks and not _has_global_hooks:
# this function, and just call forward.
if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
or _global_backward_pre_hooks or _global_backward_hooks
or _global_forward_hooks or _global_forward_pre_hooks):
return forward_call(*args, **kwargs)
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
Expand Down
17 changes: 1 addition & 16 deletions torch/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,14 @@ class RemovableHandle:
hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``.
extra_dict (dict): An additional dictionary whose keys will be deleted
when the same keys are removed from ``hooks_dict``.
module (nn.Module): If passed, the hook dict corresponds to that module,
otherwise it is a global hook dict.
"""

id: int
next_id: int = 0

def __init__(self, hooks_dict: Any, *, extra_dict: Any = None, module=None) -> None:
def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.id = RemovableHandle.next_id

# TODO: we don't pickle/unpickle this field, which means the 'update_has_hooks'
# functionality (which is an optimization) decays after pickling. Can we fix this?

self.module_ref = weakref.ref(module) if module is not None else None
RemovableHandle.next_id += 1

self.extra_dict_ref = (
Expand All @@ -47,12 +40,6 @@ def remove(self) -> None:
if extra_dict is not None and self.id in extra_dict:
del extra_dict[self.id]

if self.module_ref is not None:
module = self.module_ref()
if module is not None:
module._update_has_hooks()
torch.nn.modules.module._update_has_global_hooks()

def __getstate__(self):
return (
(self.hooks_dict_ref(), self.id)
Expand All @@ -74,8 +61,6 @@ def __setstate__(self, state) -> None:
if len(state) < 3
else weakref.ref(OrderedDict() if state[2] is None else state[2])
)
# TODO can we actually restore module_ref after unpickling? Do we care?
self.module_ref = None

def __enter__(self) -> "RemovableHandle":
return self
Expand Down

0 comments on commit 92ef868

Please sign in to comment.