diff --git a/pygls/capabilities.py b/pygls/capabilities.py index fcc9cf02..9557370d 100644 --- a/pygls/capabilities.py +++ b/pygls/capabilities.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # # limitations under the License. # ############################################################################ -from functools import reduce -from typing import Any, Dict, List, Optional, Set, Union, TypeVar import logging +from functools import reduce +from typing import Any, Dict, List, Optional, Set, TypeVar, Union from lsprotocol import types - logger = logging.getLogger(__name__) T = TypeVar("T") @@ -62,6 +61,7 @@ def __init__( commands: List[str], text_document_sync_kind: types.TextDocumentSyncKind, notebook_document_sync: Optional[types.NotebookDocumentSyncOptions] = None, + position_encoding: types.PositionEncodingKind = types.PositionEncodingKind.Utf16, ): self.client_capabilities = client_capabilities self.features = features @@ -71,12 +71,35 @@ def __init__( self.notebook_document_sync = notebook_document_sync self.server_cap = types.ServerCapabilities() + self.server_cap.position_encoding = position_encoding def _provider_options(self, feature: str, default: T) -> Optional[Union[T, Any]]: if feature in self.features: return self.feature_options.get(feature, default) return None + @classmethod + def choose_position_encoding( + cls, client_capabilities: types.ClientCapabilities + ) -> types.PositionEncodingKind: + server_encoding = types.PositionEncodingKind.Utf16 + + if (general := client_capabilities.general) is None: + return server_encoding + + if (encodings := general.position_encodings) is None: + return server_encoding + + # We match client preference where this an overlap between its and our supported encodings. + for client_encoding in encodings: + if client_encoding in _SUPPORTED_ENCODINGS: + server_encoding = client_encoding + return server_encoding + + logger.warning(f"Unknown `PositionEncoding`s: {encodings}") + + return server_encoding + def _with_text_document_sync(self): open_close = ( types.TEXT_DOCUMENT_DID_OPEN in self.features @@ -415,27 +438,6 @@ def _with_inline_value_provider(self): self.server_cap.inline_value_provider = value return self - def _with_position_encodings(self): - self.server_cap.position_encoding = types.PositionEncodingKind.Utf16 - - general = self.client_capabilities.general - if general is None: - return self - - encodings = general.position_encodings - if encodings is None: - return self - - # We match client preference where this an overlap between its and our supported encodings. - for encoding in encodings: - if encoding in _SUPPORTED_ENCODINGS: - self.server_cap.position_encoding = encoding - return self - - logger.warning(f"Unknown `PositionEncoding`s: {encodings}") - - return self - def _build(self): return self.server_cap @@ -474,6 +476,5 @@ def build(self): ._with_workspace_capabilities() ._with_diagnostic_provider() ._with_inline_value_provider() - ._with_position_encodings() ._build() ) diff --git a/pygls/protocol/language_server.py b/pygls/protocol/language_server.py index 709631d9..9f5b431d 100644 --- a/pygls/protocol/language_server.py +++ b/pygls/protocol/language_server.py @@ -130,7 +130,9 @@ def lsp_exit(self, *args) -> None: sys.exit(returncode) @lsp_method(types.INITIALIZE) - def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResult: + def lsp_initialize( + self, params: types.InitializeParams + ) -> Generator[Any, Any, types.InitializeResult]: """Method that initializes language server. It will compute and return server capabilities based on registered features. @@ -142,19 +144,9 @@ def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResu text_document_sync_kind = self._server._text_document_sync_kind notebook_document_sync = self._server._notebook_document_sync - # Initialize server capabilities self.client_capabilities = params.capabilities - self.server_capabilities = ServerCapabilitiesBuilder( - self.client_capabilities, - set({**self.fm.features, **self.fm.builtin_features}.keys()), - self.fm.feature_options, - list(self.fm.commands.keys()), - text_document_sync_kind, - notebook_document_sync, - ).build() - logger.debug( - "Server capabilities: %s", - json.dumps(self.server_capabilities, default=self._serialize_message), + position_encoding = ServerCapabilitiesBuilder.choose_position_encoding( + self.client_capabilities ) root_path = params.root_path @@ -162,13 +154,32 @@ def lsp_initialize(self, params: types.InitializeParams) -> types.InitializeResu if root_path is not None and root_uri is None: root_uri = from_fs_path(root_path) - # Initialize the workspace + # Initialize the workspace before yielding to the user's initialize handler workspace_folders = params.workspace_folders or [] self._workspace = Workspace( root_uri, text_document_sync_kind, workspace_folders, - self.server_capabilities.position_encoding, + position_encoding, + ) + + if (user_handler := self.fm.features.get(types.INITIALIZE)) is not None: + yield user_handler, (params,), None + + # Now that the user has had the opportunity to setup additional features, calculate + # the server's capabilities + self.server_capabilities = ServerCapabilitiesBuilder( + self.client_capabilities, + set({**self.fm.features, **self.fm.builtin_features}.keys()), + self.fm.feature_options, + list(self.fm.commands.keys()), + text_document_sync_kind, + notebook_document_sync, + position_encoding, + ).build() + logger.debug( + "Server capabilities: %s", + json.dumps(self.server_capabilities, default=self._serialize_message), ) return types.InitializeResult( diff --git a/tests/test_feature_manager.py b/tests/test_feature_manager.py index 119af551..5ceadc2a 100644 --- a/tests/test_feature_manager.py +++ b/tests/test_feature_manager.py @@ -18,6 +18,8 @@ from typing import Any import pytest +from lsprotocol import types as lsp + from pygls.capabilities import ServerCapabilitiesBuilder from pygls.exceptions import ( CommandAlreadyRegisteredError, @@ -29,7 +31,6 @@ has_ls_param_or_annotation, wrap_with_server, ) -from lsprotocol import types as lsp class Temp: @@ -704,13 +705,13 @@ def _(): [], None, None, + ServerCapabilitiesBuilder.choose_position_encoding(capabilities), ).build() assert expected == actual def test_register_prepare_rename_no_client_support(feature_manager: FeatureManager): - @feature_manager.feature(lsp.TEXT_DOCUMENT_RENAME) def _(): pass @@ -734,7 +735,6 @@ def _(): def test_register_prepare_rename_with_client_support(feature_manager: FeatureManager): - @feature_manager.feature(lsp.TEXT_DOCUMENT_RENAME) def _(): pass