Skip to content

Commit

Permalink
Merge pull request #154 from davidbrochart/protocol_alignment
Browse files Browse the repository at this point in the history
Protocol alignment
  • Loading branch information
davidbrochart authored Feb 8, 2022
2 parents 7d01da1 + ee4d157 commit 780b85c
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 57 deletions.
21 changes: 20 additions & 1 deletion plugins/kernels/fps_kernels/kernel_server/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import json
import asyncio
import uuid
from typing import Dict, Tuple, Union
from typing import Dict, Tuple, Union, Optional

import zmq
import zmq.asyncio
from zmq.sugar.socket import Socket

from fastapi import WebSocket


channel_socket_types = {
"hb": zmq.REQ,
Expand Down Expand Up @@ -101,3 +103,20 @@ def connect_channel(channel_name: str, cfg: cfg_t) -> Socket:
if channel_name == "iopub":
sock.setsockopt(zmq.SUBSCRIBE, b"")
return sock


class AcceptedWebSocket:
_websocket: WebSocket
_accepted_subprotocol: Optional[str]

def __init__(self, websocket, accepted_subprotocol):
self._websocket = websocket
self._accepted_subprotocol = accepted_subprotocol

@property
def websocket(self):
return self._websocket

@property
def accepted_subprotocol(self):
return self._accepted_subprotocol
59 changes: 57 additions & 2 deletions plugins/kernels/fps_kernels/kernel_server/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,18 @@ def serialize(msg: Dict[str, Any], key: str) -> List[bytes]:
return to_send


def deserialize(msg_list: List[bytes]) -> Dict[str, Any]:
def deserialize(
msg_list: List[bytes], parent_header: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
message: Dict[str, Any] = {}
header = unpack(msg_list[1])
message["header"] = header
message["msg_id"] = header["msg_id"]
message["msg_type"] = header["msg_type"]
message["parent_header"] = unpack(msg_list[2])
if parent_header:
message["parent_header"] = parent_header
else:
message["parent_header"] = unpack(msg_list[2])
message["metadata"] = unpack(msg_list[3])
message["content"] = unpack(msg_list[4])
message["buffers"] = [memoryview(b) for b in msg_list[5:]]
Expand All @@ -91,6 +96,26 @@ def send_message(msg: Dict[str, Any], sock: Socket, key: str) -> None:
sock.send_multipart(serialize(msg, key), copy=True)


def send_raw_message(parts: List[bytes], sock: Socket, key: str) -> None:
msg = parts[:4]
buffers = parts[4:]
to_send = [DELIM, sign(msg, key)] + msg + buffers
sock.send_multipart(to_send)


def deserialize_msg_from_ws_v1(ws_msg: bytes) -> Tuple[str, List[bytes]]:
offset_number = int.from_bytes(ws_msg[:8], "little")
offsets = [
int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") # noqa
for i in range(offset_number)
]
channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8") # noqa
msg_list = [
ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1) # noqa
]
return channel, msg_list


async def receive_message(
sock: Socket, timeout: float = float("inf")
) -> Optional[Dict[str, Any]]:
Expand All @@ -103,6 +128,36 @@ async def receive_message(
return None


async def get_zmq_parts(socket: Socket) -> List[bytes]:
parts = await socket.recv_multipart()
idents, parts = feed_identities(parts)
return parts


def get_msg_from_parts(
parts: List[bytes], parent_header: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
return deserialize(parts, parent_header=parent_header)


def serialize_msg_to_ws_v1(msg_list: List[bytes], channel: str) -> List[bytes]:
msg_list = msg_list[1:]
channel_b = channel.encode("utf-8")
offsets = []
offsets.append(8 * (1 + 1 + len(msg_list) + 1))
offsets.append(len(channel_b) + offsets[-1])
for msg in msg_list:
offsets.append(len(msg) + offsets[-1])
offset_number = len(offsets).to_bytes(8, byteorder="little")
offsets_b = [offset.to_bytes(8, byteorder="little") for offset in offsets]
bin_msg = [offset_number] + offsets_b + [channel_b] + msg_list
return bin_msg


def get_parent_header(parts: List[bytes]) -> Dict[str, Any]:
return unpack(parts[2])


def utcnow() -> datetime:
return datetime.utcnow().replace(tzinfo=timezone.utc)

Expand Down
140 changes: 89 additions & 51 deletions plugins/kernels/fps_kernels/kernel_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime
from typing import Iterable, Optional, List, Dict, cast

from fastapi import WebSocket, WebSocketDisconnect # type: ignore
from fastapi import WebSocketDisconnect # type: ignore
from starlette.websockets import WebSocketState

from .connect import (
Expand All @@ -14,13 +14,20 @@
launch_kernel,
connect_channel,
cfg_t,
AcceptedWebSocket,
) # type: ignore
from .message import (
receive_message,
send_message,
send_raw_message,
create_message,
to_binary,
from_binary,
deserialize_msg_from_ws_v1,
get_parent_header,
get_zmq_parts,
serialize_msg_to_ws_v1,
get_msg_from_parts,
) # type: ignore


Expand All @@ -42,7 +49,7 @@ def __init__(
self.connection_file = connection_file
self.write_connection_file = write_connection_file
self.channel_tasks: List[asyncio.Task] = []
self.sessions: Dict[str, WebSocket] = {}
self.sessions: Dict[str, AcceptedWebSocket] = {}
# blocked messages and allowed messages are mutually exclusive
self.blocked_messages: List[str] = []
self.allowed_messages: Optional[
Expand Down Expand Up @@ -103,9 +110,9 @@ async def start(self) -> None:
self.iopub_channel = connect_channel("iopub", self.connection_cfg)
await self._wait_for_ready()
self.channel_tasks += [
asyncio.create_task(self.listen_shell()),
asyncio.create_task(self.listen_control()),
asyncio.create_task(self.listen_iopub()),
asyncio.create_task(self.listen("shell")),
asyncio.create_task(self.listen("control")),
asyncio.create_task(self.listen("iopub")),
]

async def stop(self) -> None:
Expand All @@ -129,61 +136,37 @@ async def restart(self) -> None:
self.setup_connection_file()
await self.start()

async def serve(self, websocket: WebSocket, session_id: str):
async def serve(self, websocket: AcceptedWebSocket, session_id: str):
self.sessions[session_id] = websocket
await self.listen_web(websocket)
del self.sessions[session_id]

async def listen_web(self, websocket: WebSocket):
async def listen_web(self, websocket: AcceptedWebSocket):
try:
while True:
msg = await receive_json_or_bytes(websocket)
msg_type = msg["header"]["msg_type"]
if (msg_type in self.blocked_messages) or (
self.allowed_messages is not None
and msg_type not in self.allowed_messages
):
continue
channel = msg.pop("channel")
if channel == "shell":
send_message(msg, self.shell_channel, self.key)
elif channel == "control":
send_message(msg, self.control_channel, self.key)
await self.send_to_zmq(websocket)
except WebSocketDisconnect:
pass

async def listen_shell(self):
while True:
msg = await receive_message(self.shell_channel)
msg["channel"] = "shell"
session = msg["parent_header"]["session"]
if session in self.sessions:
websocket = self.sessions[session]
await send_json_or_bytes(websocket, msg)

async def listen_control(self):
while True:
msg = await receive_message(self.control_channel)
msg["channel"] = "control"
session = msg["parent_header"]["session"]
if session in self.sessions:
websocket = self.sessions[session]
await send_json_or_bytes(websocket, msg)

async def listen_iopub(self):
async def listen(self, channel_name: str):
if channel_name == "shell":
channel = self.shell_channel
elif channel_name == "control":
channel = self.control_channel
elif channel_name == "iopub":
channel = self.iopub_channel

while True:
msg = await receive_message(self.iopub_channel)
msg["channel"] = "iopub"
for websocket in self.sessions.values():
try:
await send_json_or_bytes(websocket, msg)
except Exception:
pass
if "content" in msg and "execution_state" in msg["content"]:
self.last_activity = {
"date": msg["header"]["date"],
"execution_state": msg["content"]["execution_state"],
}
parts = await get_zmq_parts(channel)
parent_header = get_parent_header(parts)
if channel == self.iopub_channel:
# broadcast to all web clients
for websocket in self.sessions.values():
await self.send_to_ws(websocket, parts, parent_header, channel_name)
else:
session = parent_header["session"]
if session in self.sessions:
websocket = self.sessions[session]
await self.send_to_ws(websocket, parts, parent_header, channel_name)

async def _wait_for_ready(self):
while True:
Expand All @@ -198,6 +181,61 @@ async def _wait_for_ready(self):
else:
break

async def send_to_zmq(self, websocket):
if not websocket.accepted_subprotocol:
while True:
msg = await receive_json_or_bytes(websocket.websocket)
msg_type = msg["header"]["msg_type"]
if (msg_type in self.blocked_messages) or (
self.allowed_messages is not None
and msg_type not in self.allowed_messages
):
continue
channel = msg.pop("channel")
if channel == "shell":
send_message(msg, self.shell_channel, self.key)
elif channel == "control":
send_message(msg, self.control_channel, self.key)
elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org":
while True:
msg = await websocket.websocket.receive_bytes()
channel, parts = deserialize_msg_from_ws_v1(msg)
# NOTE: we parse the header for message filtering
# it is not as bad as parsing the content
header = json.loads(parts[0])
msg_type = header["msg_type"]
if (msg_type in self.blocked_messages) or (
self.allowed_messages is not None
and msg_type not in self.allowed_messages
):
continue
if channel == "shell":
send_raw_message(parts, self.shell_channel, self.key)
elif channel == "control":
send_raw_message(parts, self.control_channel, self.key)

async def send_to_ws(self, websocket, parts, parent_header, channel_name):
if not websocket.accepted_subprotocol:
# default, "legacy" protocol
msg = get_msg_from_parts(parts, parent_header=parent_header)
msg["channel"] = channel_name
await send_json_or_bytes(websocket.websocket, msg)
if channel_name == "iopub":
if "content" in msg and "execution_state" in msg["content"]:
self.last_activity = {
"date": msg["header"]["date"],
"execution_state": msg["content"]["execution_state"],
}
elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org":
bin_msg = serialize_msg_to_ws_v1(parts, channel_name)
try:
await websocket.websocket.send_bytes(bin_msg)
except Exception:
pass
# FIXME: update last_activity
# but we don't want to parse the content!
# or should we request it from the control channel?


async def receive_json_or_bytes(websocket):
assert websocket.application_state == WebSocketState.CONNECTED
Expand Down
12 changes: 9 additions & 3 deletions plugins/kernels/fps_kernels/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from fps_auth.config import get_auth_config # type: ignore
from fps_lab.config import get_lab_config # type: ignore

from .kernel_server.server import KernelServer, kernels # type: ignore
from .kernel_server.server import AcceptedWebSocket, KernelServer, kernels # type: ignore
from .models import Session

router = APIRouter()
Expand Down Expand Up @@ -202,10 +202,16 @@ async def kernel_channels(
if user:
accept_websocket = True
if accept_websocket:
await websocket.accept()
subprotocol = (
"v1.kernel.websocket.jupyter.org"
if "v1.kernel.websocket.jupyter.org" in websocket["subprotocols"]
else None
)
await websocket.accept(subprotocol=subprotocol)
accepted_websocket = AcceptedWebSocket(websocket, subprotocol)
if kernel_id in kernels:
kernel_server = kernels[kernel_id]["server"]
await kernel_server.serve(websocket, session_id)
await kernel_server.serve(accepted_websocket, session_id)
else:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)

Expand Down

0 comments on commit 780b85c

Please sign in to comment.