Skip to content

Commit

Permalink
Improve variables API (#14)
Browse files Browse the repository at this point in the history
* Introduce `VariableDescription` when listing variables
Get variable returns a tuple (data, metadata)
Improve logging

* Use TypedDict instead of dataclass to reduce API changes

* Automatic application of license header

---------

Co-authored-by: Frédéric Collonval <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 18, 2024
1 parent c18adb6 commit 55b117e
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 66 deletions.
2 changes: 2 additions & 0 deletions jupyter_kernel_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .client import KernelClient
from .konsoleapp import KonsoleApp
from .manager import KernelHttpManager
from .models import VariableDescription
from .snippets import SNIPPETS_REGISTRY, LanguageSnippets
from .wsclient import KernelWebSocketClient

Expand All @@ -17,4 +18,5 @@
"KernelWebSocketClient",
"KonsoleApp",
"LanguageSnippets",
"VariableDescription",
]
50 changes: 27 additions & 23 deletions jupyter_kernel_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from traitlets.config import LoggingConfigurable

from .constants import REQUEST_TIMEOUT
from .log import get_logger
from .manager import KernelHttpManager
from .models import VariableDescription
from .snippets import SNIPPETS_REGISTRY
from .utils import UTC

logger = logging.getLogger("jupyter_kernel_client")


