1
+ use std:: collections:: hash_map:: Entry :: { Occupied , Vacant } ;
2
+ use std:: collections:: HashMap ;
3
+ use std:: sync:: Arc ;
4
+ use std:: time:: { Duration , SystemTime } ;
5
+
1
6
use async_trait:: async_trait;
2
- use tokio:: sync:: Mutex ;
7
+ use tokio:: sync:: { Mutex , OwnedMutexGuard } ;
8
+ use tracing:: { debug, info, warn} ;
3
9
4
10
use crate :: custom_service_account:: CustomServiceAccount ;
5
11
use crate :: default_authorized_user:: ConfigDefaultCredentials ;
@@ -13,6 +19,12 @@ pub(crate) trait ServiceAccount: Send + Sync {
13
19
async fn project_id ( & self , client : & HyperClient ) -> Result < String , Error > ;
14
20
fn get_token ( & self , scopes : & [ & str ] ) -> Option < Token > ;
15
21
async fn refresh_token ( & self , client : & HyperClient , scopes : & [ & str ] ) -> Result < Token , Error > ;
22
+ fn get_style ( & self ) -> TokenStyle ;
23
+ }
24
+
25
+ pub ( crate ) enum TokenStyle {
26
+ Account ,
27
+ AccountAndScopes ,
16
28
}
17
29
18
30
/// Authentication manager is responsible for caching and obtaining credentials for the required
@@ -21,10 +33,13 @@ pub(crate) trait ServiceAccount: Send + Sync {
21
33
/// Construct the authentication manager with [`AuthenticationManager::new()`] or by creating
22
34
/// a [`CustomServiceAccount`], then converting it into an `AuthenticationManager` using the `From`
23
35
/// impl.
24
- pub struct AuthenticationManager {
25
- pub ( crate ) client : HyperClient ,
26
- pub ( crate ) service_account : Box < dyn ServiceAccount > ,
27
- refresh_mutex : Mutex < ( ) > ,
36
+ #[ derive( Clone ) ]
37
+ pub struct AuthenticationManager ( Arc < AuthManagerInner > ) ;
38
+
39
+ struct AuthManagerInner {
40
+ client : HyperClient ,
41
+ service_account : Box < dyn ServiceAccount > ,
42
+ refresh_lock : RefreshLock ,
28
43
}
29
44
30
45
impl AuthenticationManager {
@@ -80,40 +95,82 @@ impl AuthenticationManager {
80
95
}
81
96
82
97
fn build ( client : HyperClient , service_account : impl ServiceAccount + ' static ) -> Self {
83
- Self {
98
+ let refresh_lock = RefreshLock :: new ( service_account. get_style ( ) ) ;
99
+ Self ( Arc :: new ( AuthManagerInner {
84
100
client,
85
101
service_account : Box :: new ( service_account) ,
86
- refresh_mutex : Mutex :: new ( ( ) ) ,
87
- }
102
+ refresh_lock ,
103
+ } ) )
88
104
}
89
105
90
106
/// Requests Bearer token for the provided scope
91
107
///
92
108
/// Token can be used in the request authorization header in format "Bearer {token}"
93
109
pub async fn get_token ( & self , scopes : & [ & str ] ) -> Result < Token , Error > {
94
- let token = self . service_account . get_token ( scopes) ;
110
+ let token = self . 0 . service_account . get_token ( scopes) ;
111
+
95
112
if let Some ( token) = token. filter ( |token| !token. has_expired ( ) ) {
113
+ let valid_for = token
114
+ . expires_at ( )
115
+ . duration_since ( SystemTime :: now ( ) )
116
+ . unwrap_or_default ( ) ;
117
+ if valid_for < Duration :: from_secs ( 60 ) {
118
+ debug ! ( ?valid_for, "gcp_auth token expires soon!" ) ;
119
+
120
+ let lock = self . 0 . refresh_lock . lock_for_scopes ( scopes) . await ;
121
+ match lock. try_lock_owned ( ) {
122
+ Err ( _) => {
123
+ // already being refreshed.
124
+ }
125
+ Ok ( guard) => {
126
+ let inner = self . clone ( ) ;
127
+ let scopes: Vec < String > = scopes. iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
128
+ tokio:: spawn ( async move {
129
+ inner. background_refresh ( scopes, guard) . await ;
130
+ } ) ;
131
+ }
132
+ }
133
+ }
96
134
return Ok ( token) ;
97
135
}
98
136
99
- let _guard = self . refresh_mutex . lock ( ) . await ;
137
+ warn ! ( "starting inline refresh of gcp auth token" ) ;
138
+ let lock = self . 0 . refresh_lock . lock_for_scopes ( scopes) . await ;
139
+ let _guard = lock. lock ( ) . await ;
100
140
101
141
// Check if refresh happened while we were waiting.
102
- let token = self . service_account . get_token ( scopes) ;
142
+ let token = self . 0 . service_account . get_token ( scopes) ;
103
143
if let Some ( token) = token. filter ( |token| !token. has_expired ( ) ) {
104
144
return Ok ( token) ;
105
145
}
106
146
107
- self . service_account
108
- . refresh_token ( & self . client , scopes)
147
+ self . 0
148
+ . service_account
149
+ . refresh_token ( & self . 0 . client , scopes)
109
150
. await
110
151
}
111
152
153
+ async fn background_refresh ( & self , scopes : Vec < String > , _lock : OwnedMutexGuard < ( ) > ) {
154
+ info ! ( "gcp_auth starting background refresh of auth token" ) ;
155
+ let scope_refs: Vec < & str > = scopes. iter ( ) . map ( |s| s. as_str ( ) ) . collect ( ) ;
156
+ match self
157
+ . 0
158
+ . service_account
159
+ . refresh_token ( & self . 0 . client , & scope_refs)
160
+ . await
161
+ {
162
+ Ok ( t) => {
163
+ info ! ( valid_for=?t. expires_at( ) . duration_since( SystemTime :: now( ) ) , "gcp auth completed background token refresh" )
164
+ }
165
+ Err ( err) => warn ! ( ?err, "gcp_auth background token refresh failed" ) ,
166
+ }
167
+ }
168
+
112
169
/// Request the project ID for the authenticating account
113
170
///
114
171
/// This is only available for service account-based authentication methods.
115
172
pub async fn project_id ( & self ) -> Result < String , Error > {
116
- self . service_account . project_id ( & self . client ) . await
173
+ self . 0 . service_account . project_id ( & self . 0 . client ) . await
117
174
}
118
175
}
119
176
@@ -122,3 +179,35 @@ impl From<CustomServiceAccount> for AuthenticationManager {
122
179
Self :: build ( types:: client ( ) , service_account)
123
180
}
124
181
}
182
+
183
+ enum RefreshLock {
184
+ One ( Arc < Mutex < ( ) > > ) ,
185
+ ByScopes ( Mutex < HashMap < Vec < String > , Arc < Mutex < ( ) > > > > ) ,
186
+ }
187
+
188
+ impl RefreshLock {
189
+ fn new ( style : TokenStyle ) -> Self {
190
+ match style {
191
+ TokenStyle :: Account => RefreshLock :: One ( Arc :: new ( Mutex :: new ( ( ) ) ) ) ,
192
+ TokenStyle :: AccountAndScopes => RefreshLock :: ByScopes ( Mutex :: new ( HashMap :: new ( ) ) ) ,
193
+ }
194
+ }
195
+
196
+ async fn lock_for_scopes ( & self , scopes : & [ & str ] ) -> Arc < Mutex < ( ) > > {
197
+ match self {
198
+ RefreshLock :: One ( mutex) => mutex. clone ( ) ,
199
+ RefreshLock :: ByScopes ( mutexes) => {
200
+ let scopes_key: Vec < _ > = scopes. iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
201
+ let mut scope_locks = mutexes. lock ( ) . await ;
202
+ match scope_locks. entry ( scopes_key) {
203
+ Occupied ( e) => e. get ( ) . clone ( ) ,
204
+ Vacant ( v) => {
205
+ let lock = Arc :: new ( Mutex :: new ( ( ) ) ) ;
206
+ v. insert ( lock. clone ( ) ) ;
207
+ lock
208
+ }
209
+ }
210
+ }
211
+ }
212
+ }
213
+ }
0 commit comments