Skip to content

Commit

Permalink
Add callback to with_exception_handling (#32136)
Browse files Browse the repository at this point in the history
* Add callback to with_exception_handling

* Format

* Format

* Implement feedback
  • Loading branch information
damccorm authored Aug 15, 2024
1 parent d7d9f51 commit 2d14076
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 5 deletions.
32 changes: 27 additions & 5 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,7 +1572,9 @@ def with_exception_handling(
use_subprocess=False,
threshold=1,
threshold_windowing=None,
timeout=None):
timeout=None,
on_failure_callback: typing.Optional[typing.Callable[
[Exception, typing.Any], None]] = None):
"""Automatically provides a dead letter output for skipping bad records.
This can allow a pipeline to continue successfully rather than fail or
continuously throw errors on retry when bad elements are encountered.
Expand Down Expand Up @@ -1620,6 +1622,12 @@ def with_exception_handling(
defaults to the windowing of the input.
timeout: If the element has not finished processing in timeout seconds,
raise a TimeoutError. Defaults to None, meaning no time limit.
on_failure_callback: If an element fails or times out,
on_failure_callback will be invoked. It will receive the exception
and the element being processed in as args. In case of a timeout,
the exception will be of type `TimeoutError`. Be careful with this
callback - if you set a timeout, it will not apply to the callback,
and if the callback fails it will not be retried.
"""
args, kwargs = self.raw_side_inputs
return self.label >> _ExceptionHandlingWrapper(
Expand All @@ -1633,7 +1641,8 @@ def with_exception_handling(
use_subprocess,
threshold,
threshold_windowing,
timeout)
timeout,
on_failure_callback)

def default_type_hints(self):
return self.fn.get_type_hints()
Expand Down Expand Up @@ -2232,7 +2241,8 @@ def __init__(
use_subprocess,
threshold,
threshold_windowing,
timeout):
timeout,
on_failure_callback):
if partial and use_subprocess:
raise ValueError('partial and use_subprocess are mutually incompatible.')
self._fn = fn
Expand All @@ -2246,6 +2256,7 @@ def __init__(
self._threshold = threshold
self._threshold_windowing = threshold_windowing
self._timeout = timeout
self._on_failure_callback = on_failure_callback

def expand(self, pcoll):
if self._use_subprocess:
Expand All @@ -2256,7 +2267,11 @@ def expand(self, pcoll):
wrapped_fn = self._fn
result = pcoll | ParDo(
_ExceptionHandlingWrapperDoFn(
wrapped_fn, self._dead_letter_tag, self._exc_class, self._partial),
wrapped_fn,
self._dead_letter_tag,
self._exc_class,
self._partial,
self._on_failure_callback),
*self._args,
**self._kwargs).with_outputs(
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
Expand Down Expand Up @@ -2295,11 +2310,13 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam):


class _ExceptionHandlingWrapperDoFn(DoFn):
def __init__(self, fn, dead_letter_tag, exc_class, partial):
def __init__(
self, fn, dead_letter_tag, exc_class, partial, on_failure_callback):
self._fn = fn
self._dead_letter_tag = dead_letter_tag
self._exc_class = exc_class
self._partial = partial
self._on_failure_callback = on_failure_callback

def __getattribute__(self, name):
if (name.startswith('__') or name in self.__dict__ or
Expand All @@ -2316,6 +2333,11 @@ def process(self, *args, **kwargs):
result = list(result)
yield from result
except self._exc_class as exn:
if self._on_failure_callback is not None:
try:
self._on_failure_callback(exn, args[0])
except Exception as e:
logging.warning('on_failure_callback failed with error: %s', e)
yield pvalue.TaggedOutput(
self._dead_letter_tag,
(
Expand Down
71 changes: 71 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# pytype: skip-file

import logging
import os
import tempfile
import unittest

import pytest
Expand Down Expand Up @@ -87,6 +89,13 @@ def process(self, element):
yield element


class TestDoFn9(beam.DoFn):
def process(self, element):
if len(element) > 3:
raise ValueError('Not allowed to have long elements')
yield element


class CreateTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
Expand Down Expand Up @@ -170,6 +179,68 @@ def test_flatten_mismatched_windows(self):
_ = (source1, source2, source3) | "flatten" >> beam.Flatten()


class ExceptionHandlingTest(unittest.TestCase):
def test_routes_failures(self):
with beam.Pipeline() as pipeline:
good, bad = (
pipeline | beam.Create(['abc', 'long_word', 'foo', 'bar', 'foobar'])
| beam.ParDo(TestDoFn9()).with_exception_handling()
)
bad_elements = bad | beam.Keys()
assert_that(good, equal_to(['abc', 'foo', 'bar']), 'good')
assert_that(bad_elements, equal_to(['long_word', 'foobar']), 'bad')

def test_handles_callbacks(self):
with tempfile.TemporaryDirectory() as tmp_dirname:
tmp_path = os.path.join(tmp_dirname, 'tmp_filename')
file_contents = 'random content'

def failure_callback(e, el):
if type(e) is not ValueError:
raise Exception(f'Failed to pass in correct exception, received {e}')
if el != 'foobar':
raise Exception(f'Failed to pass in correct element, received {el}')
f = open(tmp_path, "a")
logging.warning(tmp_path)
f.write(file_contents)
f.close()

with beam.Pipeline() as pipeline:
good, bad = (
pipeline | beam.Create(['abc', 'bcd', 'foo', 'bar', 'foobar'])
| beam.ParDo(TestDoFn9()).with_exception_handling(
on_failure_callback=failure_callback)
)
bad_elements = bad | beam.Keys()
assert_that(good, equal_to(['abc', 'bcd', 'foo', 'bar']), 'good')
assert_that(bad_elements, equal_to(['foobar']), 'bad')
with open(tmp_path) as f:
s = f.read()
self.assertEqual(s, file_contents)

def test_handles_no_callback_triggered(self):
with tempfile.TemporaryDirectory() as tmp_dirname:
tmp_path = os.path.join(tmp_dirname, 'tmp_filename')
file_contents = 'random content'

def failure_callback(e, el):
f = open(tmp_path, "a")
logging.warning(tmp_path)
f.write(file_contents)
f.close()

with beam.Pipeline() as pipeline:
good, bad = (
pipeline | beam.Create(['abc', 'bcd', 'foo', 'bar'])
| beam.ParDo(TestDoFn9()).with_exception_handling(
on_failure_callback=failure_callback)
)
bad_elements = bad | beam.Keys()
assert_that(good, equal_to(['abc', 'bcd', 'foo', 'bar']), 'good')
assert_that(bad_elements, equal_to([]), 'bad')
self.assertFalse(os.path.isfile(tmp_path))


class FlatMapTest(unittest.TestCase):
def test_default(self):

Expand Down

0 comments on commit 2d14076

Please sign in to comment.