Skip to content

Commit

Permalink
Remove warning on no_backward_sync with XLA strategy (#17761)
Browse files Browse the repository at this point in the history
(cherry picked from commit f3c49b8)
  • Loading branch information
carmocca authored and lantiga committed Aug 30, 2023
1 parent ddfd5fe commit 8c72438
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238))


- Removed false positive warning when using `fabric.no_backward_sync` with XLA strategies ([#17761](https://github.com/Lightning-AI/lightning/pull/17761))


## [2.0.7] - 2023-08-14

### Changed
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def no_backward_sync(self, module: _FabricModule, enabled: bool = True) -> Gener
"You need to set up the model first before you can call `self.no_backward_sync()`:"
" `model = self.setup(model, ...)`"
)
if not enabled or isinstance(self._strategy, SingleDeviceStrategy):
if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
context = nullcontext()
elif self._strategy._backward_sync_control is None:
rank_zero_warn(
Expand Down
7 changes: 6 additions & 1 deletion tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,14 +625,19 @@ def test_no_backward_sync():
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# same for XLA
fabric._strategy = Mock(spec=XLAStrategy, _backward_sync_control=MagicMock())
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()

# pretend that the strategy supports skipping backward sync
fabric._strategy = Mock(_backward_sync_control=MagicMock())
# disabling the context manager makes it a no-op
with fabric.no_backward_sync(model, enabled=False):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_not_called()
# when enabld, the wrapped module gets passed down
# when enabled, the wrapped module gets passed down
with fabric.no_backward_sync(model):
pass
fabric._strategy._backward_sync_control.no_backward_sync.assert_called_once_with(model._forward_module)
Expand Down

0 comments on commit 8c72438

Please sign in to comment.