Skip to content

Commit

Permalink
catch exception to checkpoint data (#2240)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2240

## Why & What
catch exception as checkpoint data in write_checkpoint decorator

Reviewed By: ztlbells

Differential Revision: D44321329

fbshipit-source-id: 4e6fbf96b135069f55313868ad179b3353320868
  • Loading branch information
joe1234wu authored and facebook-github-bot committed Mar 27, 2023
1 parent 9d54c13 commit f2693d1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
9 changes: 7 additions & 2 deletions fbpcs/common/service/test/test_write_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def add(instance_id: str, x: int, y: int = 2) -> int:


def raise_exception(instance_id: str) -> None:
raise DummyException
raise DummyException("dummy_exception")


class DummyTraceLoggingService(SimpleTraceLoggingService):
Expand Down Expand Up @@ -87,8 +87,13 @@ async def test_basic_usage(self) -> None:
self.assertEqual(res, -1)

with self.subTest("raise exception"):
checkpoint = dummy_checkpoint()
with self.assertRaises(DummyException):
res = dummy_checkpoint()(raise_exception)("instance_id")
checkpoint(raise_exception)("instance_id")

self.assertEqual(
checkpoint.checkpoint_data, {"exception": "dummy_exception"}
)

with self.subTest("async add"):
obj = DummyClass()
Expand Down
12 changes: 10 additions & 2 deletions fbpcs/common/service/write_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,23 @@ def __call__(self, func: Callable) -> Callable: # pyre-ignore
@functools.wraps(func)
async def wrapper_async(*args: Any, **kwargs: Any) -> Any: # pyre-ignore
with self._get_trace_logger_cm(func, *args, **kwargs) as checkpoint_data:
res = await func(*args, **kwargs)
try:
res = await func(*args, **kwargs)
except Exception as ex:
checkpoint_data["exception"] = str(ex)
raise ex
if self.dump_return_val:
checkpoint_data["return_value"] = str(res)
return res

@functools.wraps(func)
def wrapper_sync(*args: Any, **kwargs: Any) -> Any: # pyre-ignore
with self._get_trace_logger_cm(func, *args, **kwargs) as checkpoint_data:
res = func(*args, **kwargs)
try:
res = func(*args, **kwargs)
except Exception as ex:
checkpoint_data["exception"] = str(ex)
raise ex
if self.dump_return_val:
checkpoint_data["return_value"] = str(res)
return res
Expand Down

0 comments on commit f2693d1

Please sign in to comment.