From 2d140762eaa44c7f95a468acd5682c9a57301c71 Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Thu, 15 Aug 2024 22:51:07 +0200 Subject: [PATCH] Add callback to with_exception_handling (#32136) * Add callback to with_exception_handling * Format * Format * Implement feedback --- sdks/python/apache_beam/transforms/core.py | 32 +++++++-- .../apache_beam/transforms/core_test.py | 71 +++++++++++++++++++ 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 9f6902cdc23f..68c9eecd9f3f 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1572,7 +1572,9 @@ def with_exception_handling( use_subprocess=False, threshold=1, threshold_windowing=None, - timeout=None): + timeout=None, + on_failure_callback: typing.Optional[typing.Callable[ + [Exception, typing.Any], None]] = None): """Automatically provides a dead letter output for skipping bad records. This can allow a pipeline to continue successfully rather than fail or continuously throw errors on retry when bad elements are encountered. @@ -1620,6 +1622,12 @@ def with_exception_handling( defaults to the windowing of the input. timeout: If the element has not finished processing in timeout seconds, raise a TimeoutError. Defaults to None, meaning no time limit. + on_failure_callback: If an element fails or times out, + on_failure_callback will be invoked. It will receive the exception + and the element being processed in as args. In case of a timeout, + the exception will be of type `TimeoutError`. Be careful with this + callback - if you set a timeout, it will not apply to the callback, + and if the callback fails it will not be retried. """ args, kwargs = self.raw_side_inputs return self.label >> _ExceptionHandlingWrapper( @@ -1633,7 +1641,8 @@ def with_exception_handling( use_subprocess, threshold, threshold_windowing, - timeout) + timeout, + on_failure_callback) def default_type_hints(self): return self.fn.get_type_hints() @@ -2232,7 +2241,8 @@ def __init__( use_subprocess, threshold, threshold_windowing, - timeout): + timeout, + on_failure_callback): if partial and use_subprocess: raise ValueError('partial and use_subprocess are mutually incompatible.') self._fn = fn @@ -2246,6 +2256,7 @@ def __init__( self._threshold = threshold self._threshold_windowing = threshold_windowing self._timeout = timeout + self._on_failure_callback = on_failure_callback def expand(self, pcoll): if self._use_subprocess: @@ -2256,7 +2267,11 @@ def expand(self, pcoll): wrapped_fn = self._fn result = pcoll | ParDo( _ExceptionHandlingWrapperDoFn( - wrapped_fn, self._dead_letter_tag, self._exc_class, self._partial), + wrapped_fn, + self._dead_letter_tag, + self._exc_class, + self._partial, + self._on_failure_callback), *self._args, **self._kwargs).with_outputs( self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True) @@ -2295,11 +2310,13 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam): class _ExceptionHandlingWrapperDoFn(DoFn): - def __init__(self, fn, dead_letter_tag, exc_class, partial): + def __init__( + self, fn, dead_letter_tag, exc_class, partial, on_failure_callback): self._fn = fn self._dead_letter_tag = dead_letter_tag self._exc_class = exc_class self._partial = partial + self._on_failure_callback = on_failure_callback def __getattribute__(self, name): if (name.startswith('__') or name in self.__dict__ or @@ -2316,6 +2333,11 @@ def process(self, *args, **kwargs): result = list(result) yield from result except self._exc_class as exn: + if self._on_failure_callback is not None: + try: + self._on_failure_callback(exn, args[0]) + except Exception as e: + logging.warning('on_failure_callback failed with error: %s', e) yield pvalue.TaggedOutput( self._dead_letter_tag, ( diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index b0f54b8bb66d..b492ab0938cc 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -19,6 +19,8 @@ # pytype: skip-file import logging +import os +import tempfile import unittest import pytest @@ -87,6 +89,13 @@ def process(self, element): yield element +class TestDoFn9(beam.DoFn): + def process(self, element): + if len(element) > 3: + raise ValueError('Not allowed to have long elements') + yield element + + class CreateTest(unittest.TestCase): @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): @@ -170,6 +179,68 @@ def test_flatten_mismatched_windows(self): _ = (source1, source2, source3) | "flatten" >> beam.Flatten() +class ExceptionHandlingTest(unittest.TestCase): + def test_routes_failures(self): + with beam.Pipeline() as pipeline: + good, bad = ( + pipeline | beam.Create(['abc', 'long_word', 'foo', 'bar', 'foobar']) + | beam.ParDo(TestDoFn9()).with_exception_handling() + ) + bad_elements = bad | beam.Keys() + assert_that(good, equal_to(['abc', 'foo', 'bar']), 'good') + assert_that(bad_elements, equal_to(['long_word', 'foobar']), 'bad') + + def test_handles_callbacks(self): + with tempfile.TemporaryDirectory() as tmp_dirname: + tmp_path = os.path.join(tmp_dirname, 'tmp_filename') + file_contents = 'random content' + + def failure_callback(e, el): + if type(e) is not ValueError: + raise Exception(f'Failed to pass in correct exception, received {e}') + if el != 'foobar': + raise Exception(f'Failed to pass in correct element, received {el}') + f = open(tmp_path, "a") + logging.warning(tmp_path) + f.write(file_contents) + f.close() + + with beam.Pipeline() as pipeline: + good, bad = ( + pipeline | beam.Create(['abc', 'bcd', 'foo', 'bar', 'foobar']) + | beam.ParDo(TestDoFn9()).with_exception_handling( + on_failure_callback=failure_callback) + ) + bad_elements = bad | beam.Keys() + assert_that(good, equal_to(['abc', 'bcd', 'foo', 'bar']), 'good') + assert_that(bad_elements, equal_to(['foobar']), 'bad') + with open(tmp_path) as f: + s = f.read() + self.assertEqual(s, file_contents) + + def test_handles_no_callback_triggered(self): + with tempfile.TemporaryDirectory() as tmp_dirname: + tmp_path = os.path.join(tmp_dirname, 'tmp_filename') + file_contents = 'random content' + + def failure_callback(e, el): + f = open(tmp_path, "a") + logging.warning(tmp_path) + f.write(file_contents) + f.close() + + with beam.Pipeline() as pipeline: + good, bad = ( + pipeline | beam.Create(['abc', 'bcd', 'foo', 'bar']) + | beam.ParDo(TestDoFn9()).with_exception_handling( + on_failure_callback=failure_callback) + ) + bad_elements = bad | beam.Keys() + assert_that(good, equal_to(['abc', 'bcd', 'foo', 'bar']), 'good') + assert_that(bad_elements, equal_to([]), 'bad') + self.assertFalse(os.path.isfile(tmp_path)) + + class FlatMapTest(unittest.TestCase): def test_default(self):