def output_hook(outputs: list[dict[str, t.Any]], message: dict[str, t.Any]) -> set[int]: # noqa: C901
"""Callback on messages captured during a code snippet execution.
Expand Down Expand Up @@ -161,7 +161,7 @@ class KernelClient(LoggingConfigurable):
def __init__(
self, kernel_id: str | None = None, log: logging.Logger | None = None, **kwargs
) -> None:
super().__init__(log=log or logger)
super().__init__(log=log or get_logger())
self._manager = self.kernel_manager_class(parent=self, kernel_id=kernel_id, **kwargs)
# Set it after the manager as if a kernel_id is provided,
# we will try to connect to it.
Expand Down Expand Up @@ -408,7 +408,9 @@ def stop(
#
# Variables related methods
#
def get_variable(self, name: str, mimetype: str | None = None) -> dict[str, t.Any]:
def get_variable(
self, name: str, mimetype: str | None = None
) -> tuple[dict[str, t.Any], dict[str, t.Any]]:
"""Get a kernel variable.
Args:
Expand All @@ -417,8 +419,8 @@ def get_variable(self, name: str, mimetype: str | None = None) -> dict[str, t.An
i.e. returns all known serialization.
Returns:
A dictionary for which keys are mimetype and values the variable value
serialized in that mimetype.
A tuple of dictionaries for which keys are mimetype and values the variable value
serialized in that mimetype for the first dictionary and metadata in the second one.
Even if a mimetype is specified, the dictionary may not contain it if
the kernel introspection failed to get the variable in the specified format.
Raises:
Expand All @@ -441,27 +443,25 @@ def get_variable(self, name: str, mimetype: str | None = None) -> dict[str, t.An

if results["status"] == "ok" and results["outputs"]:
if mimetype is None:
return results["outputs"][0]["data"]
return results["outputs"][0]["data"], results["outputs"][0].get("metadata", {})
else:
has_mimetype = mimetype in results["outputs"][0]["data"]
return {mimetype: results["outputs"][0]["data"][mimetype]} if has_mimetype else {}

def filter_dict(d: dict, mimetype: str) -> dict:
if mimetype in d:
return {mimetype: d[mimetype]}
else:
return {}

return (
filter_dict(results["outputs"][0]["data"], mimetype),
filter_dict(results["outputs"][0].get("metadata", {}), mimetype),
)
else:
raise RuntimeError(f"Failed to get variable {name} with type {mimetype}.")

def list_variables(self) -> list[dict[str, t.Any]]:
def list_variables(self) -> list[VariableDescription]:
"""List the kernel global variables.
A variable is defined by a dictionary with the schema:
{
"type": "object",
"properties": {
"name": {"title": "Variable name", "type": "string"},
"type": {"title": "Variable type", "type": "string"},
"size": {"title": "Variable size in bytes.", "type": "number"}
},
"required": ["name", "type"]
}
Returns:
The list of global variables.
Raises:
Expand All @@ -485,10 +485,14 @@ def list_variables(self) -> list[dict[str, t.Any]]:
if (
results["status"] == "ok"
and results["outputs"]
and "application/json" in results["outputs"][0]["data"]
and "application/json" in results["outputs"][-1]["data"]
):
return sorted(
results["outputs"][0]["data"]["application/json"], key=lambda v: v["name"]
(
VariableDescription(**v)
for v in results["outputs"][-1]["data"]["application/json"]
),
key=lambda v: v["name"],
)
else:
raise RuntimeError("Failed to list variables.")
13 changes: 13 additions & 0 deletions jupyter_kernel_client/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023-2024 Datalayer, Inc.
#
# BSD 3-Clause License

from __future__ import annotations

from typing import TypedDict


class VariableDescription(TypedDict):
name: str
type: tuple[str | None, str]
size: int | None
32 changes: 26 additions & 6 deletions jupyter_kernel_client/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,20 @@ class LanguageSnippets:
"type": "object",
"properties": {
"name": {"title": "Variable name", "type": "string"},
"type": {"title": "Variable type", "type": "string"},
"size": {"title": "Variable size in bytes.", "type": "number"}
"type": {
"title": "Variable type",
"type": "array",
"prefixItems": [
{"title": "Type module", "oneOf": [{"type": "string"}, {"type": "null"}]},
{"title": "Type name", "type": "string"}
]
},
"size": {
"title": "Variable size in bytes.",
"oneOf": [{"type": "number"}, {"type": "null"}]
}
},
"required": ["name", "type"]
"required": ["name", "type", "size"]
}
}
"""
Expand Down Expand Up @@ -123,9 +133,19 @@ def get_get_variable(self, language: str) -> str:
(_n == 'Out' and isinstance(_v, dict))
):
try:
_vars.append({"name": _n, "type": type(_v).__qualname__})
except BaseException:
...
variable_type = type(_v)
_vars.append(
{
"name": _n,
"type": (
getattr(variable_type, "__module__", None),
variable_type.__qualname__
),
"size": None,
}
)
except BaseException as e:
print(e)
display({"application/json": _vars}, raw=True)
Expand Down
54 changes: 29 additions & 25 deletions jupyter_kernel_client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pytest

from jupyter_kernel_client import KernelClient
from jupyter_kernel_client import KernelClient, VariableDescription


def test_execution_as_context_manager(jupyter_server):
Expand Down Expand Up @@ -73,32 +73,36 @@ def test_list_variables(jupyter_server):
variables = kernel.list_variables()

assert variables == [
{
"name": "a",
"type": "float",
},
{
"name": "b",
"type": "str",
},
{
"name": "c",
"type": "set",
},
{
"name": "d",
"type": "dict",
},
VariableDescription(
name="a",
type=["builtins", "float"],
size=None
),
VariableDescription(
name="b",
type=["builtins", "str"],
size=None
),
VariableDescription(
name="c",
type=["builtins", "set"],
size=None,
),
VariableDescription(
name="d",
type=["builtins", "dict"],
size=None,
),
]


@pytest.mark.parametrize(
"variable,set_variable,expected",
(
("a", "a = 1.0", {"text/plain": "1.0"}),
("b", 'b = "hello the world"', {"text/plain": "'hello the world'"}),
("c", "c = {3, 4, 5}", {"text/plain": "{3, 4, 5}"}),
("d", "d = {'name': 'titi'}", {"text/plain": "{'name': 'titi'}"}),
("a", "a = 1.0", ({"text/plain": "1.0"}, {})),
("b", 'b = "hello the world"', ({"text/plain": "'hello the world'"}, {})),
("c", "c = {3, 4, 5}", ({"text/plain": "{3, 4, 5}"}, {})),
("d", "d = {'name': 'titi'}", ({"text/plain": "{'name': 'titi'}"}, {})),
),
)
def test_get_all_mimetype_variables(jupyter_server, variable, set_variable, expected):
Expand All @@ -115,10 +119,10 @@ def test_get_all_mimetype_variables(jupyter_server, variable, set_variable, expe
@pytest.mark.parametrize(
"variable,set_variable,expected",
(
("a", "a = 1.0", {"text/plain": "1.0"}),
("b", 'b = "hello the world"', {"text/plain": "'hello the world'"}),
("c", "c = {3, 4, 5}", {"text/plain": "{3, 4, 5}"}),
("d", "d = {'name': 'titi'}", {"text/plain": "{'name': 'titi'}"}),
("a", "a = 1.0", ({"text/plain": "1.0"}, {})),
("b", 'b = "hello the world"', ({"text/plain": "'hello the world'"}, {})),
("c", "c = {3, 4, 5}", ({"text/plain": "{3, 4, 5}"}, {})),
("d", "d = {'name': 'titi'}", ({"text/plain": "{'name': 'titi'}"}, {})),
),
)
def test_get_textplain_variables(jupyter_server, variable, set_variable, expected):
Expand Down
21 changes: 9 additions & 12 deletions jupyter_kernel_client/wsclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import logging
import os
import pprint
import queue
import signal
import sys
Expand All @@ -35,6 +34,12 @@
class WSSession(Session):
"""WebSocket session."""

def __init__(self, log: logging.Logger | None = None, **kwargs):
super().__init__(**kwargs)
self.log = log or get_logger()
if not self.debug:
self.debug = self.log.level == logging.DEBUG

def serialize(self, msg: dict[str, t.Any], **kwargs) -> list[bytes]: # type:ignore[override,no-untyped-def]
"""Serialize the message components to bytes.
Expand Down Expand Up @@ -127,10 +132,7 @@ def deserialize( # type:ignore[override,no-untyped-def]
message["content"] = msg_list[3]
buffers = [memoryview(b) for b in msg_list[4:]]
message["buffers"] = buffers
if self.debug:
# Keep pprint instead of logging as this is the method used by `Session`
pprint.pprint("WSSession.deserialize") # noqa: T203
pprint.pprint(message) # noqa: T203
self.log.debug("WSSession.deserialize\n%s", message)
# adapt to the current version
return adapt(message)

Expand Down Expand Up @@ -225,12 +227,7 @@ def send( # type:ignore[override]

stream.send_bytes(serialize_msg_to_ws_v1(to_send, channel))

if self.debug:
# Keep pprint instead of logging as this is the method used by `Session`
pprint.pprint("WSSession.send") # noqa: T203
pprint.pprint(msg) # noqa: T203
pprint.pprint(to_send) # noqa: T203
pprint.pprint(buffers) # noqa: T203
self.log.debug("WSSession.send\n%s\n%s\n%s", msg, to_send, buffers)

return msg

Expand Down Expand Up @@ -459,7 +456,7 @@ def __init__( # type:ignore[no-untyped-def]
self._message_received = Event()
self.interrupt_thread = None
self.timeout = timeout
self.session = WSSession()
self.session = WSSession(log=self.log)
self.session.username = username or ""
self.session.debug = debug_session

Expand Down

0 comments on commit 55b117e

Please sign in to comment.