From ee4d157834aab0c4b03142987cd4a56dfb5cb850 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Tue, 8 Feb 2022 18:35:00 +0100 Subject: [PATCH] Add message filtering, and last activity update for default protocol --- .../fps_kernels/kernel_server/message.py | 12 ++--- .../fps_kernels/kernel_server/server.py | 50 ++++++++++++------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/plugins/kernels/fps_kernels/kernel_server/message.py b/plugins/kernels/fps_kernels/kernel_server/message.py index b1095b25..03c033f6 100644 --- a/plugins/kernels/fps_kernels/kernel_server/message.py +++ b/plugins/kernels/fps_kernels/kernel_server/message.py @@ -103,7 +103,7 @@ def send_raw_message(parts: List[bytes], sock: Socket, key: str) -> None: sock.send_multipart(to_send) -def deserialize_msg_from_ws_v1(ws_msg): +def deserialize_msg_from_ws_v1(ws_msg: bytes) -> Tuple[str, List[bytes]]: offset_number = int.from_bytes(ws_msg[:8], "little") offsets = [ int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") # noqa @@ -140,17 +140,17 @@ def get_msg_from_parts( return deserialize(parts, parent_header=parent_header) -def serialize_msg_to_ws_v1(msg_list, channel): +def serialize_msg_to_ws_v1(msg_list: List[bytes], channel: str) -> List[bytes]: msg_list = msg_list[1:] - channel = channel.encode("utf-8") + channel_b = channel.encode("utf-8") offsets = [] offsets.append(8 * (1 + 1 + len(msg_list) + 1)) - offsets.append(len(channel) + offsets[-1]) + offsets.append(len(channel_b) + offsets[-1]) for msg in msg_list: offsets.append(len(msg) + offsets[-1]) offset_number = len(offsets).to_bytes(8, byteorder="little") - offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets] - bin_msg = [offset_number] + offsets + [channel] + msg_list + offsets_b = [offset.to_bytes(8, byteorder="little") for offset in offsets] + bin_msg = [offset_number] + offsets_b + [channel_b] + msg_list return bin_msg diff --git a/plugins/kernels/fps_kernels/kernel_server/server.py b/plugins/kernels/fps_kernels/kernel_server/server.py index b2c04ec1..8d1c6240 100644 --- a/plugins/kernels/fps_kernels/kernel_server/server.py +++ b/plugins/kernels/fps_kernels/kernel_server/server.py @@ -161,14 +161,12 @@ async def listen(self, channel_name: str): if channel == self.iopub_channel: # broadcast to all web clients for websocket in self.sessions.values(): - await send_to_ws(websocket, parts, parent_header, channel_name) - # FIXME: add back last_activity update - # or should we request it from the control channel? + await self.send_to_ws(websocket, parts, parent_header, channel_name) else: session = parent_header["session"] if session in self.sessions: websocket = self.sessions[session] - await send_to_ws(websocket, parts, parent_header, channel_name) + await self.send_to_ws(websocket, parts, parent_header, channel_name) async def _wait_for_ready(self): while True: @@ -201,26 +199,42 @@ async def send_to_zmq(self, websocket): elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org": while True: msg = await websocket.websocket.receive_bytes() - # FIXME: add back message filtering channel, parts = deserialize_msg_from_ws_v1(msg) + # NOTE: we parse the header for message filtering + # it is not as bad as parsing the content + header = json.loads(parts[0]) + msg_type = header["msg_type"] + if (msg_type in self.blocked_messages) or ( + self.allowed_messages is not None + and msg_type not in self.allowed_messages + ): + continue if channel == "shell": send_raw_message(parts, self.shell_channel, self.key) elif channel == "control": send_raw_message(parts, self.control_channel, self.key) - -async def send_to_ws(websocket, parts, parent_header, channel_name): - if not websocket.accepted_subprotocol: - # default, "legacy" protocol - msg = get_msg_from_parts(parts, parent_header=parent_header) - msg["channel"] = channel_name - await send_json_or_bytes(websocket.websocket, msg) - elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org": - bin_msg = serialize_msg_to_ws_v1(parts, channel_name) - try: - await websocket.websocket.send_bytes(bin_msg) - except Exception: - pass + async def send_to_ws(self, websocket, parts, parent_header, channel_name): + if not websocket.accepted_subprotocol: + # default, "legacy" protocol + msg = get_msg_from_parts(parts, parent_header=parent_header) + msg["channel"] = channel_name + await send_json_or_bytes(websocket.websocket, msg) + if channel_name == "iopub": + if "content" in msg and "execution_state" in msg["content"]: + self.last_activity = { + "date": msg["header"]["date"], + "execution_state": msg["content"]["execution_state"], + } + elif websocket.accepted_subprotocol == "v1.kernel.websocket.jupyter.org": + bin_msg = serialize_msg_to_ws_v1(parts, channel_name) + try: + await websocket.websocket.send_bytes(bin_msg) + except Exception: + pass + # FIXME: update last_activity + # but we don't want to parse the content! + # or should we request it from the control channel? async def receive_json_or_bytes(websocket):