Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix concurrency management within getCredentials #25

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 75 additions & 54 deletions auth/src/main/kotlin/com/tidal/sdk/auth/TokenRepository.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import com.tidal.sdk.common.TidalMessage
import com.tidal.sdk.common.UnexpectedError
import com.tidal.sdk.common.d
import com.tidal.sdk.common.logger
import java.net.HttpURLConnection
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.runBlocking
import java.net.HttpURLConnection
import java.util.concurrent.atomic.AtomicInteger

internal class TokenRepository(
private val authConfig: AuthConfig,
Expand All @@ -35,10 +35,12 @@ internal class TokenRepository(
private val bus: MutableSharedFlow<TidalMessage>,
) {

/**
* Mutex to ensure that only one thread at a time can update/upgrade the token.
*/
private val tokenMutex = Mutex()
var getCredentialsCalls = AtomicInteger(0)
var refreshesBranchSkipOrOuterSkip = AtomicInteger(0)
var refreshesBranchToken = AtomicInteger(0)
var refreshesBranchSecret = AtomicInteger(0)
var refreshesBranchLogout = AtomicInteger(0)
var upgrades = AtomicInteger(0)

private fun needsCredentialsUpgrade(): Boolean {
val storedCredentials = getLatestTokens()?.credentials
Expand All @@ -65,66 +67,85 @@ internal class TokenRepository(

@Suppress("UnusedPrivateMember")
suspend fun getCredentials(apiErrorSubStatus: String?): AuthResult<Credentials> {
var upgradedRefreshToken: String? = null
val credentials = getLatestTokens()
getCredentialsCalls.incrementAndGet()
logger.d { "Received subStatus: $apiErrorSubStatus" }
return if (credentials != null && needsCredentialsUpgrade()) {
logger.d { "Upgrading credentials" }
val upgradeCredentials = upgradeTokens(credentials)
upgradeCredentials.successData?.let {
upgradedRefreshToken = it.refreshToken
success(it.credentials)
} ?: upgradeCredentials as AuthResult.Failure
} else {
logger.d { "Updating credentials" }
updateCredentials(credentials, apiErrorSubStatus)
}.also {
it.successData?.let { token ->
saveTokensAndNotify(token, upgradedRefreshToken, credentials)
val latestTokens = getLatestTokens()
/**
* Note the double if check. This is to avoid synchronized whenever possible (since it's
* slow). It's the same reason why when you write a singleton you're supposed to do the
* null check both outside and inside the synchronized call.
*/
if ((latestTokens?.credentials?.isExpired(timeProvider) != false) ||
needsCredentialsUpgrade()
) {
return synchronized(this) {
var upgradedRefreshToken: String? = null
val latestTokens = getLatestTokens()
if (latestTokens != null && needsCredentialsUpgrade()) {
val upgradeCredentials = runBlocking { upgradeTokens(latestTokens) }
upgradeCredentials.successData?.let {
upgradedRefreshToken = it.refreshToken
success(it.credentials)
} ?: upgradeCredentials as AuthResult.Failure
} else {
updateCredentials(latestTokens, apiErrorSubStatus)
}.also {
it.successData?.let { token ->
runBlocking {
saveTokensAndNotify(token, upgradedRefreshToken, latestTokens)
}
}
}
}
}
refreshesBranchSkipOrOuterSkip.incrementAndGet()
return success(latestTokens.credentials)
}

private suspend fun updateCredentials(
private fun updateCredentials(
storedTokens: Tokens?,
apiErrorSubStatus: String?,
): AuthResult<Credentials> {
return tokenMutex.withLock {
when {
storedTokens?.credentials?.isExpired(timeProvider) == false &&
apiErrorSubStatus.shouldRefreshToken().not() -> {
success(storedTokens.credentials)
}
// if a refreshToken is available, we'll use it
storedTokens?.refreshToken != null -> {
val refreshToken = storedTokens.refreshToken
refreshCredentials { refreshUserCredentials(refreshToken) }
}
) = when {
storedTokens?.credentials?.isExpired(timeProvider) == false &&
apiErrorSubStatus.shouldRefreshToken().not() -> {
logger.d { "Refresh skipped" }
refreshesBranchSkipOrOuterSkip.incrementAndGet()
success(storedTokens.credentials)
}
// if a refreshToken is available, we'll use it
storedTokens?.refreshToken != null -> {
val refreshToken = storedTokens.refreshToken
logger.d { "Refreshing via refresh token" }
refreshesBranchToken.incrementAndGet()
runBlocking { refreshCredentials { refreshUserCredentials(refreshToken) } }
}

// if nothing is stored, we will try and refresh using a client secret
authConfig.clientSecret != null -> {
refreshCredentials { getClientCredentials(authConfig.clientSecret) }
}
// if nothing is stored, we will try and refresh using a client secret
authConfig.clientSecret != null -> {
logger.d { "Refreshing via client secret" }
refreshesBranchSecret.incrementAndGet()
runBlocking { refreshCredentials { getClientCredentials(authConfig.clientSecret) } }
}

// as a last resort we return a token-less Credentials, we're logged out
else -> logout()
}
// as a last resort we return a token-less Credentials, we're logged out
else -> {
refreshesBranchLogout.incrementAndGet()
logout()
}
}

private suspend fun upgradeTokens(storedTokens: Tokens): AuthResult<Tokens> {
val response = tokenMutex.withLock {
retryWithPolicy(upgradeBackoffPolicy) {
with(storedTokens) {
tokenService.upgradeToken(
refreshToken = requireNotNull(this.refreshToken),
clientUniqueKey = requireNotNull(authConfig.clientUniqueKey),
clientId = authConfig.clientId,
clientSecret = authConfig.clientSecret,
scopes = authConfig.scopes.toScopesString(),
grantType = GRANT_TYPE_UPGRADE,
)
}
upgrades.incrementAndGet()
val response = retryWithPolicy(upgradeBackoffPolicy) {
with(storedTokens) {
tokenService.upgradeToken(
refreshToken = requireNotNull(this.refreshToken),
clientUniqueKey = requireNotNull(authConfig.clientUniqueKey),
clientId = authConfig.clientId,
clientSecret = authConfig.clientSecret,
scopes = authConfig.scopes.toScopesString(),
grantType = GRANT_TYPE_UPGRADE,
)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@ import androidx.security.crypto.EncryptedSharedPreferences
import com.tidal.sdk.auth.model.Tokens
import com.tidal.sdk.common.logger
import com.tidal.sdk.common.w
import kotlinx.serialization.json.Json
import javax.inject.Inject
import javax.inject.Singleton
import kotlinx.serialization.decodeFromString as decode
import kotlinx.serialization.encodeToString as encode
import kotlinx.serialization.json.Json

/**
* This class uses [EncryptedSharedPreferences] to securely store credentials.
* Pass in a [SharedPreferences] instance to use a custom one, by default
* we inject an [EncryptedSharedPreferences] instance.
*/
@Singleton
internal class DefaultTokensStore @Inject constructor(
private val credentialsKey: String,
private val sharedPreferences: SharedPreferences,
Expand Down
44 changes: 44 additions & 0 deletions auth/src/test/kotlin/com/tidal/sdk/auth/TokenRepositoryTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.tidal.sdk.auth.login.FakeTokensStore
import com.tidal.sdk.auth.model.ApiErrorSubStatus
import com.tidal.sdk.auth.model.AuthConfig
import com.tidal.sdk.auth.model.AuthResult
import com.tidal.sdk.auth.model.Credentials
import com.tidal.sdk.auth.model.CredentialsUpdatedMessage
import com.tidal.sdk.auth.model.Tokens
import com.tidal.sdk.auth.util.RetryPolicy
Expand All @@ -18,14 +19,19 @@ import com.tidal.sdk.util.TEST_CLIENT_ID
import com.tidal.sdk.util.TEST_CLIENT_UNIQUE_KEY
import com.tidal.sdk.util.TEST_TIME_PROVIDER
import com.tidal.sdk.util.makeCredentials
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.advanceUntilIdle
import kotlinx.coroutines.test.runTest
import okio.IOException
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals

class TokenRepositoryTest {

Expand Down Expand Up @@ -755,4 +761,42 @@ class TokenRepositoryTest {
"No calls to the backend should have been made"
}
}

@Test
fun `getCredentials called from many threads`() = runTest {
val credentials = makeCredentials(
userId = "valid",
isExpired = true,
)
val tokens = Tokens(
credentials,
"refreshToken",
)
createTokenRepository(
FakeTokenService(),
FakeTokensStore(authConfig.credentialsKey, tokens),
)
val deferreds = mutableSetOf<Deferred<AuthResult<Credentials>>>()
val threads = mutableSetOf<Thread>()
repeat(1_000) {
deferreds.add(
async { tokenRepository.getCredentials(null) },
)
threads.add(
Thread {
runBlocking { tokenRepository.getCredentials(null) }
}.apply {
start()
},
)
}
deferreds.awaitAll()
threads.onEach { it.join() }
assertEquals(2_000, tokenRepository.getCredentialsCalls.get())
assertEquals(1_999, tokenRepository.refreshesBranchSkipOrOuterSkip.get())
assertEquals(1, tokenRepository.refreshesBranchToken.get())
assertEquals(0, tokenRepository.refreshesBranchSecret.get())
assertEquals(0, tokenRepository.refreshesBranchLogout.get())
assertEquals(0, tokenRepository.upgrades.get())
}
}
Loading