Skip to content

Commit

Permalink
Fix test_autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
Flamefire committed Dec 9, 2022
1 parent 452419e commit 8f2d4ab
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ patches = [
'PyTorch-1.11.0_increase-distributed-test-timeout.patch',
'PyTorch-1.11.0_install-vsx-vec-headers.patch',
'PyTorch-1.11.1_skip-test_init_from_local_shards.patch',
'PyTorch-1.12.1_fix-autograd-thread_shutdown-test.patch',
'PyTorch-1.12.1_fix-cuda-gcc-version-check.patch',
'PyTorch-1.12.1_fix-skip-decorators.patch',
'PyTorch-1.12.1_fix-test_cpp_extensions_jit.patch',
Expand Down Expand Up @@ -67,6 +68,8 @@ checksums = [
'f2e6b9625733d9a471bb75e1ea20e28814cf1380b4f9089aa838ee35ddecf07d', # PyTorch-1.11.0_install-vsx-vec-headers.patch
# PyTorch-1.11.1_skip-test_init_from_local_shards.patch
'4aeb1b0bc863d4801b0095cbce69f8794066748f0df27c6aaaf729c5ecba04b7',
# PyTorch-1.12.1_fix-autograd-thread_shutdown-test.patch
'd97cd6b0570a167ecc3e631dc4ea884d95ace285cc38aa980566f4fec2c0d089',
# PyTorch-1.12.1_fix-cuda-gcc-version-check.patch
'a650f4576f06c749f244cada52ff9c02499fa8f182019129488db3845e0756ab',
'e3ca6e42b2fa592ea095939fb59ab875668a058479407db3f3684cc5c6f4146c', # PyTorch-1.12.1_fix-skip-decorators.patch
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Fix flaky test_thread_shutdown in test_autograd

From https://github.com/pytorch/pytorch/pull/86464

Backport: Alexander Grund (TU Dresden)

diff --git a/test/test_autograd.py b/test/test_autograd.py
index da1e859682e..0c0bc4f1a2a 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -4320,8 +4320,12 @@ class MyFunction(Function):
def backward(ctx, grad):
return grad

+# Run on cuda if it is available to ensure that the worker thread
+# is properly initialized by the time we exit.
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
for shape in [(1,), ()]:
- v = torch.ones(shape, requires_grad=True)
+ v = torch.ones(shape, requires_grad=True, device=device)
MyFunction.apply(v).backward()
"""
s = TestCase.runWithPytorchAPIUsageStderr(code)

0 comments on commit 8f2d4ab

Please sign in to comment.