Skip to content

Commit

Permalink
[rpc] Wrap exception creation with try/catch (pytorch#87224)
Browse files Browse the repository at this point in the history
Sometimes, we cannot recreate the exception with only string (for example if it is a custom exception type). Ideal situation would be to carry over all details on how to recreate the remote end's exception and throw that on client, but for now, we raise a RuntimeError with the original error msg when we cannot reconstruct.

Created from CodeHub with https://fburl.com/edit-in-codehub

Differential Revision: [D40353274](https://our.internmc.facebook.com/intern/diff/D40353274/)
Pull Request resolved: pytorch#87224
Approved by: https://github.com/fduwjj
  • Loading branch information
rohan-varma authored and pytorchmergebot committed Oct 20, 2022
1 parent c97ffcf commit 07bd053
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
17 changes: 16 additions & 1 deletion torch/distributed/rpc/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def _run_function(python_udf):
if isinstance(python_udf, AttributeError):
raise python_udf
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
# TODO (rohan-varma): This should probably be BaseException, but change can
# cause BC issues.
except Exception as e:
# except str = exception info + traceback string
except_str = (
Expand All @@ -218,7 +220,20 @@ def _run_function(python_udf):

def _handle_exception(result):
if isinstance(result, RemoteException):
raise result.exception_type(result.msg.encode("utf-8").decode("unicode_escape"))
exception_msg = result.msg.encode("utf-8").decode("unicode_escape")
# We wrap exception re-creation here in case some exception classes
# cannot be constructed directly from a string.
exc = None
try:
exc = result.exception_type(exception_msg)
except BaseException as e:
raise RuntimeError( # noqa: B904
f"Failed to create original exception type. Error msg was {str(e)}"
f" Original exception on remote side was {exception_msg}"
)

if exc is not None:
raise exc


def _build_rpc_profiling_key(
Expand Down
38 changes: 38 additions & 0 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,20 @@ def my_script_func(tensor):


expected_err = "Expected error"

# Note that it needs to inherit from Exception, not BaseException. See comment
# in rpc/internal.py
class CustomException(Exception):
def __init__(self, bool, msg):
self.bool = bool
super().__init__(msg)

def raise_func():
raise ValueError(expected_err)

def custom_raise_func():
raise CustomException(True, "foo")

@torch.jit.script
def raise_func_script(expected_err: str) -> torch.Tensor:
raise ValueError(expected_err)
Expand Down Expand Up @@ -3567,6 +3578,33 @@ def test_wait_all_raise_in_body(self):
raise_func()
self.assertFalse(hasattr(_thread_local_var, "future_list"))

@dist_init
def test_custom_exception_throw_during_reconstruction(self):
"""
Test that we still throw info about the remote side exception even when
we cannot recreate it on client side.
"""
initialize_pg(self.file_init_method, self.rank, self.world_size)
if self.rank != 0:
exc_caught = False
dst = worker_name(0)
try:
rpc.rpc_sync(dst, custom_raise_func, args=())
except RuntimeError as e:
exc_caught = True
msg = str(e)
print(f"Got msg {msg}")
self.assertTrue("Original exception on remote side was" in msg)
self.assertTrue("CustomException" in msg)
except BaseException as e:
raise RuntimeError(
f"Failure - expected RuntimeError, got {e}"
) from e
finally:
self.assertTrue(exc_caught)

dist.barrier()


timed_out_rpc_event = None

Expand Down

0 comments on commit 07bd053

Please sign in to comment.