Skip to content

Commit

Permalink
Fix swap_tensors path in _apply for modules that inherit from RNNBase…
Browse files Browse the repository at this point in the history
… (RNN, GRU, LSTM) (pytorch#122800)

Pull Request resolved: pytorch#122800
Approved by: https://github.com/albanD

(cherry picked from commit cc12668)
  • Loading branch information
mikaylagawarecki committed Apr 1, 2024
1 parent 8602990 commit d2648f9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 15 deletions.
1 change: 1 addition & 0 deletions torch/nn/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def flatten_parameters(self) -> None:
self.batch_first, bool(self.bidirectional))

def _apply(self, fn, recurse=True):
self._flat_weight_refs = []
ret = super()._apply(fn, recurse)

# Resets _flat_weights
Expand Down
16 changes: 1 addition & 15 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4288,34 +4288,20 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad
train_and_eval_differ=True,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=True),
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
skips=(
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to_empty', active_if=lambda p: p['swap']),),
decorators=rnn_gru_lstm_module_info_decorators
),
ModuleInfo(torch.nn.GRU,
train_and_eval_differ=True,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU, is_rnn=False),
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
skips=(
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to_empty', active_if=lambda p: p['swap']),),
decorators=rnn_gru_lstm_module_info_decorators),
ModuleInfo(torch.nn.LSTM,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_LSTM,
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU,
skips=(
# LSTM with projections is not currently supported with MPS
DecorateInfo(skipMPS),
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to', active_if=lambda p: p['swap']),
# RNNBase overrides `_apply` and adds weakrefs to params
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_to_empty', active_if=lambda p: p['swap']),),
DecorateInfo(skipMPS),),
decorators=rnn_gru_lstm_module_info_decorators),
ModuleInfo(torch.nn.ReflectionPad1d,
module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,
Expand Down

0 comments on commit d2648f9

Please sign in to comment.