Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace rouille with warp #2241

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,075 changes: 582 additions & 493 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ regex = "1.10.3"
reqsign = { version = "0.16.0", optional = true }
reqwest = { version = "0.12", features = [
"json",
"blocking",
"native-tls",
"stream",
"rustls-tls",
"rustls-tls-native-roots",
Expand Down Expand Up @@ -107,6 +107,7 @@ zstd = "0.13"

# dist-server only
memmap2 = "0.9.4"
native-tls = "0.2.12"
nix = { version = "0.28.0", optional = true, features = [
"mount",
"user",
Expand All @@ -115,11 +116,10 @@ nix = { version = "0.28.0", optional = true, features = [
"process",
] }
object = "0.32"
rouille = { version = "3.6", optional = true, default-features = false, features = [
"ssl",
] }
syslog = { version = "6", optional = true }
thiserror = { version = "1.0.63", optional = true }
version-compare = { version = "0.1.1", optional = true }
warp = { version = "0.3.7", optional = true, features = ["tls"] }

[dev-dependencies]
assert_cmd = "2.0.13"
Expand Down Expand Up @@ -190,15 +190,17 @@ dist-client = [
]
# Enables the sccache-dist binary
dist-server = [
"reqwest/blocking",
"jwt",
"flate2",
"libmount",
"nix",
"openssl",
"reqwest",
"rouille",
"syslog",
"version-compare",
"warp",
"thiserror",
]
# Enables dist tests with external requirements
dist-tests = ["dist-client", "dist-server"]
Expand Down
2 changes: 1 addition & 1 deletion src/bin/sccache-dist/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ impl OverlayBuilder {
for (tc, _) in entries {
warn!("Removing old un-compressed toolchain: {:?}", tc);
assert!(toolchain_dir_map.remove(tc).is_some());
fs::remove_dir_all(&self.dir.join("toolchains").join(&tc.archive_id))
fs::remove_dir_all(self.dir.join("toolchains").join(&tc.archive_id))
.context("Failed to remove old toolchain directory")?;
}
}
Expand Down
49 changes: 30 additions & 19 deletions src/bin/sccache-dist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
extern crate log;

use anyhow::{bail, Context, Error, Result};
use async_trait::async_trait;
use base64::Engine;
use cmdline::{AuthSubcommand, Command};
use rand::{rngs::OsRng, RngCore};
use sccache::config::{
scheduler as scheduler_config, server as server_config, INSECURE_DIST_CLIENT_TOKEN,
Expand All @@ -22,17 +24,16 @@ use std::env;
use std::io;
use std::path::Path;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Mutex, MutexGuard};
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::{Duration, Instant};
use tokio::runtime::Runtime;

#[cfg_attr(target_os = "freebsd", path = "build_freebsd.rs")]
mod build;

mod cmdline;
mod token_check;

use cmdline::{AuthSubcommand, Command};

pub const INSECURE_DIST_SERVER_TOKEN: &str = "dangerously_insecure_server";

// Only supported on x86_64 Linux machines and on FreeBSD
Expand Down Expand Up @@ -184,10 +185,10 @@ fn run(command: Command) -> Result<i32> {
scheduler_config::ServerAuth::Insecure => {
warn!("Scheduler starting with DANGEROUSLY_INSECURE server authentication");
let token = INSECURE_DIST_SERVER_TOKEN;
Box::new(move |server_token| check_server_token(server_token, token))
Arc::new(move |server_token| check_server_token(server_token, token))
}
scheduler_config::ServerAuth::Token { token } => {
Box::new(move |server_token| check_server_token(server_token, &token))
Arc::new(move |server_token| check_server_token(server_token, &token))
}
scheduler_config::ServerAuth::JwtHS256 { secret_key } => {
let secret_key = BASE64_URL_SAFE_ENGINE
Expand All @@ -203,7 +204,7 @@ fn run(command: Command) -> Result<i32> {
validation.validate_nbf = false;
validation
};
Box::new(move |server_token| {
Arc::new(move |server_token| {
check_jwt_server_token(server_token, &secret_key, &validation)
})
}
Expand All @@ -217,7 +218,10 @@ fn run(command: Command) -> Result<i32> {
check_client_auth,
check_server_auth,
);
http_scheduler.start()?;

// Create runtime after daemonize because Tokio doesn't work well with daemonize
let runtime = Runtime::new().context("Failed to create Tokio runtime")?;
runtime.block_on(async { http_scheduler.start().await })?;
unreachable!();
}

Expand Down Expand Up @@ -294,7 +298,8 @@ fn run(command: Command) -> Result<i32> {
server,
)
.context("Failed to create sccache HTTP server instance")?;
http_server.start()?;
let runtime = Runtime::new().context("Failed to create Tokio runtime")?;
runtime.block_on(async { http_server.start().await })?;
unreachable!();
}
}
Expand Down Expand Up @@ -399,8 +404,9 @@ impl Default for Scheduler {
}
}

#[async_trait]
impl SchedulerIncoming for Scheduler {
fn handle_alloc_job(
async fn handle_alloc_job(
&self,
requester: &dyn SchedulerOutgoing,
tc: Toolchain,
Expand Down Expand Up @@ -499,6 +505,7 @@ impl SchedulerIncoming for Scheduler {
need_toolchain,
} = requester
.do_assign_job(server_id, job_id, tc, auth.clone())
.await
.with_context(|| {
// LOCKS
let mut servers = self.servers.lock().unwrap();
Expand Down Expand Up @@ -717,7 +724,7 @@ impl SchedulerIncoming for Scheduler {
pub struct Server {
builder: Box<dyn BuilderIncoming>,
cache: Mutex<TcCache>,
job_toolchains: Mutex<HashMap<JobId, Toolchain>>,
job_toolchains: tokio::sync::Mutex<HashMap<JobId, Toolchain>>,
}

impl Server {
Expand All @@ -731,18 +738,19 @@ impl Server {
Ok(Server {
builder,
cache: Mutex::new(cache),
job_toolchains: Mutex::new(HashMap::new()),
job_toolchains: tokio::sync::Mutex::new(HashMap::new()),
})
}
}

#[async_trait]
impl ServerIncoming for Server {
fn handle_assign_job(&self, job_id: JobId, tc: Toolchain) -> Result<AssignJobResult> {
async fn handle_assign_job(&self, job_id: JobId, tc: Toolchain) -> Result<AssignJobResult> {
let need_toolchain = !self.cache.lock().unwrap().contains_toolchain(&tc);
assert!(self
.job_toolchains
.lock()
.unwrap()
.await
.insert(job_id, tc)
.is_none());
let state = if need_toolchain {
Expand All @@ -756,18 +764,19 @@ impl ServerIncoming for Server {
need_toolchain,
})
}
fn handle_submit_toolchain(
async fn handle_submit_toolchain(
&self,
requester: &dyn ServerOutgoing,
job_id: JobId,
tc_rdr: ToolchainReader,
tc_rdr: ToolchainReader<'_>,
) -> Result<SubmitToolchainResult> {
requester
.do_update_job_state(job_id, JobState::Ready)
.await
.context("Updating job state failed")?;
// TODO: need to lock the toolchain until the container has started
// TODO: can start prepping container
let tc = match self.job_toolchains.lock().unwrap().get(&job_id).cloned() {
let tc = match self.job_toolchains.lock().await.get(&job_id).cloned() {
Some(tc) => tc,
None => return Ok(SubmitToolchainResult::JobNotFound),
};
Expand All @@ -783,18 +792,19 @@ impl ServerIncoming for Server {
.map(|_| SubmitToolchainResult::Success)
.unwrap_or(SubmitToolchainResult::CannotCache))
}
fn handle_run_job(
async fn handle_run_job(
&self,
requester: &dyn ServerOutgoing,
job_id: JobId,
command: CompileCommand,
outputs: Vec<String>,
inputs_rdr: InputsReader,
inputs_rdr: InputsReader<'_>,
) -> Result<RunJobResult> {
requester
.do_update_job_state(job_id, JobState::Started)
.await
.context("Updating job state failed")?;
let tc = self.job_toolchains.lock().unwrap().remove(&job_id);
let tc = self.job_toolchains.lock().await.remove(&job_id);
let res = match tc {
None => Ok(RunJobResult::JobNotFound),
Some(tc) => {
Expand All @@ -812,6 +822,7 @@ impl ServerIncoming for Server {
};
requester
.do_update_job_state(job_id, JobState::Complete)
.await
.context("Updating job state failed")?;
res
}
Expand Down
41 changes: 25 additions & 16 deletions src/bin/sccache-dist/token_check.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use base64::Engine;
use sccache::dist::http::{ClientAuthCheck, ClientVisibleMsg};
use sccache::util::{new_reqwest_blocking_client, BASE64_URL_SAFE_ENGINE};
use sccache::util::new_reqwest_client;
use sccache::util::BASE64_URL_SAFE_ENGINE;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::result::Result as StdResult;
Expand Down Expand Up @@ -54,8 +56,9 @@ pub struct EqCheck {
s: String,
}

#[async_trait]
impl ClientAuthCheck for EqCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
async fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
if self.s == token {
Ok(())
} else {
Expand All @@ -80,14 +83,15 @@ const MOZ_USERINFO_ENDPOINT: &str = "https://auth.mozilla.auth0.com/userinfo";
/// Mozilla-specific check by forwarding the token onto the auth0 userinfo endpoint
pub struct MozillaCheck {
// token, token_expiry
auth_cache: Mutex<HashMap<String, Instant>>,
client: reqwest::blocking::Client,
auth_cache: tokio::sync::Mutex<HashMap<String, Instant>>,
client: reqwest::Client,
required_groups: Vec<String>,
}

#[async_trait]
impl ClientAuthCheck for MozillaCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
self.check_mozilla(token).map_err(|e| {
async fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
self.check_mozilla(token).await.map_err(|e| {
warn!("Mozilla token validation failed: {}", e);
ClientVisibleMsg::from_nonsensitive(
"Failed to validate Mozilla OAuth token, run sccache --dist-auth".to_owned(),
Expand All @@ -99,13 +103,13 @@ impl ClientAuthCheck for MozillaCheck {
impl MozillaCheck {
pub fn new(required_groups: Vec<String>) -> Self {
Self {
auth_cache: Mutex::new(HashMap::new()),
client: new_reqwest_blocking_client(),
auth_cache: tokio::sync::Mutex::new(HashMap::new()),
client: new_reqwest_client(),
required_groups,
}
}

fn check_mozilla(&self, token: &str) -> Result<()> {
async fn check_mozilla(&self, token: &str) -> Result<()> {
// azp == client_id
// {
// "iss": "https://auth.mozilla.auth0.com/",
Expand Down Expand Up @@ -139,7 +143,7 @@ impl MozillaCheck {
}

// If the token is cached and not expired, return it
let mut auth_cache = self.auth_cache.lock().unwrap();
let mut auth_cache = self.auth_cache.lock().await;
if let Some(cached_at) = auth_cache.get(token) {
if cached_at.elapsed() < MOZ_SESSION_TIMEOUT {
return Ok(());
Expand All @@ -158,10 +162,12 @@ impl MozillaCheck {
.get(url.clone())
.bearer_auth(token)
.send()
.await
.context("Failed to make request to mozilla userinfo")?;
let status = res.status();
let res_text = res
.text()
.await
.context("Failed to interpret response from mozilla userinfo as string")?;
if !status.is_success() {
bail!("JWT forwarded to {} returned {}: {}", url, status, res_text)
Expand Down Expand Up @@ -245,14 +251,15 @@ fn test_auth_verify_check_mozilla_profile() {
// Don't check a token is valid (it may not even be a JWT) just forward it to
// an API and check for success
pub struct ProxyTokenCheck {
client: reqwest::blocking::Client,
client: reqwest::Client,
maybe_auth_cache: Option<Mutex<(HashMap<String, Instant>, Duration)>>,
url: String,
}

#[async_trait]
impl ClientAuthCheck for ProxyTokenCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
match self.check_token_with_forwarding(token) {
async fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
match self.check_token_with_forwarding(token).await {
Ok(()) => Ok(()),
Err(e) => {
warn!("Proxying token validation failed: {}", e);
Expand All @@ -269,13 +276,13 @@ impl ProxyTokenCheck {
let maybe_auth_cache: Option<Mutex<(HashMap<String, Instant>, Duration)>> =
cache_secs.map(|secs| Mutex::new((HashMap::new(), Duration::from_secs(secs))));
Self {
client: new_reqwest_blocking_client(),
client: new_reqwest_client(),
maybe_auth_cache,
url,
}
}

fn check_token_with_forwarding(&self, token: &str) -> Result<()> {
async fn check_token_with_forwarding(&self, token: &str) -> Result<()> {
trace!("Validating token by forwarding to {}", self.url);
// If the token is cached and not cache has not expired, return it
if let Some(ref auth_cache) = self.maybe_auth_cache {
Expand All @@ -294,6 +301,7 @@ impl ProxyTokenCheck {
.get(&self.url)
.bearer_auth(token)
.send()
.await
.context("Failed to make request to proxying url")?;
if !res.status().is_success() {
bail!("Token forwarded to {} returned {}", self.url, res.status());
Expand All @@ -315,8 +323,9 @@ pub struct ValidJWTCheck {
kid_to_pkcs1: HashMap<String, Vec<u8>>,
}

#[async_trait]
impl ClientAuthCheck for ValidJWTCheck {
fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
async fn check(&self, token: &str) -> StdResult<(), ClientVisibleMsg> {
match self.check_jwt_validity(token) {
Ok(()) => Ok(()),
Err(e) => {
Expand Down
1 change: 0 additions & 1 deletion src/compiler/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,6 @@ mod test {
use std::io::{Cursor, Write};
use std::sync::Arc;
use std::time::Duration;
use std::u64;
use test_case::test_case;
use tokio::runtime::Runtime;

Expand Down
1 change: 0 additions & 1 deletion src/compiler/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use fs_err as fs;
use log::Level::Trace;
use once_cell::sync::Lazy;
#[cfg(feature = "dist-client")]
#[cfg(feature = "dist-client")]
use std::borrow::Borrow;
use std::borrow::Cow;
#[cfg(feature = "dist-client")]
Expand Down
Loading