Skip to content

Commit 0240a9e

Browse files
authored
Merge pull request djc#6 from juicebox-systems/background_refresh
refresh auth tokens in the background
2 parents 980168e + 7dd925e commit 0240a9e

5 files changed

+123
-18
lines changed

src/authentication_manager.rs

+103-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
use std::collections::hash_map::Entry::{Occupied, Vacant};
2+
use std::collections::HashMap;
3+
use std::sync::Arc;
4+
use std::time::{Duration, SystemTime};
5+
16
use async_trait::async_trait;
2-
use tokio::sync::Mutex;
7+
use tokio::sync::{Mutex, OwnedMutexGuard};
8+
use tracing::{debug, info, warn};
39

410
use crate::custom_service_account::CustomServiceAccount;
511
use crate::default_authorized_user::ConfigDefaultCredentials;
@@ -13,6 +19,12 @@ pub(crate) trait ServiceAccount: Send + Sync {
1319
async fn project_id(&self, client: &HyperClient) -> Result<String, Error>;
1420
fn get_token(&self, scopes: &[&str]) -> Option<Token>;
1521
async fn refresh_token(&self, client: &HyperClient, scopes: &[&str]) -> Result<Token, Error>;
22+
fn get_style(&self) -> TokenStyle;
23+
}
24+
25+
pub(crate) enum TokenStyle {
26+
Account,
27+
AccountAndScopes,
1628
}
1729

