Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort out auth token refresh and setCredentials synchronization #101

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 46 additions & 46 deletions auth/src/main/kotlin/com/tidal/sdk/auth/TokenRepository.kt
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ import com.tidal.sdk.common.UnexpectedError
import com.tidal.sdk.common.d
import com.tidal.sdk.common.logger
import java.net.HttpURLConnection
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext

internal class TokenRepository(
private val authConfig: AuthConfig,
Expand All @@ -32,13 +35,12 @@ internal class TokenRepository(
private val tokenService: TokenService,
private val defaultBackoffPolicy: RetryPolicy,
private val upgradeBackoffPolicy: RetryPolicy,
private val tokenMutex: Mutex,
private val bus: MutableSharedFlow<TidalMessage>,
coroutineDispatcher: CoroutineDispatcher? = null,
) {

/**
* Mutex to ensure that only one thread at a time can update/upgrade the token.
*/
private val tokenMutex = Mutex()
private val dispatcher = coroutineDispatcher ?: Dispatchers.IO

private fun needsCredentialsUpgrade(): Boolean {
val storedCredentials = getLatestTokens()?.credentials
Expand All @@ -59,31 +61,30 @@ internal class TokenRepository(
}
}

internal fun getLatestTokens(): Tokens? {
return tokensStore.getLatestTokens(authConfig.credentialsKey)
}
internal fun getLatestTokens(): Tokens? = tokensStore.getLatestTokens(authConfig.credentialsKey)

@Suppress("UnusedPrivateMember")
suspend fun getCredentials(apiErrorSubStatus: String?): AuthResult<Credentials> {
logger.d { "Received subStatus: $apiErrorSubStatus" }

return tokenMutex.withLock {
var upgradedRefreshToken: String? = null
val credentials = getLatestTokens()

if (credentials != null && needsCredentialsUpgrade()) {
logger.d { "Upgrading credentials" }
val upgradeCredentials = upgradeTokens(credentials)
upgradeCredentials.successData?.let {
upgradedRefreshToken = it.refreshToken
success(it.credentials)
} ?: upgradeCredentials as AuthResult.Failure
} else {
logger.d { "Updating credentials" }
updateCredentials(credentials, apiErrorSubStatus)
}.also {
it.successData?.let { token ->
saveTokensAndNotify(token, upgradedRefreshToken, credentials)
return withContext(dispatcher) {
tokenMutex.withLock {
var upgradedRefreshToken: String? = null
val credentials = getLatestTokens()

if (credentials != null && needsCredentialsUpgrade()) {
logger.d { "Upgrading credentials" }
val upgradeCredentials = upgradeTokens(credentials)
upgradeCredentials.successData?.let {
upgradedRefreshToken = it.refreshToken
success(it.credentials)
} ?: upgradeCredentials as AuthResult.Failure
} else {
logger.d { "Updating credentials" }
updateCredentials(credentials, apiErrorSubStatus)
}.also {
it.successData?.let { token ->
saveTokensAndNotify(token, upgradedRefreshToken, credentials)
}
}
}
}
Expand All @@ -92,26 +93,24 @@ internal class TokenRepository(
private suspend fun updateCredentials(
storedTokens: Tokens?,
apiErrorSubStatus: String?,
): AuthResult<Credentials> {
return when {
storedTokens?.credentials?.isExpired(timeProvider) == false &&
apiErrorSubStatus.shouldRefreshToken().not() -> {
success(storedTokens.credentials)
}
// if a refreshToken is available, we'll use it
storedTokens?.refreshToken != null -> {
val refreshToken = storedTokens.refreshToken
refreshCredentials { refreshUserCredentials(refreshToken) }
}

// if nothing is stored, we will try and refresh using a client secret
authConfig.clientSecret != null -> {
refreshCredentials { getClientCredentials(authConfig.clientSecret) }
}
): AuthResult<Credentials> = when {
storedTokens?.credentials?.isExpired(timeProvider) == false &&
apiErrorSubStatus.shouldRefreshToken().not() -> {
success(storedTokens.credentials)
}
// if a refreshToken is available, we'll use it
storedTokens?.refreshToken != null -> {
val refreshToken = storedTokens.refreshToken
refreshCredentials { refreshUserCredentials(refreshToken) }
}

// as a last resort we return a token-less Credentials, we're logged out
else -> logout()
// if nothing is stored, we will try and refresh using a client secret
authConfig.clientSecret != null -> {
refreshCredentials { getClientCredentials(authConfig.clientSecret) }
}

// as a last resort we return a token-less Credentials, we're logged out
else -> logout()
}

private suspend fun upgradeTokens(storedTokens: Tokens): AuthResult<Tokens> {
Expand Down Expand Up @@ -213,16 +212,17 @@ internal class TokenRepository(
}
}

private suspend fun getClientCredentials(clientSecret: String): AuthResult<RefreshResponse> {
return retryWithPolicy(defaultBackoffPolicy) {
private suspend fun getClientCredentials(clientSecret: String): AuthResult<RefreshResponse> =
retryWithPolicy(
defaultBackoffPolicy,
) {
tokenService.getTokenFromClientSecret(
authConfig.clientId,
clientSecret,
GRANT_TYPE_CLIENT_CREDENTIALS,
authConfig.scopes.toScopesString(),
)
}
}

companion object {

Expand Down
2 changes: 2 additions & 0 deletions auth/src/main/kotlin/com/tidal/sdk/auth/di/AuthComponent.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.tidal.sdk.auth.model.AuthConfig
import dagger.BindsInstance
import dagger.Component
import javax.inject.Singleton
import kotlinx.coroutines.sync.Mutex

@Singleton
@Component(
Expand All @@ -27,6 +28,7 @@ interface AuthComponent {
fun create(
@BindsInstance context: Context,
@BindsInstance config: AuthConfig,
@BindsInstance mutex: Mutex = Mutex(),
): AuthComponent
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import dagger.Provides
import javax.inject.Named
import javax.inject.Singleton
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.sync.Mutex
import retrofit2.Retrofit

@Module
Expand All @@ -38,6 +39,7 @@ internal class CredentialsModule {
tokenService: TokenService,
@Named("default") defaultBackoffPolicy: RetryPolicy,
@Named("upgrade") upgradeBackoffPolicy: RetryPolicy,
mutex: Mutex,
bus: MutableSharedFlow<TidalMessage>,
) = TokenRepository(
authConfig,
Expand All @@ -46,6 +48,7 @@ internal class CredentialsModule {
tokenService,
defaultBackoffPolicy,
upgradeBackoffPolicy,
mutex,
bus,
)
}
3 changes: 3 additions & 0 deletions auth/src/main/kotlin/com/tidal/sdk/auth/di/LoginModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import dagger.Reusable
import javax.inject.Named
import javax.inject.Singleton
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.sync.Mutex
import retrofit2.Retrofit

@Module
Expand Down Expand Up @@ -58,6 +59,7 @@ internal class LoginModule {
loginUriBuilder: LoginUriBuilder,
loginService: LoginService,
tokensStore: TokensStore,
mutex: Mutex,
@Named("default") retryPolicy: RetryPolicy,
bus: MutableSharedFlow<TidalMessage>,
): LoginRepository = LoginRepository(
Expand All @@ -68,6 +70,7 @@ internal class LoginModule {
loginService,
tokensStore,
retryPolicy,
mutex,
bus,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ import com.tidal.sdk.auth.util.toScopesString
import com.tidal.sdk.common.TidalMessage
import com.tidal.sdk.common.d
import com.tidal.sdk.common.logger
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext

internal class LoginRepository constructor(
private val authConfig: AuthConfig,
Expand All @@ -29,6 +33,7 @@ internal class LoginRepository constructor(
private val loginService: LoginService,
private val tokensStore: TokensStore,
private val exponentialBackoffPolicy: RetryPolicy,
private val tokenMutex: Mutex,
private val bus: MutableSharedFlow<TidalMessage>,
) {

Expand Down Expand Up @@ -78,14 +83,18 @@ internal class LoginRepository constructor(
}

suspend fun setCredentials(credentials: Credentials, refreshToken: String? = null) {
val storedTokens = tokensStore.getLatestTokens(authConfig.credentialsKey)
if (credentials != storedTokens?.credentials) {
val tokens = Tokens(
credentials,
refreshToken ?: storedTokens?.refreshToken,
)
tokensStore.saveTokens(tokens)
bus.emit(CredentialsUpdatedMessage(tokens.credentials))
withContext(Dispatchers.IO) {
tokenMutex.withLock {
val storedTokens = tokensStore.getLatestTokens(authConfig.credentialsKey)
if (credentials != storedTokens?.credentials) {
val tokens = Tokens(
credentials,
refreshToken ?: storedTokens?.refreshToken,
)
tokensStore.saveTokens(tokens)
bus.emit(CredentialsUpdatedMessage(tokens.credentials))
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,9 @@ internal class LegacyCredentialsMigrator {
val exceptions = mutableListOf<Exception>()
for (operation in operations) {
try {
return operation().also {
println("Successfully decoded using legacy types.")
}
return operation()
} catch (e: Exception) {
logger.i { "Failed to decode using legacy types." }
println("Failed to decode using legacy types.")
exceptions.plus(e)
}
}
Expand Down
7 changes: 4 additions & 3 deletions auth/src/test/kotlin/com/tidal/sdk/auth/FakeTokenService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ internal class FakeTokenService : TokenService {

var calls = mutableListOf<CallType>()
var throwableToThrow: Throwable? = null
var responseDelay = 10L

override suspend fun getTokenFromRefreshToken(
clientId: String,
Expand All @@ -18,7 +19,7 @@ internal class FakeTokenService : TokenService {
scope: String,
): RefreshResponse {
calls.add(CallType.Refresh)
delay(10)
delay(responseDelay)
throwableToThrow?.let {
throw it
} ?: run {
Expand All @@ -41,7 +42,7 @@ internal class FakeTokenService : TokenService {
scope: String,
): RefreshResponse {
calls.add(CallType.Secret)
delay(10)
delay(responseDelay)
throwableToThrow?.let {
throw it
} ?: run {
Expand All @@ -65,7 +66,7 @@ internal class FakeTokenService : TokenService {
grantType: String,
): UpgradeResponse {
calls.add(CallType.Upgrade)
delay(10)
delay(responseDelay)
throwableToThrow?.let {
throw it
} ?: run {
Expand Down
12 changes: 10 additions & 2 deletions auth/src/test/kotlin/com/tidal/sdk/auth/LoginRepositoryTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import io.mockk.mockk
import io.mockk.mockkStatic
import java.util.Locale
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.BeforeAll
Expand Down Expand Up @@ -80,10 +81,16 @@ class LoginRepositoryTest {
authConfig,
timeProvider,
CodeChallengeBuilder(),
LoginUriBuilder(TEST_CLIENT_ID, TEST_CLIENT_UNIQUE_KEY, loginBaseUrl, authConfig.scopes),
LoginUriBuilder(
TEST_CLIENT_ID,
TEST_CLIENT_UNIQUE_KEY,
loginBaseUrl,
authConfig.scopes
),
loginService,
tokensStore,
retryPolicy,
Mutex(),
bus,
)
}
Expand Down Expand Up @@ -503,7 +510,8 @@ class LoginRepositoryTest {
"In case of 5xx returns, initializeDeviceLogin should trigger retries as defined by the retryPolicy"
}
assert(
((result as AuthResult.Failure).message as RetryableError).code == testErrorCode.toString(),
((result as AuthResult.Failure).message as RetryableError).code ==
testErrorCode.toString(),
) {
"When finished retrying, a RetryableError should be returned that cointains the correct error code."
}
Expand Down
Loading
Loading