Skip to content

Commit

Permalink
Allow killing long-running processes
Browse files Browse the repository at this point in the history
  • Loading branch information
shepmaster committed Nov 6, 2023
1 parent b930fc2 commit 10b092d
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 18 deletions.
98 changes: 89 additions & 9 deletions compiler/base/orchestrator/src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,14 +456,15 @@ where

pub async fn begin_execute(
&self,
token: CancellationToken,
request: ExecuteRequest,
) -> Result<ActiveExecution, ExecuteError> {
use execute_error::*;

self.select_channel(request.channel)
.await
.context(CouldNotStartContainerSnafu)?
.begin_execute(request)
.begin_execute(token, request)
.await
}

Expand All @@ -482,14 +483,15 @@ where

pub async fn begin_compile(
&self,
token: CancellationToken,
request: CompileRequest,
) -> Result<ActiveCompilation, CompileError> {
use compile_error::*;

self.select_channel(request.channel)
.await
.context(CouldNotStartContainerSnafu)?
.begin_compile(request)
.begin_compile(token, request)
.await
}

Expand Down Expand Up @@ -603,12 +605,14 @@ impl Container {
&self,
request: ExecuteRequest,
) -> Result<WithOutput<ExecuteResponse>, ExecuteError> {
let token = Default::default();

let ActiveExecution {
task,
stdin_tx,
stdout_rx,
stderr_rx,
} = self.begin_execute(request).await?;
} = self.begin_execute(token, request).await?;

