Skip to content

Commit

Permalink
Fix monkeypatching of _FabricModule methods (#19705)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
2 people authored and lantiga committed Apr 11, 2024
1 parent 41bbd23 commit 1218658
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __setattr__(self, name: str, value: Any) -> None:
original_has_attr = hasattr(original_module, name)
# Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules
# Can't use self.__getattr__ because it would pass through to the original module
fabric_has_attr = name in self.__dict__
fabric_has_attr = name in dir(self)

if not (original_has_attr or fabric_has_attr):
setattr(original_module, name, value)
Expand Down
20 changes: 20 additions & 0 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def __init__(self):

# Modify existing attribute on original_module
fabric_module.attribute = 101
# "attribute" is only in the original_module, so it shouldn't get set in the fabric_module
assert "attribute" not in fabric_module.__dict__
assert fabric_module.attribute == 101 # returns it from original_module
assert original_module.attribute == 101

# Check setattr of original_module
Expand All @@ -170,6 +173,23 @@ def __init__(self):
assert linear in fabric_module.modules()
assert linear in original_module.modules()

# Check monkeypatching of methods
fabric_module = _FabricModule(Mock(), Mock())
original = id(fabric_module.forward)
fabric_module.forward = lambda *_: None
assert id(fabric_module.forward) != original
# Check special methods
assert "__repr__" in dir(fabric_module)
assert "__repr__" not in fabric_module.__dict__
assert "__repr__" not in _FabricModule.__dict__
fabric_module.__repr__ = lambda *_: "test"
assert fabric_module.__repr__() == "test"
# needs to be monkeypatched on the class for `repr()` to change
assert repr(fabric_module) == "_FabricModule()"
with mock.patch.object(_FabricModule, "__repr__", return_value="test"):
assert fabric_module.__repr__() == "test"
assert repr(fabric_module) == "test"


def test_fabric_module_state_dict_access():
"""Test that state_dict access passes through to the original module."""
Expand Down

0 comments on commit 1218658

Please sign in to comment.