Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reauth race condition #50

Merged
merged 3 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 86 additions & 67 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use flate2::read::GzDecoder;
use reqwest::{blocking::Client, StatusCode, Url};
use std::{
io::{Cursor, Read},
sync::RwLock,
sync::{Condvar, Mutex, RwLock},
time::Duration,
};
use tracing::{error, info, trace, warn};
Expand All @@ -20,22 +20,22 @@ pub type TarballType = tar::Archive<Cursor<Vec<u8>>>;
/// Type alias representing a zip archive
pub type ZipType = zip::ZipArchive<Cursor<Vec<u8>>>;

/// Application state
pub struct State {
/// The current ruleset this client is using
pub rules: yara::Rules,
pub struct AuthState {
pub access_token: RwLock<String>,
pub authenticating: Mutex<bool>,
pub cvar: Condvar,
}

/// The GitHub commit hash of the ruleset this client is using
pub struct RulesState {
pub rules: yara::Rules,
pub hash: String,

/// Access token this client is using for authentication
pub access_token: String,
}

#[warn(clippy::module_name_repetitions)]
pub struct DragonflyClient {
pub client: Client,
pub state: RwLock<State>,
pub authentication_state: AuthState,
pub rules_state: RwLock<RulesState>,
}

impl DragonflyClient {
Expand All @@ -44,21 +44,46 @@ impl DragonflyClient {

let auth_response = fetch_access_token(&client)?;
let rules_response = fetch_rules(&client, &auth_response.access_token)?;
let state = State {

let auth_state = AuthState {
access_token: RwLock::new(auth_response.access_token),
authenticating: Mutex::new(false),
cvar: Condvar::new(),
};

let rules_state = RwLock::new(RulesState {
rules: rules_response.compile()?,
hash: rules_response.hash,
access_token: auth_response.access_token,
}
.into();
});

Ok(Self { client, state })
Ok(Self {
client,
authentication_state: auth_state,
rules_state,
})
}

/// Update the state with a new access token, using the given write lock [`RwLockWriteGuard`]
/// Update the state with a new access token.
///
/// If an error occurs while reauthenticating, the function retries with an exponential backoff
/// described by the equation `min(10 * 60, 2^(x - 1))` where `x` is the number of failed tries.
pub fn reauthenticate(&self) -> String {
pub fn reauthenticate(&self) {
trace!("Trying to lock to check if we're authenticating.");
let mut authing = self.authentication_state.authenticating.lock().unwrap();
trace!("Acquired lock");
if *authing {
trace!("Another thread is authenticating. Waiting for it to finish.");
let _guard = self
.authentication_state
.cvar
.wait_while(authing, |authing| *authing);
trace!("Was notified, returning");
return;
}
trace!("No other thread is authenticating. Trying to reauthenticate.");
*authing = true;
drop(authing);

let access_token;

let base = 2_f64;
Expand Down Expand Up @@ -88,58 +113,58 @@ impl DragonflyClient {
}
}

trace!("Successfully got new access token!");

*self.authentication_state.access_token.write().unwrap() = access_token;

let mut authing = self.authentication_state.authenticating.lock().unwrap();
*authing = false;
self.authentication_state.cvar.notify_all();

info!("Successfully reauthenticated.");
access_token
}

/// Update the global ruleset. Waits for a write lock.
pub fn update_rules(&self) -> Result<(), DragonflyError> {
let response = match fetch_rules(
self.get_http_client(),
&self.state.read().unwrap().access_token,
&self.authentication_state.access_token.read().unwrap(),
) {
Err(err) if err.status() == Some(StatusCode::UNAUTHORIZED) => {
info!("Got 401 UNAUTHORIZED while updating rules");
trace!("Waiting on write lock to update access token");
let mut state = self.state.write().unwrap();
trace!("Successfully obtained write lock!");
trace!("Requesting new access token...");
let new_access_token = self.reauthenticate();
trace!("Successfully got new access token!");
state.access_token = new_access_token;
info!("Successfully updated local access token to new one!");
self.reauthenticate();
info!("Fetching rules again...");
fetch_rules(self.get_http_client(), &state.access_token)
fetch_rules(
self.get_http_client(),
&self.authentication_state.access_token.read().unwrap(),
)
}

Ok(response) => Ok(response),

Err(err) => Err(err),
}?;

let mut state = self.state.write().unwrap();
state.rules = response.compile()?;
state.hash = response.hash;
let mut rules_state = self.rules_state.write().unwrap();
rules_state.rules = response.compile()?;
rules_state.hash = response.hash;

Ok(())
}

pub fn bulk_get_job(&self, n_jobs: usize) -> reqwest::Result<Vec<Job>> {
let state = self.state.read().unwrap();
match fetch_bulk_job(self.get_http_client(), &state.access_token, n_jobs) {
let access_token = self.authentication_state.access_token.read().unwrap();
match fetch_bulk_job(self.get_http_client(), &access_token, n_jobs) {
Err(err) if err.status() == Some(StatusCode::UNAUTHORIZED) => {
drop(state); // Drop the read lock
drop(access_token); // Drop the read lock
info!("Got 401 UNAUTHORIZED while doing a bulk fetch job request");
trace!("Waiting on write lock to update access token");
let mut state = self.state.write().unwrap();
trace!("Successfully obtained write lock!");
trace!("Requesting new access token...");
let new_access_token = self.reauthenticate();
trace!("Successfully got new access token!");
state.access_token = new_access_token;
info!("Successfully updated local access token to new one!");
self.reauthenticate();
info!("Doing a bulk fetch job again...");
fetch_bulk_job(self.get_http_client(), &state.access_token, n_jobs)
fetch_bulk_job(
self.get_http_client(),
&self.authentication_state.access_token.read().unwrap(),
n_jobs,
)
}

other => other,
Expand All @@ -148,21 +173,18 @@ impl DragonflyClient {

/// Report an error to the server.
pub fn send_error(&self, body: &SubmitJobResultsError) -> reqwest::Result<()> {
let state = self.state.read().unwrap();
match send_error(self.get_http_client(), &state.access_token, body) {
let access_token = self.authentication_state.access_token.read().unwrap();
match send_error(self.get_http_client(), &access_token, body) {
Err(http_err) if http_err.status() == Some(StatusCode::UNAUTHORIZED) => {
drop(state); // Drop the read lock
drop(access_token); // Drop the read lock
info!("Got 401 UNAUTHORIZED while sending success");
trace!("Waiting on write lock to update access token");
let mut state = self.state.write().unwrap();
trace!("Successfully obtained write lock!");
trace!("Requesting new access token...");
let new_access_token = self.reauthenticate();
trace!("Successfully got new access token!");
state.access_token = new_access_token;
info!("Successfully updated local access token to new one!");
self.reauthenticate();
info!("Sending error body again...");
send_error(self.get_http_client(), &state.access_token, body)
send_error(
self.get_http_client(),
&self.authentication_state.access_token.read().unwrap(),
body,
)
}

other => other,
Expand All @@ -172,21 +194,18 @@ impl DragonflyClient {
/// Submit the results of a scan to the server, given the job and the scan results of each
/// distribution
pub fn send_success(&self, body: &SubmitJobResultsSuccess) -> reqwest::Result<()> {
let state = self.state.read().unwrap();
match send_success(self.get_http_client(), &state.access_token, body) {
let access_token = self.authentication_state.access_token.read().unwrap();
match send_success(self.get_http_client(), &access_token, body) {
Err(http_err) if http_err.status() == Some(StatusCode::UNAUTHORIZED) => {
drop(state); // Drop the read lock
drop(access_token); // Drop the read lock
info!("Got 401 UNAUTHORIZED while sending success");
trace!("Waiting on write lock to update access token");
let mut state = self.state.write().unwrap();
trace!("Successfully obtained write lock!");
trace!("Requesting new access token...");
let new_access_token = self.reauthenticate();
trace!("Successfully got new access token!");
state.access_token = new_access_token;
info!("Successfully updated local access token to new one!");
self.reauthenticate();
info!("Sending success body again...");
send_success(self.get_http_client(), &state.access_token, body)
send_success(
self.get_http_client(),
&self.authentication_state.access_token.read().unwrap(),
body,
)
}

other => other,
Expand Down
17 changes: 11 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,13 @@ fn scanner(
fn runner(client: &DragonflyClient, job: Job, tx: &SyncSender<SubmitJobResultsBody>) {
let span = span!(Level::INFO, "Job", name = job.name, version = job.version);
let _enter = span.enter();
let state = client.state.read().unwrap();
let send_result = match scanner(client.get_http_client(), &job, &state.rules, &state.hash) {
let rules_state = client.rules_state.read().unwrap();
let send_result = match scanner(
client.get_http_client(),
&job,
&rules_state.rules,
&rules_state.hash,
) {
Ok(package_scan_results) => tx.send(SubmitJobResultsBody::Success(
package_scan_results.build_body(),
)),
Expand Down Expand Up @@ -142,13 +147,13 @@ fn main() -> Result<(), DragonflyError> {

for job in jobs {
info!("Submitting {} v{} for execution", job.name, job.version);
let state = client.state.read().unwrap();
if job.hash != client.state.read().unwrap().hash {
let rules_state = client.rules_state.read().unwrap();
if job.hash != rules_state.hash {
info!(
"Must update rules, updating from {} to {}",
state.hash, job.hash
rules_state.hash, job.hash
);
drop(state);
drop(rules_state);
if let Err(err) = client.update_rules() {
error!("Error while updating rules: {err}");
}
Expand Down