Skip to content

Commit

Permalink
feat: yield to the user's initialize before calculating capabilities
Browse files Browse the repository at this point in the history
This should enable the dynamic registration of features during
initialization, as discussed in #381
  • Loading branch information
alcarney committed Nov 30, 2024
1 parent e4862c1 commit 1035549
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 43 deletions.
51 changes: 26 additions & 25 deletions pygls/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
# See the License for the specific language governing permissions and #
# limitations under the License. #
############################################################################
from functools import reduce
from typing import Any, Dict, List, Optional, Set, Union, TypeVar
import logging
from functools import reduce
from typing import Any, Dict, List, Optional, Set, TypeVar, Union

from lsprotocol import types


logger = logging.getLogger(__name__)
T = TypeVar("T")

Expand Down Expand Up @@ -62,6 +61,7 @@ def __init__(
commands: List[str],
text_document_sync_kind: types.TextDocumentSyncKind,
notebook_document_sync: Optional[types.NotebookDocumentSyncOptions] = None,
position_encoding: types.PositionEncodingKind = types.PositionEncodingKind.Utf16,
):
self.client_capabilities = client_capabilities
self.features = features
Expand All @@ -71,12 +71,35 @@ def __init__(
self.notebook_document_sync = notebook_document_sync

self.server_cap = types.ServerCapabilities()
self.server_cap.position_encoding = position_encoding

def _provider_options(self, feature: str, default: T) -> Optional[Union[T, Any]]:
if feature in self.features:
return self.feature_options.get(feature, default)
return None

@classmethod
def choose_position_encoding(
cls, client_capabilities: types.ClientCapabilities
) -> types.PositionEncodingKind:
server_encoding = types.PositionEncodingKind.Utf16

if (general := client_capabilities.general) is None:
return server_encoding

if (encodings := general.position_encodings) is None:
return server_encoding

# We match client preference where this an overlap between its and our supported encodings.
for client_encoding in encodings:
if client_encoding in _SUPPORTED_ENCODINGS:
server_encoding = client_encoding
return server_encoding

logger.warning(f"Unknown `PositionEncoding`s: {encodings}")

return server_encoding

def _with_text_document_sync(self):
open_close = (
types.TEXT_DOCUMENT_DID_OPEN in self.features
Expand Down Expand Up @@ -415,27 +438,6 @@ def _with_inline_value_provider(self):
self.server_cap.inline_value_provider = value
return self

def _with_position_encodings(self):
self.server_cap.position_encoding = types.PositionEncodingKind.Utf16

general = self.client_capabilities.general
if general is None:
return self

encodings = general.position_encodings
if encodings is None:
return self

# We match client preference where this an overlap between its and our supported encodings.
for encoding in encodings:
if encoding in _SUPPORTED_ENCODINGS:
self.server_cap.position_encoding = encoding
return self

logger.warning(f"Unknown `PositionEncoding`s: {encodings}")

return self

def _build(self):
return self.server_cap

Expand Down Expand Up @@ -474,6 +476,5 @@ def build(self):
._with_workspace_capabilities()
._with_diagnostic_provider()
._with_inline_value_provider()
._with_position_encodings()
._build()
)
41 changes: 26 additions & 15 deletions pygls/protocol/language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def lsp_exit(self, *args) -> None:
sys.exit(returncode)

@lsp_method(types.INITIALIZE)
def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResult:
def lsp_initialize(
self, params: types.InitializeParams
) -> Generator[Any, Any, types.InitializeResult]:
"""Method that initializes language server.
It will compute and return server capabilities based on
registered features.
Expand All @@ -142,33 +144,42 @@ def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResu
text_document_sync_kind = self._server._text_document_sync_kind
notebook_document_sync = self._server._notebook_document_sync

# Initialize server capabilities
self.client_capabilities = params.capabilities
self.server_capabilities = ServerCapabilitiesBuilder(
self.client_capabilities,
set({**self.fm.features, **self.fm.builtin_features}.keys()),
self.fm.feature_options,
list(self.fm.commands.keys()),
text_document_sync_kind,
notebook_document_sync,
).build()
logger.debug(
"Server capabilities: %s",
json.dumps(self.server_capabilities, default=self._serialize_message),
position_encoding = ServerCapabilitiesBuilder.choose_position_encoding(
self.client_capabilities
)

root_path = params.root_path
root_uri = params.root_uri
if root_path is not None and root_uri is None:
root_uri = from_fs_path(root_path)

# Initialize the workspace
# Initialize the workspace before yielding to the user's initialize handler
workspace_folders = params.workspace_folders or []
self._workspace = Workspace(
root_uri,
text_document_sync_kind,
workspace_folders,
self.server_capabilities.position_encoding,
position_encoding,
)

if (user_handler := self.fm.features.get(types.INITIALIZE)) is not None:
yield user_handler, (params,), None

# Now that the user has had the opportunity to setup additional features, calculate
# the server's capabilities
self.server_capabilities = ServerCapabilitiesBuilder(
self.client_capabilities,
set({**self.fm.features, **self.fm.builtin_features}.keys()),
self.fm.feature_options,
list(self.fm.commands.keys()),
text_document_sync_kind,
notebook_document_sync,
position_encoding,
).build()
logger.debug(
"Server capabilities: %s",
json.dumps(self.server_capabilities, default=self._serialize_message),
)

return types.InitializeResult(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_feature_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing import Any

import pytest
from lsprotocol import types as lsp

from pygls.capabilities import ServerCapabilitiesBuilder
from pygls.exceptions import (
CommandAlreadyRegisteredError,
Expand All @@ -29,7 +31,6 @@
has_ls_param_or_annotation,
wrap_with_server,
)
from lsprotocol import types as lsp


class Temp:
Expand Down Expand Up @@ -704,13 +705,13 @@ def _():
[],
None,
None,
ServerCapabilitiesBuilder.choose_position_encoding(capabilities),
).build()

assert expected == actual


def test_register_prepare_rename_no_client_support(feature_manager: FeatureManager):

@feature_manager.feature(lsp.TEXT_DOCUMENT_RENAME)
def _():
pass
Expand All @@ -734,7 +735,6 @@ def _():


def test_register_prepare_rename_with_client_support(feature_manager: FeatureManager):

@feature_manager.feature(lsp.TEXT_DOCUMENT_RENAME)
def _():
pass
Expand Down

0 comments on commit 1035549

Please sign in to comment.