drop(stdin_tx);
WithOutput::try_absorb(task, stdout_rx, stderr_rx).await
Expand All @@ -617,6 +621,7 @@ impl Container {
#[instrument(skip_all)]
async fn begin_execute(
&self,
token: CancellationToken,
request: ExecuteRequest,
) -> Result<ActiveExecution, ExecuteError> {
use execute_error::*;
Expand All @@ -642,7 +647,7 @@ impl Container {
stdout_rx,
stderr_rx,
} = self
.spawn_cargo_task(execute_cargo)
.spawn_cargo_task(token, execute_cargo)
.await
.context(CouldNotStartCargoSnafu)?;

Expand Down Expand Up @@ -673,18 +678,21 @@ impl Container {
&self,
request: CompileRequest,
) -> Result<WithOutput<CompileResponse>, CompileError> {
let token = Default::default();

let ActiveCompilation {
task,
stdout_rx,
stderr_rx,
} = self.begin_compile(request).await?;
} = self.begin_compile(token, request).await?;

WithOutput::try_absorb(task, stdout_rx, stderr_rx).await
}

#[instrument(skip_all)]
async fn begin_compile(
&self,
token: CancellationToken,
request: CompileRequest,
) -> Result<ActiveCompilation, CompileError> {
use compile_error::*;
Expand Down Expand Up @@ -715,7 +723,7 @@ impl Container {
stdout_rx,
stderr_rx,
} = self
.spawn_cargo_task(execute_cargo)
.spawn_cargo_task(token, execute_cargo)
.await
.context(CouldNotStartCargoSnafu)?;

Expand Down Expand Up @@ -761,6 +769,7 @@ impl Container {

async fn spawn_cargo_task(
&self,
token: CancellationToken,
execute_cargo: ExecuteCommandRequest,
) -> Result<SpawnCargo, SpawnCargoError> {
use spawn_cargo_error::*;
Expand All @@ -777,10 +786,19 @@ impl Container {

let task = tokio::spawn({
async move {
let mut already_cancelled = false;
let mut stdin_open = true;

loop {
select! {
() = token.cancelled(), if !already_cancelled => {
already_cancelled = true;

let msg = CoordinatorMessage::Kill;
trace!("processing {msg:?}");
to_worker_tx.send(msg).await.context(KillSnafu)?;
},

stdin = stdin_rx.recv(), if stdin_open => {
let msg = match stdin {
Some(stdin) => {
Expand Down Expand Up @@ -952,6 +970,9 @@ pub enum SpawnCargoError {

#[snafu(display("Unable to send stdin message"))]
Stdin { source: MultiplexedSenderError },

#[snafu(display("Unable to send kill message"))]
Kill { source: MultiplexedSenderError },
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -1787,12 +1808,13 @@ mod tests {
..ARBITRARY_EXECUTE_REQUEST
};

let token = Default::default();
let ActiveExecution {
task,
stdin_tx,
stdout_rx,
stderr_rx,
} = coordinator.begin_execute(request).await.unwrap();
} = coordinator.begin_execute(token, request).await.unwrap();

stdin_tx.send("this is stdin\n".into()).await.unwrap();
// Purposefully not dropping stdin_tx early -- a user might forget
Expand Down Expand Up @@ -1836,12 +1858,13 @@ mod tests {
..ARBITRARY_EXECUTE_REQUEST
};

let token = Default::default();
let ActiveExecution {
task,
stdin_tx,
stdout_rx,
stderr_rx,
} = coordinator.begin_execute(request).await.unwrap();
} = coordinator.begin_execute(token, request).await.unwrap();

for i in 0..3 {
stdin_tx.send(format!("line {i}\n")).await.unwrap();
Expand Down Expand Up @@ -1870,6 +1893,62 @@ mod tests {
Ok(())
}

#[tokio::test]
#[snafu::report]
async fn execute_kill() -> Result<()> {
let coordinator = new_coordinator().await;

let request = ExecuteRequest {
code: r#"
fn main() {
println!("Before");
loop {
std::thread::sleep(std::time::Duration::from_secs(1));
}
println!("After");
}
"#
.into(),
..ARBITRARY_EXECUTE_REQUEST
};

let token = CancellationToken::new();
let ActiveExecution {
task,
stdin_tx: _,
mut stdout_rx,
stderr_rx,
} = coordinator
.begin_execute(token.clone(), request)
.await
.unwrap();

// Wait for some output before killing
let early_stdout = stdout_rx.recv().await.unwrap();

token.cancel();

let WithOutput {
response,
stdout,
stderr,
} = WithOutput::try_absorb(task, stdout_rx, stderr_rx)
.with_timeout()
.await
.unwrap();

assert!(!response.success, "{stderr}");
assert_contains!(response.exit_detail, "kill");

assert_contains!(early_stdout, "Before");
assert_not_contains!(stdout, "Before");
assert_not_contains!(stdout, "After");

coordinator.shutdown().await?;

Ok(())
}

const HELLO_WORLD_CODE: &str = r#"fn main() { println!("Hello World!"); }"#;

const ARBITRARY_COMPILE_REQUEST: CompileRequest = CompileRequest {
Expand Down Expand Up @@ -1914,11 +1993,12 @@ mod tests {
..ARBITRARY_COMPILE_REQUEST
};

let token = Default::default();
let ActiveCompilation {
task,
stdout_rx,
stderr_rx,
} = coordinator.begin_compile(req).await.unwrap();
} = coordinator.begin_compile(token, req).await.unwrap();

let WithOutput {
response,
Expand Down
1 change: 1 addition & 0 deletions compiler/base/orchestrator/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub enum CoordinatorMessage {
ExecuteCommand(ExecuteCommandRequest),
StdinPacket(String),
StdinClose,
Kill,
}

impl_narrow_to_broad!(
Expand Down
38 changes: 37 additions & 1 deletion compiler/base/orchestrator/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use tokio::{
sync::mpsc,
task::JoinSet,
};
use tokio_util::sync::CancellationToken;

use crate::{
bincode_input_closed,
Expand Down Expand Up @@ -194,6 +195,14 @@ async fn handle_coordinator_message(
.drop_error_details()
.context(UnableToSendStdinCloseSnafu)?;
}

CoordinatorMessage::Kill => {
process_tx
.send(Multiplexed(job_id, ProcessCommand::Kill))
.await
.drop_error_details()
.context(UnableToSendKillSnafu)?;
}
}
}

Expand Down Expand Up @@ -227,6 +236,9 @@ pub enum HandleCoordinatorMessageError {
#[snafu(display("Failed to send stdin close request to the command task"))]
UnableToSendStdinClose { source: mpsc::error::SendError<()> },

#[snafu(display("Failed to send kill request to the command task"))]
UnableToSendKill { source: mpsc::error::SendError<()> },

#[snafu(display("A coordinator command handler background task panicked"))]
TaskPanicked { source: tokio::task::JoinError },
}
Expand Down Expand Up @@ -383,13 +395,15 @@ enum ProcessCommand {
Start(ExecuteCommandRequest, MultiplexingSender),
Stdin(String),
StdinClose,
Kill,
}

struct ProcessState {
project_path: PathBuf,
processes: JoinSet<Result<(), ProcessError>>,
stdin_senders: HashMap<JobId, mpsc::Sender<String>>,
stdin_shutdown_tx: mpsc::Sender<JobId>,
kill_tokens: HashMap<JobId, CancellationToken>,
}

impl ProcessState {
Expand All @@ -399,6 +413,7 @@ impl ProcessState {
processes: Default::default(),
stdin_senders: Default::default(),
stdin_shutdown_tx,
kill_tokens: Default::default(),
}
}

Expand All @@ -410,6 +425,8 @@ impl ProcessState {
) -> Result<(), ProcessError> {
use process_error::*;

let token = CancellationToken::new();

let RunningChild {
child,
stdin_rx,
Expand All @@ -432,11 +449,13 @@ impl ProcessState {

let task_set = stream_stdio(worker_msg_tx.clone(), stdin_rx, stdin, stdout, stderr);

self.kill_tokens.insert(job_id, token.clone());

self.processes.spawn({
let stdin_shutdown_tx = self.stdin_shutdown_tx.clone();
async move {
worker_msg_tx
.send(process_end(child, task_set, stdin_shutdown_tx, job_id).await)
.send(process_end(token, child, task_set, stdin_shutdown_tx, job_id).await)
.await
.context(UnableToSendExecuteCommandResponseSnafu)
}
Expand Down Expand Up @@ -470,6 +489,12 @@ impl ProcessState {
let process = self.processes.join_next().await?;
Some(process.context(ProcessTaskPanickedSnafu).and_then(|e| e))
}

fn kill(&mut self, job_id: JobId) {
if let Some(token) = self.kill_tokens.get(&job_id) {
token.cancel();
}
}
}

async fn manage_processes(
Expand All @@ -492,6 +517,8 @@ async fn manage_processes(
ProcessCommand::Stdin(packet) => state.stdin(job_id, packet).await?,

ProcessCommand::StdinClose => state.stdin_close(job_id),

ProcessCommand::Kill => state.kill(job_id),
}
}

Expand Down Expand Up @@ -560,13 +587,19 @@ fn process_begin(
}

async fn process_end(
token: CancellationToken,
mut child: Child,
mut task_set: JoinSet<Result<(), StdioError>>,
stdin_shutdown_tx: mpsc::Sender<JobId>,
job_id: JobId,
) -> Result<ExecuteCommandResponse, ProcessError> {
use process_error::*;

select! {
() = token.cancelled() => child.kill().await.context(KillChildSnafu)?,
_ = child.wait() => {},
};

let status = child.wait().await.context(WaitChildSnafu)?;

stdin_shutdown_tx
Expand Down Expand Up @@ -706,6 +739,9 @@ pub enum ProcessError {
#[snafu(display("Failed to send stdin data"))]
UnableToSendStdinData { source: mpsc::error::SendError<()> },

#[snafu(display("Failed to kill the child process"))]
KillChild { source: std::io::Error },

#[snafu(display("Failed to wait for child process exiting"))]
WaitChild { source: std::io::Error },

Expand Down
Loading

0 comments on commit 10b092d

Please sign in to comment.