1830
/// Authentication manager is responsible for caching and obtaining credentials for the required
@@ -21,10 +33,13 @@ pub(crate) trait ServiceAccount: Send + Sync {
2133
/// Construct the authentication manager with [`AuthenticationManager::new()`] or by creating
2234
/// a [`CustomServiceAccount`], then converting it into an `AuthenticationManager` using the `From`
2335
/// impl.
24-
pub struct AuthenticationManager {
25-
pub(crate) client: HyperClient,
26-
pub(crate) service_account: Box<dyn ServiceAccount>,
27-
refresh_mutex: Mutex<()>,
36+
#[derive(Clone)]
37+
pub struct AuthenticationManager(Arc<AuthManagerInner>);
38+
39+
struct AuthManagerInner {
40+
client: HyperClient,
41+
service_account: Box<dyn ServiceAccount>,
42+
refresh_lock: RefreshLock,
2843
}
2944

3045
impl AuthenticationManager {
@@ -80,40 +95,82 @@ impl AuthenticationManager {
8095
}
8196

8297
fn build(client: HyperClient, service_account: impl ServiceAccount + 'static) -> Self {
83-
Self {
98+
let refresh_lock = RefreshLock::new(service_account.get_style());
99+
Self(Arc::new(AuthManagerInner {
84100
client,
85101
service_account: Box::new(service_account),
86-
refresh_mutex: Mutex::new(()),
87-
}
102+
refresh_lock,
103+
}))
88104
}
89105

90106
/// Requests Bearer token for the provided scope
91107
///
92108
/// Token can be used in the request authorization header in format "Bearer {token}"
93109
pub async fn get_token(&self, scopes: &[&str]) -> Result<Token, Error> {
94-
let token = self.service_account.get_token(scopes);
110+
let token = self.0.service_account.get_token(scopes);
111+
95112
if let Some(token) = token.filter(|token| !token.has_expired()) {
113+
let valid_for = token
114+
.expires_at()
115+
.duration_since(SystemTime::now())
116+
.unwrap_or_default();
117+
if valid_for < Duration::from_secs(60) {
118+
debug!(?valid_for, "gcp_auth token expires soon!");
119+
120+
let lock = self.0.refresh_lock.lock_for_scopes(scopes).await;
121+
match lock.try_lock_owned() {
122+
Err(_) => {
123+
// already being refreshed.
124+
}
125+
Ok(guard) => {
126+
let inner = self.clone();
127+
let scopes: Vec<String> = scopes.iter().map(|s| s.to_string()).collect();
128+
tokio::spawn(async move {
129+
inner.background_refresh(scopes, guard).await;
130+
});
131+
}
132+
}
133+
}
96134
return Ok(token);
97135
}
98136

99-
let _guard = self.refresh_mutex.lock().await;
137+
warn!("starting inline refresh of gcp auth token");
138+
let lock = self.0.refresh_lock.lock_for_scopes(scopes).await;
139+
let _guard = lock.lock().await;
100140

101141
// Check if refresh happened while we were waiting.
102-
let token = self.service_account.get_token(scopes);
142+
let token = self.0.service_account.get_token(scopes);
103143
if let Some(token) = token.filter(|token| !token.has_expired()) {
104144
return Ok(token);
105145
}
106146

107-
self.service_account
108-
.refresh_token(&self.client, scopes)
147+
self.0
148+
.service_account
149+
.refresh_token(&self.0.client, scopes)
109150
.await
110151
}
111152

153+
async fn background_refresh(&self, scopes: Vec<String>, _lock: OwnedMutexGuard<()>) {
154+
info!("gcp_auth starting background refresh of auth token");
155+
let scope_refs: Vec<&str> = scopes.iter().map(|s| s.as_str()).collect();
156+
match self
157+
.0
158+
.service_account
159+
.refresh_token(&self.0.client, &scope_refs)
160+
.await
161+
{
162+
Ok(t) => {
163+
info!(valid_for=?t.expires_at().duration_since(SystemTime::now()), "gcp auth completed background token refresh")
164+
}
165+
Err(err) => warn!(?err, "gcp_auth background token refresh failed"),
166+
}
167+
}
168+
112169
/// Request the project ID for the authenticating account
113170
///
114171
/// This is only available for service account-based authentication methods.
115172
pub async fn project_id(&self) -> Result<String, Error> {
116-
self.service_account.project_id(&self.client).await
173+
self.0.service_account.project_id(&self.0.client).await
117174
}
118175
}
119176

@@ -122,3 +179,35 @@ impl From<CustomServiceAccount> for AuthenticationManager {
122179
Self::build(types::client(), service_account)
123180
}
124181
}
182+
183+
enum RefreshLock {
184+
One(Arc<Mutex<()>>),
185+
ByScopes(Mutex<HashMap<Vec<String>, Arc<Mutex<()>>>>),
186+
}
187+
188+
impl RefreshLock {
189+
fn new(style: TokenStyle) -> Self {
190+
match style {
191+
TokenStyle::Account => RefreshLock::One(Arc::new(Mutex::new(()))),
192+
TokenStyle::AccountAndScopes => RefreshLock::ByScopes(Mutex::new(HashMap::new())),
193+
}
194+
}
195+
196+
async fn lock_for_scopes(&self, scopes: &[&str]) -> Arc<Mutex<()>> {
197+
match self {
198+
RefreshLock::One(mutex) => mutex.clone(),
199+
RefreshLock::ByScopes(mutexes) => {
200+
let scopes_key: Vec<_> = scopes.iter().map(|s| s.to_string()).collect();
201+
let mut scope_locks = mutexes.lock().await;
202+
match scope_locks.entry(scopes_key) {
203+
Occupied(e) => e.get().clone(),
204+
Vacant(v) => {
205+
let lock = Arc::new(Mutex::new(()));
206+
v.insert(lock.clone());
207+
lock
208+
}
209+
}
210+
}
211+
}
212+
}
213+
}

src/custom_service_account.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use std::sync::RwLock;
66
use async_trait::async_trait;
77
use serde::{Deserialize, Serialize};
88

9-
use crate::authentication_manager::ServiceAccount;
9+
use crate::authentication_manager::{ServiceAccount, TokenStyle};
1010
use crate::error::Error;
1111
use crate::types::{HyperClient, SecretString, Signer, Token};
1212
use crate::util::HyperExt;
@@ -80,6 +80,10 @@ impl CustomServiceAccount {
8080

8181
#[async_trait]
8282
impl ServiceAccount for CustomServiceAccount {
83+
fn get_style(&self) -> TokenStyle {
84+
TokenStyle::AccountAndScopes
85+
}
86+
8387
async fn project_id(&self, _: &HyperClient) -> Result<String, Error> {
8488
match &self.credentials.project_id {
8589
Some(pid) => Ok(pid.clone()),

src/default_authorized_user.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use hyper::body::Body;
77
use hyper::{Method, Request};
88
use serde::{Deserialize, Serialize};
99

10-
use crate::authentication_manager::ServiceAccount;
10+
use crate::authentication_manager::{ServiceAccount, TokenStyle};
1111
use crate::error::Error;
1212
use crate::types::{HyperClient, SecretString, Token};
1313
use crate::util::HyperExt;
@@ -77,6 +77,10 @@ impl ConfigDefaultCredentials {
7777

7878
#[async_trait]
7979
impl ServiceAccount for ConfigDefaultCredentials {
80+
fn get_style(&self) -> TokenStyle {
81+
TokenStyle::Account
82+
}
83+
8084
async fn project_id(&self, _: &HyperClient) -> Result<String, Error> {
8185
self.credentials
8286
.quota_project_id

src/default_service_account.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use async_trait::async_trait;
55
use hyper::body::Body;
66
use hyper::{Method, Request};
77

8-
use crate::authentication_manager::ServiceAccount;
8+
use crate::authentication_manager::{ServiceAccount, TokenStyle};
99
use crate::error::Error;
1010
use crate::types::{HyperClient, Token};
1111
use crate::util::HyperExt;
@@ -62,6 +62,10 @@ impl MetadataServiceAccount {
6262

6363
#[async_trait]
6464
impl ServiceAccount for MetadataServiceAccount {
65+
fn get_style(&self) -> TokenStyle {
66+
TokenStyle::Account
67+
}
68+
6569
async fn project_id(&self, client: &HyperClient) -> Result<String, Error> {
6670
tracing::debug!("Getting project ID from GCP instance metadata server");
6771
let req = Self::build_token_request(Self::DEFAULT_PROJECT_ID_GCP_URI);

src/gcloud_authorized_user.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::sync::RwLock;
55
use async_trait::async_trait;
66
use std::time::Duration;
77

8-
use crate::authentication_manager::ServiceAccount;
8+
use crate::authentication_manager::{ServiceAccount, TokenStyle};
99
use crate::error::Error;
1010
use crate::error::Error::{GCloudError, GCloudParseError};
1111
use crate::types::{HyperClient, SecretString};
@@ -45,6 +45,10 @@ impl GCloudAuthorizedUser {
4545

4646
#[async_trait]
4747
impl ServiceAccount for GCloudAuthorizedUser {
48+
fn get_style(&self) -> TokenStyle {
49+
TokenStyle::Account
50+
}
51+
4852
async fn project_id(&self, _: &HyperClient) -> Result<String, Error> {
4953
self.project_id.clone().ok_or(Error::NoProjectId)
5054
}

0 commit comments

Comments
 (0)