Skip to content

Commit

Permalink
Improve output_hook API (#5)
Browse files Browse the repository at this point in the history
Co-authored-by: Frédéric Collonval <[email protected]>
  • Loading branch information
fcollonval and fcollonval authored Dec 7, 2024
1 parent 0029647 commit c768f69
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions jupyter_kernel_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
from .utils import UTC


def output_hook(outputs: list[dict[str, t.Any]], message: dict[str, t.Any]) -> None:
def output_hook(outputs: list[dict[str, t.Any]], message: dict[str, t.Any]) -> set[int]: # noqa: C901
"""Callback on messages captured during a code snippet execution.
The return list of updated output will be empty if no output where changed.
It will equal all indexes if the outputs was cleared.
Example:
This callback is meant to be used with ``KernelClient.execute_interactive``::
Expand All @@ -35,6 +38,9 @@ def output_hook(outputs: list[dict[str, t.Any]], message: dict[str, t.Any]) -> N
Args:
outputs: List in which to append the output
message: A kernel message
Returns:
list of output indexed updated
"""
msg_type = message["header"]["msg_type"]
content = message["content"]
Expand All @@ -49,6 +55,7 @@ def output_hook(outputs: list[dict[str, t.Any]], message: dict[str, t.Any]) -> N
"execution_count": content["execution_count"],
}
elif msg_type == "stream":
# FIXME Logic is quite complex at https://github.com/jupyterlab/jupyterlab/blob/7ae2d436fc410b0cff51042a3350ba71f54f4445/packages/outputarea/src/model.ts#L518
output = {
"output_type": msg_type,
"name": content["name"],
Expand All @@ -70,18 +77,27 @@ def output_hook(outputs: list[dict[str, t.Any]], message: dict[str, t.Any]) -> N
}
elif msg_type == "clear_output":
# Ignore wait as we run without display
size = len(outputs)
outputs.clear()
return set(range(size))
elif msg_type == "update_display_data":
display_id = content.get("transient", {}).get("display_id")
indexes = set()
if display_id:
for obsolete_update in filter(
lambda o: o.get("transient", {}).get("display_id") == display_id, outputs
):
obsolete_update["metadata"] = content["metadata"]
obsolete_update["data"] = content["data"]
for index, obsolete_update in enumerate(outputs):
if obsolete_update.get("transient", {}).get("display_id") == display_id:
obsolete_update["metadata"] = content["metadata"]
obsolete_update["data"] = content["data"]
indexes.add(index)

return indexes

if output:
index = len(outputs)
outputs.append(output)
return {index}

return set()


class KernelClient(LoggingConfigurable):
Expand Down

0 comments on commit c768f69

Please sign in to comment.