Skip to content

Commit 7ceefc7

Browse files
committed
Auto merge of #15894 - schrieveslaach:cancelable-initialization, r=Veykril
Cancelable Initialization This commit provides additional initialization methods to Connection in order to support CTRL + C sigterm handling. In the process of adding LSP to Nushell (see nushell/nushell#10941) this gap has been identified.
2 parents 4513651 + 81c2d35 commit 7ceefc7

File tree

3 files changed

+188
-14
lines changed

3 files changed

+188
-14
lines changed

Cargo.lock

+27-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lib/lsp-server/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ crossbeam-channel = "0.5.6"
1414

1515
[dev-dependencies]
1616
lsp-types = "=0.94"
17+
ctrlc = "3.4.1"

lib/lsp-server/src/lib.rs

+160-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use std::{
1717
net::{TcpListener, TcpStream, ToSocketAddrs},
1818
};
1919

20-
use crossbeam_channel::{Receiver, Sender};
20+
use crossbeam_channel::{Receiver, RecvTimeoutError, Sender};
2121

2222
pub use crate::{
2323
error::{ExtractError, ProtocolError},
@@ -113,11 +113,62 @@ impl Connection {
113113
/// }
114114
/// ```
115115
pub fn initialize_start(&self) -> Result<(RequestId, serde_json::Value), ProtocolError> {
116-
loop {
117-
break match self.receiver.recv() {
118-
Ok(Message::Request(req)) if req.is_initialize() => Ok((req.id, req.params)),
116+
self.initialize_start_while(|| true)
117+
}
118+
119+
/// Starts the initialization process by waiting for an initialize as described in
120+
/// [`Self::initialize_start`] as long as `running` returns
121+
/// `true` while the return value can be changed through a sig handler such as `CTRL + C`.
122+
///
123+
/// # Example
124+
///
125+
/// ```rust
126+
/// use std::sync::atomic::{AtomicBool, Ordering};
127+
/// use std::sync::Arc;
128+
/// # use std::error::Error;
129+
/// # use lsp_types::{ClientCapabilities, InitializeParams, ServerCapabilities};
130+
/// # use lsp_server::{Connection, Message, Request, RequestId, Response};
131+
/// # fn main() -> Result<(), Box<dyn Error + Sync + Send>> {
132+
/// let running = Arc::new(AtomicBool::new(true));
133+
/// # running.store(true, Ordering::SeqCst);
134+
/// let r = running.clone();
135+
///
136+
/// ctrlc::set_handler(move || {
137+
/// r.store(false, Ordering::SeqCst);
138+
/// }).expect("Error setting Ctrl-C handler");
139+
///
140+
/// let (connection, io_threads) = Connection::stdio();
141+
///
142+
/// let res = connection.initialize_start_while(|| running.load(Ordering::SeqCst));
143+
/// # assert!(res.is_err());
144+
///
145+
/// # Ok(())
146+
/// # }
147+
/// ```
148+
pub fn initialize_start_while<C>(
149+
&self,
150+
running: C,
151+
) -> Result<(RequestId, serde_json::Value), ProtocolError>
152+
where
153+
C: Fn() -> bool,
154+
{
155+
while running() {
156+
let msg = match self.receiver.recv_timeout(std::time::Duration::from_secs(1)) {
157+
Ok(msg) => msg,
158+
Err(RecvTimeoutError::Timeout) => {
159+
continue;
160+
}
161+
Err(e) => {
162+
return Err(ProtocolError(format!(
163+
"expected initialize request, got error: {e}"
164+
)))
165+
}
166+
};
167+
168+
match msg {
169+
Message::Request(req) if req.is_initialize() => return Ok((req.id, req.params)),
119170
// Respond to non-initialize requests with ServerNotInitialized
120-
Ok(Message::Request(req)) => {
171+
Message::Request(req) => {
121172
let resp = Response::new_err(
122173
req.id.clone(),
123174
ErrorCode::ServerNotInitialized as i32,
@@ -126,15 +177,18 @@ impl Connection {
126177
self.sender.send(resp.into()).unwrap();
127178
continue;
128179
}
129-
Ok(Message::Notification(n)) if !n.is_exit() => {
180+
Message::Notification(n) if !n.is_exit() => {
130181
continue;
131182
}
132-
Ok(msg) => Err(ProtocolError(format!("expected initialize request, got {msg:?}"))),
133-
Err(e) => {
134-
Err(ProtocolError(format!("expected initialize request, got error: {e}")))
183+
msg => {
184+
return Err(ProtocolError(format!("expected initialize request, got {msg:?}")));
135185
}
136186
};
137187
}
188+
189+
return Err(ProtocolError(String::from(
190+
"Initialization has been aborted during initialization",
191+
)));
138192
}
139193

140194
/// Finishes the initialization process by sending an `InitializeResult` to the client
@@ -156,6 +210,51 @@ impl Connection {
156210
}
157211
}
158212

213+
/// Finishes the initialization process as described in [`Self::initialize_finish`] as
214+
/// long as `running` returns `true` while the return value can be changed through a sig
215+
/// handler such as `CTRL + C`.
216+
pub fn initialize_finish_while<C>(
217+
&self,
218+
initialize_id: RequestId,
219+
initialize_result: serde_json::Value,
220+
running: C,
221+
) -> Result<(), ProtocolError>
222+
where
223+
C: Fn() -> bool,
224+
{
225+
let resp = Response::new_ok(initialize_id, initialize_result);
226+
self.sender.send(resp.into()).unwrap();
227+
228+
while running() {
229+
let msg = match self.receiver.recv_timeout(std::time::Duration::from_secs(1)) {
230+
Ok(msg) => msg,
231+
Err(RecvTimeoutError::Timeout) => {
232+
continue;
233+
}
234+
Err(e) => {
235+
return Err(ProtocolError(format!(
236+
"expected initialized notification, got error: {e}",
237+
)));
238+
}
239+
};
240+
241+
match msg {
242+
Message::Notification(n) if n.is_initialized() => {
243+
return Ok(());
244+
}
245+
msg => {
246+
return Err(ProtocolError(format!(
247+
r#"expected initialized notification, got: {msg:?}"#
248+
)));
249+
}
250+
}
251+
}
252+
253+
return Err(ProtocolError(String::from(
254+
"Initialization has been aborted during initialization",
255+
)));
256+
}
257+
159258
/// Initialize the connection. Sends the server capabilities
160259
/// to the client and returns the serialized client capabilities
161260
/// on success. If more fine-grained initialization is required use
@@ -198,6 +297,58 @@ impl Connection {
198297
Ok(params)
199298
}
200299

300+
/// Initialize the connection as described in [`Self::initialize`] as long as `running` returns
301+
/// `true` while the return value can be changed through a sig handler such as `CTRL + C`.
302+
///
303+
/// # Example
304+
///
305+
/// ```rust
306+
/// use std::sync::atomic::{AtomicBool, Ordering};
307+
/// use std::sync::Arc;
308+
/// # use std::error::Error;
309+
/// # use lsp_types::ServerCapabilities;
310+
/// # use lsp_server::{Connection, Message, Request, RequestId, Response};
311+
///
312+
/// # fn main() -> Result<(), Box<dyn Error + Sync + Send>> {
313+
/// let running = Arc::new(AtomicBool::new(true));
314+
/// # running.store(true, Ordering::SeqCst);
315+
/// let r = running.clone();
316+
///
317+
/// ctrlc::set_handler(move || {
318+
/// r.store(false, Ordering::SeqCst);
319+
/// }).expect("Error setting Ctrl-C handler");
320+
///
321+
/// let (connection, io_threads) = Connection::stdio();
322+
///
323+
/// let server_capabilities = serde_json::to_value(&ServerCapabilities::default()).unwrap();
324+
/// let initialization_params = connection.initialize_while(
325+
/// server_capabilities,
326+
/// || running.load(Ordering::SeqCst)
327+
/// );
328+
///
329+
/// # assert!(initialization_params.is_err());
330+
/// # Ok(())
331+
/// # }
332+
/// ```
333+
pub fn initialize_while<C>(
334+
&self,
335+
server_capabilities: serde_json::Value,
336+
running: C,
337+
) -> Result<serde_json::Value, ProtocolError>
338+
where
339+
C: Fn() -> bool,
340+
{
341+
let (id, params) = self.initialize_start_while(&running)?;
342+
343+
let initialize_data = serde_json::json!({
344+
"capabilities": server_capabilities,
345+
});
346+
347+
self.initialize_finish_while(id, initialize_data, running)?;
348+
349+
Ok(params)
350+
}
351+
201352
/// If `req` is `Shutdown`, respond to it and return `true`, otherwise return `false`
202353
pub fn handle_shutdown(&self, req: &Request) -> Result<bool, ProtocolError> {
203354
if !req.is_shutdown() {

0 commit comments

Comments
 (0)