Skip to content

Commit

Permalink
Don't fail to acquire multi_process_shared object if some attributes …
Browse files Browse the repository at this point in the history
…aren't available on a given platform (apache#27995)

* Add failing test to demonstrate problem

* Fix issue
  • Loading branch information
damccorm authored Aug 16, 2023
1 parent ea364b8 commit 6e1dfa1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
18 changes: 17 additions & 1 deletion sdks/python/apache_beam/utils/multi_process_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,23 @@ def singletonProxy_release(self):
def __getattr__(self, name):
if not self._SingletonProxy_valid:
raise RuntimeError('Entry was released.')
return getattr(self._SingletonProxy_entry.obj, name)
try:
return getattr(self._SingletonProxy_entry.obj, name)
except AttributeError as e:
# Swallow AttributeError exceptions so that they are ignored when
# calculating public functions. These can occur if __getattr__ is
# overriden, for example to only support some platforms. This will mean
# that these functions will be silently unavailable to the
# MultiProcessShared object, leading to worse errors when someone tries
# to use them, but it will keep them from breaking the whole object for
# functions which are unusable anyways.
logging.info(
'Attribute %s is unavailable as a public function because '
'its __getattr__ function raised the following exception '
'%s',
name,
e)
return None

def __dir__(self):
# Needed for multiprocessing.managers's proxying.
Expand Down
34 changes: 34 additions & 0 deletions sdks/python/apache_beam/utils/multi_process_shared_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import threading
import unittest
from typing import Any

from apache_beam.utils import multi_process_shared

Expand Down Expand Up @@ -57,6 +58,30 @@ def error(self, msg):
raise RuntimeError(msg)


class CounterWithBadAttr(object):
def __init__(self, start=0):
self.running = start
self.lock = threading.Lock()

def get(self):
return self.running

def increment(self, value=1):
with self.lock:
self.running += value
return self.running

def error(self, msg):
raise RuntimeError(msg)

def __getattribute__(self, __name: str) -> Any:
if __name == 'error':
raise AttributeError('error is not actually supported on this platform')
else:
# Default behaviour
return object.__getattribute__(self, __name)


class MultiProcessSharedTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -72,6 +97,15 @@ def test_call(self):
self.assertEqual(self.shared.increment(value=10), 21)
self.assertEqual(self.shared.get(), 21)

def test_call_illegal_attr(self):
shared_handle = multi_process_shared.MultiProcessShared(
CounterWithBadAttr, tag='test_call_illegal_attr', always_proxy=True)
shared = shared_handle.acquire()

self.assertEqual(shared.get(), 0)
self.assertEqual(shared.increment(), 1)
self.assertEqual(shared.get(), 1)

def test_call_callable(self):
self.assertEqual(self.sharedCallable(), 0)
self.assertEqual(self.sharedCallable.increment(), 1)
Expand Down

0 comments on commit 6e1dfa1

Please sign in to comment.