diff --git a/auth/src/main/kotlin/com/tidal/sdk/auth/TokenRepository.kt b/auth/src/main/kotlin/com/tidal/sdk/auth/TokenRepository.kt index 976bd422..bfc15f72 100644 --- a/auth/src/main/kotlin/com/tidal/sdk/auth/TokenRepository.kt +++ b/auth/src/main/kotlin/com/tidal/sdk/auth/TokenRepository.kt @@ -20,10 +20,10 @@ import com.tidal.sdk.common.TidalMessage import com.tidal.sdk.common.UnexpectedError import com.tidal.sdk.common.d import com.tidal.sdk.common.logger -import java.net.HttpURLConnection import kotlinx.coroutines.flow.MutableSharedFlow -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.runBlocking +import java.net.HttpURLConnection +import java.util.concurrent.atomic.AtomicInteger internal class TokenRepository( private val authConfig: AuthConfig, @@ -35,10 +35,12 @@ internal class TokenRepository( private val bus: MutableSharedFlow, ) { - /** - * Mutex to ensure that only one thread at a time can update/upgrade the token. - */ - private val tokenMutex = Mutex() + var getCredentialsCalls = AtomicInteger(0) + var refreshesBranchSkipOrOuterSkip = AtomicInteger(0) + var refreshesBranchToken = AtomicInteger(0) + var refreshesBranchSecret = AtomicInteger(0) + var refreshesBranchLogout = AtomicInteger(0) + var upgrades = AtomicInteger(0) private fun needsCredentialsUpgrade(): Boolean { val storedCredentials = getLatestTokens()?.credentials @@ -65,66 +67,85 @@ internal class TokenRepository( @Suppress("UnusedPrivateMember") suspend fun getCredentials(apiErrorSubStatus: String?): AuthResult { - var upgradedRefreshToken: String? = null - val credentials = getLatestTokens() + getCredentialsCalls.incrementAndGet() logger.d { "Received subStatus: $apiErrorSubStatus" } - return if (credentials != null && needsCredentialsUpgrade()) { - logger.d { "Upgrading credentials" } - val upgradeCredentials = upgradeTokens(credentials) - upgradeCredentials.successData?.let { - upgradedRefreshToken = it.refreshToken - success(it.credentials) - } ?: upgradeCredentials as AuthResult.Failure - } else { - logger.d { "Updating credentials" } - updateCredentials(credentials, apiErrorSubStatus) - }.also { - it.successData?.let { token -> - saveTokensAndNotify(token, upgradedRefreshToken, credentials) + val latestTokens = getLatestTokens() + /** + * Note the double if check. This is to avoid synchronized whenever possible (since it's + * slow). It's the same reason why when you write a singleton you're supposed to do the + * null check both outside and inside the synchronized call. + */ + if ((latestTokens?.credentials?.isExpired(timeProvider) != false) || + needsCredentialsUpgrade() + ) { + return synchronized(this) { + var upgradedRefreshToken: String? = null + val latestTokens = getLatestTokens() + if (latestTokens != null && needsCredentialsUpgrade()) { + val upgradeCredentials = runBlocking { upgradeTokens(latestTokens) } + upgradeCredentials.successData?.let { + upgradedRefreshToken = it.refreshToken + success(it.credentials) + } ?: upgradeCredentials as AuthResult.Failure + } else { + updateCredentials(latestTokens, apiErrorSubStatus) + }.also { + it.successData?.let { token -> + runBlocking { + saveTokensAndNotify(token, upgradedRefreshToken, latestTokens) + } + } + } } } + refreshesBranchSkipOrOuterSkip.incrementAndGet() + return success(latestTokens.credentials) } - private suspend fun updateCredentials( + private fun updateCredentials( storedTokens: Tokens?, apiErrorSubStatus: String?, - ): AuthResult { - return tokenMutex.withLock { - when { - storedTokens?.credentials?.isExpired(timeProvider) == false && - apiErrorSubStatus.shouldRefreshToken().not() -> { - success(storedTokens.credentials) - } - // if a refreshToken is available, we'll use it - storedTokens?.refreshToken != null -> { - val refreshToken = storedTokens.refreshToken - refreshCredentials { refreshUserCredentials(refreshToken) } - } + ) = when { + storedTokens?.credentials?.isExpired(timeProvider) == false && + apiErrorSubStatus.shouldRefreshToken().not() -> { + logger.d { "Refresh skipped" } + refreshesBranchSkipOrOuterSkip.incrementAndGet() + success(storedTokens.credentials) + } + // if a refreshToken is available, we'll use it + storedTokens?.refreshToken != null -> { + val refreshToken = storedTokens.refreshToken + logger.d { "Refreshing via refresh token" } + refreshesBranchToken.incrementAndGet() + runBlocking { refreshCredentials { refreshUserCredentials(refreshToken) } } + } - // if nothing is stored, we will try and refresh using a client secret - authConfig.clientSecret != null -> { - refreshCredentials { getClientCredentials(authConfig.clientSecret) } - } + // if nothing is stored, we will try and refresh using a client secret + authConfig.clientSecret != null -> { + logger.d { "Refreshing via client secret" } + refreshesBranchSecret.incrementAndGet() + runBlocking { refreshCredentials { getClientCredentials(authConfig.clientSecret) } } + } - // as a last resort we return a token-less Credentials, we're logged out - else -> logout() - } + // as a last resort we return a token-less Credentials, we're logged out + else -> { + refreshesBranchLogout.incrementAndGet() + logout() } } private suspend fun upgradeTokens(storedTokens: Tokens): AuthResult { - val response = tokenMutex.withLock { - retryWithPolicy(upgradeBackoffPolicy) { - with(storedTokens) { - tokenService.upgradeToken( - refreshToken = requireNotNull(this.refreshToken), - clientUniqueKey = requireNotNull(authConfig.clientUniqueKey), - clientId = authConfig.clientId, - clientSecret = authConfig.clientSecret, - scopes = authConfig.scopes.toScopesString(), - grantType = GRANT_TYPE_UPGRADE, - ) - } + upgrades.incrementAndGet() + val response = retryWithPolicy(upgradeBackoffPolicy) { + with(storedTokens) { + tokenService.upgradeToken( + refreshToken = requireNotNull(this.refreshToken), + clientUniqueKey = requireNotNull(authConfig.clientUniqueKey), + clientId = authConfig.clientId, + clientSecret = authConfig.clientSecret, + scopes = authConfig.scopes.toScopesString(), + grantType = GRANT_TYPE_UPGRADE, + ) } } diff --git a/auth/src/main/kotlin/com/tidal/sdk/auth/storage/DefaultTokensStore.kt b/auth/src/main/kotlin/com/tidal/sdk/auth/storage/DefaultTokensStore.kt index 1eb4a528..86f38af1 100644 --- a/auth/src/main/kotlin/com/tidal/sdk/auth/storage/DefaultTokensStore.kt +++ b/auth/src/main/kotlin/com/tidal/sdk/auth/storage/DefaultTokensStore.kt @@ -5,16 +5,18 @@ import androidx.security.crypto.EncryptedSharedPreferences import com.tidal.sdk.auth.model.Tokens import com.tidal.sdk.common.logger import com.tidal.sdk.common.w +import kotlinx.serialization.json.Json import javax.inject.Inject +import javax.inject.Singleton import kotlinx.serialization.decodeFromString as decode import kotlinx.serialization.encodeToString as encode -import kotlinx.serialization.json.Json /** * This class uses [EncryptedSharedPreferences] to securely store credentials. * Pass in a [SharedPreferences] instance to use a custom one, by default * we inject an [EncryptedSharedPreferences] instance. */ +@Singleton internal class DefaultTokensStore @Inject constructor( private val credentialsKey: String, private val sharedPreferences: SharedPreferences, diff --git a/auth/src/test/kotlin/com/tidal/sdk/auth/TokenRepositoryTest.kt b/auth/src/test/kotlin/com/tidal/sdk/auth/TokenRepositoryTest.kt index e36c2aba..780cbfc3 100644 --- a/auth/src/test/kotlin/com/tidal/sdk/auth/TokenRepositoryTest.kt +++ b/auth/src/test/kotlin/com/tidal/sdk/auth/TokenRepositoryTest.kt @@ -6,6 +6,7 @@ import com.tidal.sdk.auth.login.FakeTokensStore import com.tidal.sdk.auth.model.ApiErrorSubStatus import com.tidal.sdk.auth.model.AuthConfig import com.tidal.sdk.auth.model.AuthResult +import com.tidal.sdk.auth.model.Credentials import com.tidal.sdk.auth.model.CredentialsUpdatedMessage import com.tidal.sdk.auth.model.Tokens import com.tidal.sdk.auth.util.RetryPolicy @@ -18,14 +19,19 @@ import com.tidal.sdk.util.TEST_CLIENT_ID import com.tidal.sdk.util.TEST_CLIENT_UNIQUE_KEY import com.tidal.sdk.util.TEST_TIME_PROVIDER import com.tidal.sdk.util.makeCredentials +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.StandardTestDispatcher import kotlinx.coroutines.test.advanceUntilIdle import kotlinx.coroutines.test.runTest import okio.IOException import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test +import kotlin.test.assertEquals class TokenRepositoryTest { @@ -755,4 +761,42 @@ class TokenRepositoryTest { "No calls to the backend should have been made" } } + + @Test + fun `getCredentials called from many threads`() = runTest { + val credentials = makeCredentials( + userId = "valid", + isExpired = true, + ) + val tokens = Tokens( + credentials, + "refreshToken", + ) + createTokenRepository( + FakeTokenService(), + FakeTokensStore(authConfig.credentialsKey, tokens), + ) + val deferreds = mutableSetOf>>() + val threads = mutableSetOf() + repeat(1_000) { + deferreds.add( + async { tokenRepository.getCredentials(null) }, + ) + threads.add( + Thread { + runBlocking { tokenRepository.getCredentials(null) } + }.apply { + start() + }, + ) + } + deferreds.awaitAll() + threads.onEach { it.join() } + assertEquals(2_000, tokenRepository.getCredentialsCalls.get()) + assertEquals(1_999, tokenRepository.refreshesBranchSkipOrOuterSkip.get()) + assertEquals(1, tokenRepository.refreshesBranchToken.get()) + assertEquals(0, tokenRepository.refreshesBranchSecret.get()) + assertEquals(0, tokenRepository.refreshesBranchLogout.get()) + assertEquals(0, tokenRepository.upgrades.get()) + } }