Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add start_ws method to pygls' LanguageClient #503

Merged
merged 3 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 88 additions & 74 deletions poetry.lock

Large diffs are not rendered by default.

48 changes: 48 additions & 0 deletions pygls/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import json
import logging
import re
import sys
import typing
from threading import Event

from pygls.exceptions import PyglsError, JsonRpcException, JsonRpcInternalError
from pygls.protocol import JsonRPCProtocol, default_converter
from pygls.server import WebSocketTransportAdapter

if typing.TYPE_CHECKING:
from typing import Any
Expand All @@ -35,6 +37,7 @@

from cattrs import Converter

from websockets.asyncio.client import ClientConnection

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,6 +119,26 @@ async def start_tcp(self, host: str, port: int):

self._async_tasks.extend([connection])

async def start_ws(self, host: str, port: int):
"""Start communicating with a server over WebSockets."""

try:
from websockets.asyncio.client import connect
except ImportError:
logger.exception(
"Run `pip install pygls[ws]` to install dependencies required for websockets."
)
sys.exit(1)

uri = f"ws://{host}:{port}"
websocket = await connect(uri)

self.protocol._send_only_body = True
self.protocol.connection_made(WebSocketTransportAdapter(websocket)) # type: ignore

connection = asyncio.create_task(self.run_websocket(websocket))
self._async_tasks.extend([connection])

async def run_async(self, reader: asyncio.StreamReader):
"""Run the main message processing loop, asynchronously"""

Expand Down Expand Up @@ -154,6 +177,31 @@ async def run_async(self, reader: asyncio.StreamReader):
# Reset
content_length = 0

async def run_websocket(self, websocket: ClientConnection):
"""Run the main message processing loop, over websockets."""

try:
from websockets.exceptions import ConnectionClosedOK
except ImportError:
logger.exception(
"Run `pip install pygls[ws]` to install dependencies required for websockets."
)
return

while not self._stop_event.is_set():
try:
data = await websocket.recv(decode=False)
except ConnectionClosedOK:
self._stop_event.set()
break

try:
message = json.loads(data, object_hook=self.protocol.structure_message)
self.protocol.handle_message(message)
except Exception as exc:
logger.exception("Unable to handle message")
self._report_server_error(exc, JsonRpcInternalError)

async def _server_exit(self):
"""Cleanup handler that runs when the server process managed by the client exits"""
if self._server is None:
Expand Down
7 changes: 3 additions & 4 deletions pygls/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,12 @@ class WebSocketTransportAdapter:
Write method sends data via the WebSocket interface.
"""

def __init__(self, ws, loop):
def __init__(self, ws):
self._ws = ws
self._loop = loop

def close(self) -> None:
"""Stop the WebSocket server."""
self._ws.close()
asyncio.ensure_future(self._ws.close())

def write(self, data: Any) -> None:
"""Create a task to write specified data into a WebSocket."""
Expand Down Expand Up @@ -290,7 +289,7 @@ def start_ws(self, host: str, port: int) -> None:

async def connection_made(websocket, _):
"""Handle new connection wrapped in the WebSocket."""
self.protocol.transport = WebSocketTransportAdapter(websocket, self.loop)
self.protocol.transport = WebSocketTransportAdapter(websocket)
async for message in websocket:
self.protocol.handle_message(
json.loads(message, object_hook=self.protocol.structure_message)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ readme = "README.md"
python = ">=3.9"
cattrs = ">=23.1.2"
lsprotocol = "2024.0.0a2"
websockets = { version = ">=11.0.3", optional = true }
websockets = { version = ">=13.0", optional = true }

[tool.poetry.extras]
ws = ["websockets"]
Expand Down Expand Up @@ -68,6 +68,7 @@ poetry_lock_check = "poetry check"
sequence = [
{ cmd = "pytest --cov" },
{ cmd = "pytest tests/e2e --lsp-transport tcp" },
{ cmd = "pytest tests/e2e --lsp-transport websockets" },
]
ignore_fail = "return_non_zero"

Expand Down
11 changes: 10 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def pytest_addoption(parser):
dest="lsp_transport",
action="store",
default="stdio",
choices=("stdio", "tcp"),
choices=("stdio", "tcp", "websockets"),
help="Choose the transport to use with servers under test.",
)

Expand Down Expand Up @@ -212,6 +212,15 @@ async def fn(
await asyncio.sleep(1)
await client.start_tcp(host, port)

elif transport == "websockets":
# TODO: Make host/port configurable?
host, port = "localhost", 8888
server_cmd.extend(["--ws", "--host", host, "--port", f"{port}"])

server = await asyncio.create_subprocess_exec(*server_cmd)
await asyncio.sleep(1)
await client.start_ws(host, port)

else:
raise NotImplementedError(f"Unsupported transport: {transport!r}")

Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_threaded_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def test_countdown_threaded(
):
"""Ensure that the countdown threaded command is working as expected."""

if IS_WIN and transport == "tcp":
if (IS_WIN and transport == "tcp") or transport == "websockets":
pytest.skip("see https://github.com/openlawlibrary/pygls/issues/502")

client, initialize_result = threaded_handlers
Expand Down
Loading