Skip to content

Commit

Permalink
Update with jupyter_server PR 657
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jan 31, 2022
1 parent 89b68ba commit c7b8805
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 81 deletions.
58 changes: 23 additions & 35 deletions plugins/kernels/fps_kernels/kernel_server/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,23 @@ def send_message(msg: Dict[str, Any], sock: Socket, key: str) -> None:


def send_raw_message(parts: List[bytes], sock: Socket, key: str) -> None:
if len(parts) == 4:
msg = parts
buffers = []
else:
msg = parts[:4]
buffers = parts[4:]
msg = parts[:4]
buffers = parts[4:]
to_send = [DELIM, sign(msg, key)] + msg + buffers
sock.send_multipart(to_send)


def get_channel_parts(msg: bytes) -> Tuple[str, List[bytes]]:
layout_len = int.from_bytes(msg[:2], "little")
layout = json.loads(msg[2 : 2 + layout_len]) # noqa
parts: List[bytes] = list(
get_parts(msg[2 + layout_len :], layout["offsets"]) # noqa
)
return layout["channel"], parts


def get_parts(msg, offsets):
i0 = 0
for i1 in offsets:
yield msg[i0:i1]
i0 = i1
yield msg[i0:]
def deserialize_msg_from_ws_v1(ws_msg):
offset_number = int.from_bytes(ws_msg[:8], "little")
offsets = [
int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") # noqa
for i in range(offset_number)
]
channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8") # noqa
msg_list = [
ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1) # noqa
]
return channel, msg_list


async def receive_message(
Expand All @@ -148,21 +140,17 @@ def get_msg_from_parts(
return deserialize(parts, parent_header=parent_header)


def get_bin_msg_from_parts(channel: str, parts: List[bytes]) -> List[bytes]:
def serialize_msg_to_ws_v1(msg_list, channel):
msg_list = msg_list[1:]
channel = channel.encode("utf-8")
offsets = []
curr_sum = 0
for part in parts[1:]:
length = len(part)
offsets.append(length + curr_sum)
curr_sum += length
layout = json.dumps(
{
"channel": channel,
"offsets": offsets,
}
).encode("utf-8")
layout_length = len(layout).to_bytes(2, byteorder="little")
bin_msg = [layout_length, layout] + parts[1:]
offsets.append(8 * (1 + 1 + len(msg_list) + 1))
offsets.append(len(channel) + offsets[-1])
for msg in msg_list:
offsets.append(len(msg) + offsets[-1])
offset_number = len(offsets).to_bytes(8, byteorder="little")
offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets]
bin_msg = [offset_number] + offsets + [channel] + msg_list
return bin_msg


Expand Down
88 changes: 44 additions & 44 deletions plugins/kernels/fps_kernels/kernel_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
create_message,
to_binary,
from_binary,
get_channel_parts,
deserialize_msg_from_ws_v1,
get_parent_header,
get_zmq_parts,
get_bin_msg_from_parts,
serialize_msg_to_ws_v1,
get_msg_from_parts,
) # type: ignore

Expand Down Expand Up @@ -143,29 +143,7 @@ async def serve(self, websocket: AcceptedWebSocket, session_id: str):

async def listen_web(self, websocket: AcceptedWebSocket):
try:
if not websocket.accepted_subprotocol:
while True:
msg = await receive_json_or_bytes(websocket.websocket)
msg_type = msg["header"]["msg_type"]
if (msg_type in self.blocked_messages) or (
self.allowed_messages is not None
and msg_type not in self.allowed_messages
):
continue
channel = msg.pop("channel")
if channel == "shell":
send_message(msg, self.shell_channel, self.key)
elif channel == "control":
send_message(msg, self.control_channel, self.key)
elif websocket.accepted_subprotocol == "v1.websocket.jupyter.org":
while True:
msg = await websocket.websocket.receive_bytes()
# FIXME: add back message filtering
channel, parts = get_channel_parts(msg)
if channel == "shell":
send_raw_message(parts, self.shell_channel, self.key)
elif channel == "control":
send_raw_message(parts, self.control_channel, self.key)
await self.send_to_zmq(websocket)
except WebSocketDisconnect:
pass

Expand All @@ -183,31 +161,14 @@ async def listen(self, channel_name: str):
if channel == self.iopub_channel:
# broadcast to all web clients
for websocket in self.sessions.values():
if not websocket.accepted_subprotocol:
# default, "legacy" protocol
msg = get_msg_from_parts(parts, parent_header=parent_header)
msg["channel"] = channel_name
await send_json_or_bytes(websocket.websocket, msg)
elif websocket.accepted_subprotocol == "v1.websocket.jupyter.org":
bin_msg = get_bin_msg_from_parts(channel_name, parts)
try:
await websocket.websocket.send_bytes(bin_msg)
except Exception:
pass
await send_to_ws(websocket, parts, parent_header, channel_name)
# FIXME: add back last_activity update
# or should we request it from the control channel?
else:
session = parent_header["session"]
if session in self.sessions:
websocket = self.sessions[session]
if not websocket.accepted_subprotocol:
# default, "legacy" protocol
msg = get_msg_from_parts(parts, parent_header=parent_header)
msg["channel"] = channel_name
await send_json_or_bytes(websocket.websocket, msg)
elif websocket.accepted_subprotocol == "v1.websocket.jupyter.org":
bin_msg = get_bin_msg_from_parts(channel_name, parts)
await websocket.websocket.send_bytes(bin_msg)
await send_to_ws(websocket, parts, parent_header, channel_name)

async def _wait_for_ready(self):
while True:
Expand All @@ -222,6 +183,45 @@ async def _wait_for_ready(self):
else:
break

async def send_to_zmq(self, websocket):
if not websocket.accepted_subprotocol:
while True:
msg = await receive_json_or_bytes(websocket.websocket)
msg_type = msg["header"]["msg_type"]
if (msg_type in self.blocked_messages) or (
self.allowed_messages is not None
and msg_type not in self.allowed_messages
):
continue
channel = msg.pop("channel")
if channel == "shell":
send_message(msg, self.shell_channel, self.key)
elif channel == "control":
send_message(msg, self.control_channel, self.key)
elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org":
while True:
msg = await websocket.websocket.receive_bytes()
# FIXME: add back message filtering
channel, parts = deserialize_msg_from_ws_v1(msg)
if channel == "shell":
send_raw_message(parts, self.shell_channel, self.key)
elif channel == "control":
send_raw_message(parts, self.control_channel, self.key)


async def send_to_ws(websocket, parts, parent_header, channel_name):
if not websocket.accepted_subprotocol:
# default, "legacy" protocol
msg = get_msg_from_parts(parts, parent_header=parent_header)
msg["channel"] = channel_name
await send_json_or_bytes(websocket.websocket, msg)
elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org":
bin_msg = serialize_msg_to_ws_v1(parts, channel_name)
try:
await websocket.websocket.send_bytes(bin_msg)
except Exception:
pass


async def receive_json_or_bytes(websocket):
assert websocket.application_state == WebSocketState.CONNECTED
Expand Down
4 changes: 2 additions & 2 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ async def kernel_channels(
accept_websocket = True
if accept_websocket:
subprotocol = (
"v1.websocket.jupyter.org"
if "v1.websocket.jupyter.org" in websocket["subprotocols"]
"v1.kernel.websocket.jupyter.org"
if "v1.kernel.websocket.jupyter.org" in websocket["subprotocols"]
else None
)
await websocket.accept(subprotocol=subprotocol)
Expand Down

0 comments on commit c7b8805

Please sign in to comment.