diff --git a/src/client.rs b/src/client.rs index 45b66d5..82557c8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -9,7 +9,7 @@ use flate2::read::GzDecoder; use reqwest::{blocking::Client, StatusCode, Url}; use std::{ io::{Cursor, Read}, - sync::RwLock, + sync::{Condvar, Mutex, RwLock}, time::Duration, }; use tracing::{error, info, trace, warn}; @@ -20,22 +20,22 @@ pub type TarballType = tar::Archive>>; /// Type alias representing a zip archive pub type ZipType = zip::ZipArchive>>; -/// Application state -pub struct State { - /// The current ruleset this client is using - pub rules: yara::Rules, +pub struct AuthState { + pub access_token: RwLock, + pub authenticating: Mutex, + pub cvar: Condvar, +} - /// The GitHub commit hash of the ruleset this client is using +pub struct RulesState { + pub rules: yara::Rules, pub hash: String, - - /// Access token this client is using for authentication - pub access_token: String, } #[warn(clippy::module_name_repetitions)] pub struct DragonflyClient { pub client: Client, - pub state: RwLock, + pub authentication_state: AuthState, + pub rules_state: RwLock, } impl DragonflyClient { @@ -44,21 +44,46 @@ impl DragonflyClient { let auth_response = fetch_access_token(&client)?; let rules_response = fetch_rules(&client, &auth_response.access_token)?; - let state = State { + + let auth_state = AuthState { + access_token: RwLock::new(auth_response.access_token), + authenticating: Mutex::new(false), + cvar: Condvar::new(), + }; + + let rules_state = RwLock::new(RulesState { rules: rules_response.compile()?, hash: rules_response.hash, - access_token: auth_response.access_token, - } - .into(); + }); - Ok(Self { client, state }) + Ok(Self { + client, + authentication_state: auth_state, + rules_state, + }) } - /// Update the state with a new access token, using the given write lock [`RwLockWriteGuard`] + /// Update the state with a new access token. /// /// If an error occurs while reauthenticating, the function retries with an exponential backoff /// described by the equation `min(10 * 60, 2^(x - 1))` where `x` is the number of failed tries. - pub fn reauthenticate(&self) -> String { + pub fn reauthenticate(&self) { + trace!("Trying to lock to check if we're authenticating."); + let mut authing = self.authentication_state.authenticating.lock().unwrap(); + trace!("Acquired lock"); + if *authing { + trace!("Another thread is authenticating. Waiting for it to finish."); + let _guard = self + .authentication_state + .cvar + .wait_while(authing, |authing| *authing); + trace!("Was notified, returning"); + return; + } + trace!("No other thread is authenticating. Trying to reauthenticate."); + *authing = true; + drop(authing); + let access_token; let base = 2_f64; @@ -88,28 +113,31 @@ impl DragonflyClient { } } + trace!("Successfully got new access token!"); + + *self.authentication_state.access_token.write().unwrap() = access_token; + + let mut authing = self.authentication_state.authenticating.lock().unwrap(); + *authing = false; + self.authentication_state.cvar.notify_all(); + info!("Successfully reauthenticated."); - access_token } /// Update the global ruleset. Waits for a write lock. pub fn update_rules(&self) -> Result<(), DragonflyError> { let response = match fetch_rules( self.get_http_client(), - &self.state.read().unwrap().access_token, + &self.authentication_state.access_token.read().unwrap(), ) { Err(err) if err.status() == Some(StatusCode::UNAUTHORIZED) => { info!("Got 401 UNAUTHORIZED while updating rules"); - trace!("Waiting on write lock to update access token"); - let mut state = self.state.write().unwrap(); - trace!("Successfully obtained write lock!"); - trace!("Requesting new access token..."); - let new_access_token = self.reauthenticate(); - trace!("Successfully got new access token!"); - state.access_token = new_access_token; - info!("Successfully updated local access token to new one!"); + self.reauthenticate(); info!("Fetching rules again..."); - fetch_rules(self.get_http_client(), &state.access_token) + fetch_rules( + self.get_http_client(), + &self.authentication_state.access_token.read().unwrap(), + ) } Ok(response) => Ok(response), @@ -117,29 +145,26 @@ impl DragonflyClient { Err(err) => Err(err), }?; - let mut state = self.state.write().unwrap(); - state.rules = response.compile()?; - state.hash = response.hash; + let mut rules_state = self.rules_state.write().unwrap(); + rules_state.rules = response.compile()?; + rules_state.hash = response.hash; Ok(()) } pub fn bulk_get_job(&self, n_jobs: usize) -> reqwest::Result> { - let state = self.state.read().unwrap(); - match fetch_bulk_job(self.get_http_client(), &state.access_token, n_jobs) { + let access_token = self.authentication_state.access_token.read().unwrap(); + match fetch_bulk_job(self.get_http_client(), &access_token, n_jobs) { Err(err) if err.status() == Some(StatusCode::UNAUTHORIZED) => { - drop(state); // Drop the read lock + drop(access_token); // Drop the read lock info!("Got 401 UNAUTHORIZED while doing a bulk fetch job request"); - trace!("Waiting on write lock to update access token"); - let mut state = self.state.write().unwrap(); - trace!("Successfully obtained write lock!"); - trace!("Requesting new access token..."); - let new_access_token = self.reauthenticate(); - trace!("Successfully got new access token!"); - state.access_token = new_access_token; - info!("Successfully updated local access token to new one!"); + self.reauthenticate(); info!("Doing a bulk fetch job again..."); - fetch_bulk_job(self.get_http_client(), &state.access_token, n_jobs) + fetch_bulk_job( + self.get_http_client(), + &self.authentication_state.access_token.read().unwrap(), + n_jobs, + ) } other => other, @@ -148,21 +173,18 @@ impl DragonflyClient { /// Report an error to the server. pub fn send_error(&self, body: &SubmitJobResultsError) -> reqwest::Result<()> { - let state = self.state.read().unwrap(); - match send_error(self.get_http_client(), &state.access_token, body) { + let access_token = self.authentication_state.access_token.read().unwrap(); + match send_error(self.get_http_client(), &access_token, body) { Err(http_err) if http_err.status() == Some(StatusCode::UNAUTHORIZED) => { - drop(state); // Drop the read lock + drop(access_token); // Drop the read lock info!("Got 401 UNAUTHORIZED while sending success"); - trace!("Waiting on write lock to update access token"); - let mut state = self.state.write().unwrap(); - trace!("Successfully obtained write lock!"); - trace!("Requesting new access token..."); - let new_access_token = self.reauthenticate(); - trace!("Successfully got new access token!"); - state.access_token = new_access_token; - info!("Successfully updated local access token to new one!"); + self.reauthenticate(); info!("Sending error body again..."); - send_error(self.get_http_client(), &state.access_token, body) + send_error( + self.get_http_client(), + &self.authentication_state.access_token.read().unwrap(), + body, + ) } other => other, @@ -172,21 +194,18 @@ impl DragonflyClient { /// Submit the results of a scan to the server, given the job and the scan results of each /// distribution pub fn send_success(&self, body: &SubmitJobResultsSuccess) -> reqwest::Result<()> { - let state = self.state.read().unwrap(); - match send_success(self.get_http_client(), &state.access_token, body) { + let access_token = self.authentication_state.access_token.read().unwrap(); + match send_success(self.get_http_client(), &access_token, body) { Err(http_err) if http_err.status() == Some(StatusCode::UNAUTHORIZED) => { - drop(state); // Drop the read lock + drop(access_token); // Drop the read lock info!("Got 401 UNAUTHORIZED while sending success"); - trace!("Waiting on write lock to update access token"); - let mut state = self.state.write().unwrap(); - trace!("Successfully obtained write lock!"); - trace!("Requesting new access token..."); - let new_access_token = self.reauthenticate(); - trace!("Successfully got new access token!"); - state.access_token = new_access_token; - info!("Successfully updated local access token to new one!"); + self.reauthenticate(); info!("Sending success body again..."); - send_success(self.get_http_client(), &state.access_token, body) + send_success( + self.get_http_client(), + &self.authentication_state.access_token.read().unwrap(), + body, + ) } other => other, diff --git a/src/main.rs b/src/main.rs index ffbf512..f9cf61b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -53,8 +53,13 @@ fn scanner( fn runner(client: &DragonflyClient, job: Job, tx: &SyncSender) { let span = span!(Level::INFO, "Job", name = job.name, version = job.version); let _enter = span.enter(); - let state = client.state.read().unwrap(); - let send_result = match scanner(client.get_http_client(), &job, &state.rules, &state.hash) { + let rules_state = client.rules_state.read().unwrap(); + let send_result = match scanner( + client.get_http_client(), + &job, + &rules_state.rules, + &rules_state.hash, + ) { Ok(package_scan_results) => tx.send(SubmitJobResultsBody::Success( package_scan_results.build_body(), )), @@ -142,13 +147,13 @@ fn main() -> Result<(), DragonflyError> { for job in jobs { info!("Submitting {} v{} for execution", job.name, job.version); - let state = client.state.read().unwrap(); - if job.hash != client.state.read().unwrap().hash { + let rules_state = client.rules_state.read().unwrap(); + if job.hash != rules_state.hash { info!( "Must update rules, updating from {} to {}", - state.hash, job.hash + rules_state.hash, job.hash ); - drop(state); + drop(rules_state); if let Err(err) = client.update_rules() { error!("Error while updating rules: {err}"); }