Skip to content

Commit

Permalink
Log user code exceptions in Python (#28515)
Browse files Browse the repository at this point in the history
* Log user code exceptions in Python

* address linter

* trigger tests

---------

Co-authored-by: Sam Rohde <[email protected]>
  • Loading branch information
rohdesamuel and Sam Rohde authored Sep 19, 2023
1 parent 6261a00 commit 612bfc4
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/runners/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

# pytype: skip-file

import logging
import sys
import threading
import traceback
Expand Down Expand Up @@ -81,6 +82,8 @@
ENCODED_IMPULSE_VALUE = IMPULSE_VALUE_CODER_IMPL.encode_nested(
GlobalWindows.windowed_value(b''))

_LOGGER = logging.getLogger(__name__)


class NameContext(object):
"""Holds the name information for a step."""
Expand Down Expand Up @@ -1538,6 +1541,7 @@ def _reraise_augmented(self, exn, windowed_value=None):

new_exn = new_exn.with_traceback(tb)
self._maybe_sample_exception(exc_info, windowed_value)
_LOGGER.exception(new_exn)
raise new_exn


Expand Down
151 changes: 151 additions & 0 deletions sdks/python/apache_beam/runners/worker/log_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,97 @@

import grpc

import apache_beam as beam
from apache_beam.coders.coders import FastPrimitivesCoder
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners import common
from apache_beam.runners.common import NameContext
from apache_beam.runners.worker import bundle_processor
from apache_beam.runners.worker import log_handler
from apache_beam.runners.worker import operations
from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker.bundle_processor import BeamTransformFactory
from apache_beam.runners.worker.bundle_processor import BundleProcessor
from apache_beam.transforms.window import GlobalWindow
from apache_beam.utils import thread_pool_executor
from apache_beam.utils.windowed_value import WindowedValue

_LOGGER = logging.getLogger(__name__)


@BeamTransformFactory.register_urn('beam:internal:testexn:v1', bytes)
def create_exception_dofn(
factory, transform_id, transform_proto, payload, consumers):
"""Returns a test DoFn that raises the given exception."""
class RaiseException(beam.DoFn):
def __init__(self, msg):
self.msg = msg.decode()

def process(self, _):
raise RuntimeError(self.msg)

return bundle_processor._create_simple_pardo_operation(
factory,
transform_id,
transform_proto,
consumers,
RaiseException(payload))


class TestOperation(operations.Operation):
"""Test operation that forwards its payload to consumers."""
class Spec:
def __init__(self, transform_proto):
self.output_coders = [
FastPrimitivesCoder() for _ in transform_proto.outputs
]

def __init__(
self,
transform_proto,
name_context,
counter_factory,
state_sampler,
consumers,
payload,
):
super().__init__(
name_context,
self.Spec(transform_proto),
counter_factory,
state_sampler)
self.payload = payload

for _, consumer_ops in consumers.items():
for consumer in consumer_ops:
self.add_receiver(consumer, 0)

def start(self):
super().start()

# Not using windowing logic, so just using simple defaults here.
if self.payload:
self.process(
WindowedValue(self.payload, timestamp=0, windows=[GlobalWindow()]))

def process(self, windowed_value):
self.output(windowed_value)


@BeamTransformFactory.register_urn('beam:internal:testop:v1', bytes)
def create_test_op(factory, transform_id, transform_proto, payload, consumers):
return TestOperation(
transform_proto,
common.NameContext(transform_proto.unique_name, transform_id),
factory.counter_factory,
factory.state_sampler,
consumers,
payload)


class BeamFnLoggingServicer(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):
def __init__(self):
self.log_records_received = []
Expand Down Expand Up @@ -153,6 +233,77 @@ def test_context(self):
finally:
statesampler.set_current_tracker(None)

def test_extracts_transform_id_during_exceptions(self):
"""Tests that transform ids are captured during user code exceptions."""
descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()

# Boiler plate for the DoFn.
WINDOWING_ID = 'window'
WINDOW_CODER_ID = 'cw'
window = descriptor.windowing_strategies[WINDOWING_ID]
window.window_fn.urn = common_urns.global_windows.urn
window.window_coder_id = WINDOW_CODER_ID
window.trigger.default.SetInParent()
window_coder = descriptor.coders[WINDOW_CODER_ID]
window_coder.spec.urn = common_urns.StandardCoders.Enum.GLOBAL_WINDOW.urn

# Input collection to the exception raising DoFn.
INPUT_PCOLLECTION_ID = 'pc-in'
INPUT_CODER_ID = 'c-in'
descriptor.pcollections[
INPUT_PCOLLECTION_ID].unique_name = INPUT_PCOLLECTION_ID
descriptor.pcollections[INPUT_PCOLLECTION_ID].coder_id = INPUT_CODER_ID
descriptor.pcollections[
INPUT_PCOLLECTION_ID].windowing_strategy_id = WINDOWING_ID
descriptor.coders[
INPUT_CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn

# Output collection to the exception raising DoFn.
OUTPUT_PCOLLECTION_ID = 'pc-out'
OUTPUT_CODER_ID = 'c-out'
descriptor.pcollections[
OUTPUT_PCOLLECTION_ID].unique_name = OUTPUT_PCOLLECTION_ID
descriptor.pcollections[OUTPUT_PCOLLECTION_ID].coder_id = OUTPUT_CODER_ID
descriptor.pcollections[
OUTPUT_PCOLLECTION_ID].windowing_strategy_id = WINDOWING_ID
descriptor.coders[
OUTPUT_CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn

# Add a simple transform to inject an element into the fake pipeline.
TEST_OP_TRANSFORM_ID = 'test_op'
test_transform = descriptor.transforms[TEST_OP_TRANSFORM_ID]
test_transform.outputs['None'] = INPUT_PCOLLECTION_ID
test_transform.spec.urn = 'beam:internal:testop:v1'
test_transform.spec.payload = b'hello, world!'

# Add the DoFn to create an exception.
TEST_EXCEPTION_TRANSFORM_ID = 'test_transform'
test_transform = descriptor.transforms[TEST_EXCEPTION_TRANSFORM_ID]
test_transform.inputs['0'] = INPUT_PCOLLECTION_ID
test_transform.outputs['None'] = OUTPUT_PCOLLECTION_ID
test_transform.spec.urn = 'beam:internal:testexn:v1'
test_transform.spec.payload = b'expected exception'

# Create and process a fake bundle. The instruction id doesn't matter
# here.
processor = BundleProcessor(descriptor, None, None)

with self.assertRaisesRegex(RuntimeError, 'expected exception'):
processor.process_bundle('instruction_id')

self.fn_log_handler.close()
logs = [
log for logs in self.test_logging_service.log_records_received
for log in logs.log_entries
]

actual_log = logs[0]

self.assertEqual(
actual_log.severity, beam_fn_api_pb2.LogEntry.Severity.ERROR)
self.assertTrue('expected exception' in actual_log.message)
self.assertEqual(actual_log.transform_id, 'test_transform')


# Test cases.
data = {
Expand Down

0 comments on commit 612bfc4

Please sign in to comment.