From 62151d13c952088f1bea5f7a65f5e28979e46ac7 Mon Sep 17 00:00:00 2001 From: sadilchamishka Date: Sat, 8 Feb 2025 06:49:11 +0530 Subject: [PATCH] Add unit tests --- .../org.wso2.carbon.identity.oauth/pom.xml | 1 + .../oauth2/dao/AccessTokenDAOImpl.java | 41 ++++++++----- .../carbon/identity/oauth/OAuthUtilTest.java | 57 +++++++++++++++++ .../oauth2/dao/AccessTokenDAOImplTest.java | 61 +++++++++++++++++-- 4 files changed, 142 insertions(+), 18 deletions(-) diff --git a/components/org.wso2.carbon.identity.oauth/pom.xml b/components/org.wso2.carbon.identity.oauth/pom.xml index 7d6ff14f011..a73494cb2a4 100644 --- a/components/org.wso2.carbon.identity.oauth/pom.xml +++ b/components/org.wso2.carbon.identity.oauth/pom.xml @@ -406,6 +406,7 @@ org.apache.commons.collections.map; version="${commons-collections.wso2.osgi.version.range}", org.apache.commons.lang; version="${commons-lang.wso2.osgi.version.range}", org.apache.commons.logging; version="${commons-logging.osgi.version.range}", + org.apache.commons.collections; version="${commons-logging.osgi.version.range}", com.google.gdata.client.authn.oauth; version="${gdata-core.imp.pkg.version.range}", diff --git a/components/org.wso2.carbon.identity.oauth/src/main/java/org/wso2/carbon/identity/oauth2/dao/AccessTokenDAOImpl.java b/components/org.wso2.carbon.identity.oauth/src/main/java/org/wso2/carbon/identity/oauth2/dao/AccessTokenDAOImpl.java index efed60bc072..eee1458b4c1 100644 --- a/components/org.wso2.carbon.identity.oauth/src/main/java/org/wso2/carbon/identity/oauth2/dao/AccessTokenDAOImpl.java +++ b/components/org.wso2.carbon.identity.oauth/src/main/java/org/wso2/carbon/identity/oauth2/dao/AccessTokenDAOImpl.java @@ -19,6 +19,7 @@ package org.wso2.carbon.identity.oauth2.dao; import org.apache.commons.codec.digest.DigestUtils; +import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.StringUtils; import org.apache.commons.logging.Log; @@ -1514,10 +1515,12 @@ public void revokeAccessTokensInBatch(String[] tokens, boolean isHashedToken) th } ps.executeBatch(); IdentityDatabaseUtil.commitTransaction(connection); - // To revoke request objects which have persisted against the access token. - OAuth2TokenUtil.postUpdateAccessTokens(Arrays.asList(tokens), OAuthConstants.TokenStates. - TOKEN_STATE_REVOKED); if (isTokenCleanupFeatureEnabled) { + if (connection.getMetaData().getDriverName().contains("Microsoft")) { + /* When token is deleted, the request objects get on delete cascade except for the SQL server. + Hence, invoke the event listener to revoke the request objects.*/ + revokeRequestObjectEntries(Arrays.asList(tokens)); + } oldTokenCleanupObject.cleanupTokensInBatch(oldTokens, connection); } } catch (SQLException e) { @@ -1545,14 +1548,13 @@ public void revokeAccessTokensInBatch(String[] tokens, boolean isHashedToken) th if (isTokenCleanupFeatureEnabled) { - oldTokenCleanupObject.cleanupTokenByTokenValue( - getHashingPersistenceProcessor().getProcessedAccessTokenIdentifier(tokens[0]), connection); - /* When token is deleted, the request objects get on delete cascade except for the SQL server. - Hence, invoke the event listener to revoke the request objects.*/ if (connection.getMetaData().getDriverName().contains("Microsoft")) { - OAuth2TokenUtil.postUpdateAccessTokens(Arrays.asList(tokens), OAuthConstants.TokenStates. - TOKEN_STATE_REVOKED); + /* When token is deleted, the request objects get on delete cascade except for the SQL server. + Hence, invoke the event listener to revoke the request objects.*/ + revokeRequestObjectEntries(Arrays.asList(tokens)); } + oldTokenCleanupObject.cleanupTokenByTokenValue( + getHashingPersistenceProcessor().getProcessedAccessTokenIdentifier(tokens[0]), connection); } } catch (SQLException e) { // IdentityDatabaseUtil.rollbackTransaction(connection); @@ -1618,13 +1620,13 @@ public void revokeAccessTokensIndividually(String[] tokens, boolean isHashedToke } accessTokenId.add(getTokenIdByAccessToken(token)); } - // To revoke request objects which have persisted against the access token. - if (accessTokenId.size() > 0) { - OAuth2TokenUtil.postUpdateAccessTokens(accessTokenId, OAuthConstants.TokenStates. - TOKEN_STATE_REVOKED); - } if (isTokenCleanupFeatureEnabled) { + if (connection.getMetaData().getDriverName().contains("Microsoft")) { + /* When token is deleted, the request objects get on delete cascade except for the SQL server. + Hence, invoke the event listener to revoke the request objects.*/ + revokeRequestObjectEntries(accessTokenId); + } for (String token : tokens) { oldTokenCleanupObject.cleanupTokenByTokenValue( getHashingPersistenceProcessor().getProcessedAccessTokenIdentifier(token), connection); @@ -3331,4 +3333,15 @@ private String getRootTenantDomainByOrganizationId(String organizationId) throws organizationId, e); } } + + /* When token is deleted, the request objects get on delete cascade except for the SQL server. + Hence, invoke the event listener to revoke the request objects.*/ + private void revokeRequestObjectEntries(List tokens) throws IdentityOAuth2Exception { + + if (CollectionUtils.isEmpty(tokens)) { + return; + } + OAuth2TokenUtil.postUpdateAccessTokens(tokens, OAuthConstants.TokenStates. + TOKEN_STATE_REVOKED); + } } diff --git a/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/oauth/OAuthUtilTest.java b/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/oauth/OAuthUtilTest.java index 0a4ea394a5f..07d68fb956e 100644 --- a/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/oauth/OAuthUtilTest.java +++ b/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/oauth/OAuthUtilTest.java @@ -38,14 +38,19 @@ import org.wso2.carbon.identity.application.mgt.ApplicationManagementService; import org.wso2.carbon.identity.common.testng.WithCarbonHome; import org.wso2.carbon.identity.common.testng.WithRealmService; +import org.wso2.carbon.identity.core.util.IdentityDatabaseUtil; +import org.wso2.carbon.identity.oauth.cache.AuthorizationGrantCache; import org.wso2.carbon.identity.oauth.cache.CacheEntry; import org.wso2.carbon.identity.oauth.cache.OAuthCache; import org.wso2.carbon.identity.oauth.cache.OAuthCacheKey; import org.wso2.carbon.identity.oauth.internal.OAuthComponentServiceHolder; import org.wso2.carbon.identity.oauth2.dao.AccessTokenDAO; +import org.wso2.carbon.identity.oauth2.dao.AuthorizationCodeDAO; +import org.wso2.carbon.identity.oauth2.dao.AuthorizationCodeDAOImpl; import org.wso2.carbon.identity.oauth2.dao.OAuthTokenPersistenceFactory; import org.wso2.carbon.identity.oauth2.dao.TokenManagementDAO; import org.wso2.carbon.identity.oauth2.model.AccessTokenDO; +import org.wso2.carbon.identity.oauth2.model.AuthzCodeDO; import org.wso2.carbon.identity.oauth2.util.OAuth2Util; import org.wso2.carbon.identity.organization.management.organization.user.sharing.OrganizationUserSharingService; import org.wso2.carbon.identity.organization.management.organization.user.sharing.models.UserAssociation; @@ -64,8 +69,12 @@ import org.wso2.carbon.user.core.util.UserCoreUtil; import org.wso2.carbon.utils.multitenancy.MultitenantConstants; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -126,6 +135,8 @@ public class OAuthUtilTest { private AutoCloseable closeable; private MockedStatic organizationManagementUtil; private MockedStatic oAuth2Util; + private MockedStatic identityDatabaseUtil; + private MockedStatic authorizationGrantCache; private MockedStatic oAuthTokenPersistenceFactory; @BeforeMethod @@ -140,7 +151,9 @@ public void setUp() throws Exception { OAuthComponentServiceHolder.getInstance().setOrganizationManager(organizationManager); OAuthComponentServiceHolder.getInstance().setRealmService(realmService); oAuth2Util = mockStatic(OAuth2Util.class); + identityDatabaseUtil = mockStatic(IdentityDatabaseUtil.class); oAuthTokenPersistenceFactory = mockStatic(OAuthTokenPersistenceFactory.class); + authorizationGrantCache = mockStatic(AuthorizationGrantCache.class); } @AfterMethod @@ -148,7 +161,9 @@ public void tearDown() throws Exception { organizationManagementUtil.close(); oAuth2Util.close(); + identityDatabaseUtil.close(); oAuthTokenPersistenceFactory.close(); + authorizationGrantCache.close(); reset(organizationUserSharingService); reset(roleManagementService); reset(applicationManagementService); @@ -476,6 +491,48 @@ public void testAuthenticatedUserInSharedUserFlow(boolean isSSOLoginUser, boolea } } + @Test + public void testRevokeAuthzCodes() throws Exception { + + UserStoreManager userStoreManager = mock(UserStoreManager.class); + + // Create a real instance of AuthorizationCodeDAO and spy on it + AuthorizationCodeDAO authorizationCodeDAO = Mockito.spy(new AuthorizationCodeDAOImpl()); + + when(userStoreManager.getTenantId()).thenReturn(-1234); + when(userStoreManager.getRealmConfiguration()).thenReturn(mock(RealmConfiguration.class)); + + OAuthTokenPersistenceFactory mockOAuthTokenPersistenceFactory = mock(OAuthTokenPersistenceFactory.class); + when(OAuthTokenPersistenceFactory.getInstance()).thenReturn(mockOAuthTokenPersistenceFactory); + when(mockOAuthTokenPersistenceFactory.getAuthorizationCodeDAO()).thenReturn(authorizationCodeDAO); + + List authorizationCodes = new ArrayList<>(); + AuthzCodeDO mockAuthzCodeDO = mock(AuthzCodeDO.class); + when(mockAuthzCodeDO.getConsumerKey()).thenReturn("consumer-key"); + when(mockAuthzCodeDO.getAuthorizationCode()).thenReturn("auth-code"); + when(mockAuthzCodeDO.getAuthzCodeId()).thenReturn("auth-code-id"); + + authorizationCodes.add(mockAuthzCodeDO); + + // Mock the getAuthorizationCodesDataByUser method to return the list of authorization codes + Mockito.doReturn(authorizationCodes).when(authorizationCodeDAO) + .getAuthorizationCodesDataByUser(any(AuthenticatedUser.class)); + + when(OAuth2Util.buildCacheKeyStringForAuthzCode(anyString(), anyString())).thenReturn("testAuthzCode"); + + Connection connection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + when(IdentityDatabaseUtil.getDBConnection()).thenReturn(connection); + when(connection.prepareStatement(anyString())).thenReturn(preparedStatement); + + AuthorizationGrantCache mockAuthorizationGrantCache = mock(AuthorizationGrantCache.class); + when(AuthorizationGrantCache.getInstance()).thenReturn(mockAuthorizationGrantCache); + + boolean result = OAuthUtil.revokeAuthzCodes("testUser", userStoreManager); + // Verify the result + assertTrue(result, "Authorization code revocation failed."); + } + private OAuthCache getOAuthCache(OAuthCacheKey oAuthCacheKey) { // Add some value to OAuthCache. diff --git a/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/oauth2/dao/AccessTokenDAOImplTest.java b/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/oauth2/dao/AccessTokenDAOImplTest.java index 585f625edde..c53e0c703b1 100644 --- a/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/oauth2/dao/AccessTokenDAOImplTest.java +++ b/components/org.wso2.carbon.identity.oauth/src/test/java/org/wso2/carbon/identity/oauth2/dao/AccessTokenDAOImplTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024, WSO2 LLC. (http://www.wso2.org). + * Copyright (c) 2024-2025, WSO2 LLC. (http://www.wso2.org). * * WSO2 Inc. licenses this file to you under the Apache License, * Version 2.0 (the "License"); you may not use this file except @@ -29,14 +29,21 @@ import org.testng.annotations.Test; import org.wso2.carbon.identity.common.testng.WithCarbonHome; import org.wso2.carbon.identity.core.util.IdentityDatabaseUtil; +import org.wso2.carbon.identity.oauth.config.OAuthServerConfiguration; import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception; import org.wso2.carbon.identity.oauth2.dao.util.DAOUtils; import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.util.HashMap; import java.util.Map; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; @WithCarbonHome @@ -51,6 +58,7 @@ public class AccessTokenDAOImplTest { public static final String DB_NAME = "AccessTokenDB"; Connection connection = null; private MockedStatic identityDatabaseUtil; + private MockedStatic oAuthServerConfiguration; @BeforeClass public void initTest() throws Exception { @@ -60,13 +68,20 @@ public void initTest() throws Exception { } catch (Exception e) { throw new IdentityOAuth2Exception("Error while initializing the data source", e); } - accessTokenDAO = new AccessTokenDAOImpl(); } @BeforeMethod public void setUp() throws Exception { connection = DAOUtils.getConnection(DB_NAME); + identityDatabaseUtil = mockStatic(IdentityDatabaseUtil.class); + oAuthServerConfiguration = mockStatic(OAuthServerConfiguration.class); + + OAuthServerConfiguration mockOAuthServerConfiguration = mock(OAuthServerConfiguration.class); + oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance) + .thenReturn(mockOAuthServerConfiguration); + when(mockOAuthServerConfiguration.isTokenCleanupEnabled()).thenReturn(true); + accessTokenDAO = new AccessTokenDAOImpl(); } @AfterMethod @@ -76,6 +91,7 @@ public void closeup() throws Exception { connection.close(); } identityDatabaseUtil.close(); + oAuthServerConfiguration.close(); } @AfterClass @@ -88,8 +104,6 @@ public void tearDown() throws Exception { public void getSessionIdentifierByTokenId() throws Exception { connection = DAOUtils.getConnection(DB_NAME); - identityDatabaseUtil = mockStatic(IdentityDatabaseUtil.class); - identityDatabaseUtil.when(() -> IdentityDatabaseUtil.getDBConnection(false)) .thenReturn(connection); assertEquals(accessTokenDAO.getSessionIdentifierByTokenId("2sa9a678f890877856y66e75f605d456"), @@ -103,4 +117,43 @@ private static void closeH2Base(String databaseName) throws Exception { dataSource.close(); } } + + @Test + public void testRevokeAccessTokensIndividually() throws Exception { + + String[] tokens = {}; + boolean isHashedToken = false; + + Connection mockConnection = mock(Connection.class); + DatabaseMetaData mockDatabaseMetaData = mock(DatabaseMetaData.class); + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockDatabaseMetaData.getDriverName()).thenReturn("Microsoft SQL Server"); + identityDatabaseUtil.when(IdentityDatabaseUtil::getDBConnection).thenReturn(mockConnection); + accessTokenDAO.revokeAccessTokensIndividually(tokens, isHashedToken); + } + + @Test + public void testRevokeAccessTokensInBatch() throws Exception { + + String[] tokens = {"token1", "token2"}; + boolean isHashedToken = true; + + Connection mockConnection = mock(Connection.class); + PreparedStatement preparedStatement = mock(PreparedStatement.class); + ResultSet resultSet = mock(ResultSet.class); + when(preparedStatement.executeQuery()).thenReturn(resultSet); + when(mockConnection.prepareStatement(anyString())).thenReturn(preparedStatement); + + DatabaseMetaData mockDatabaseMetaData = mock(DatabaseMetaData.class); + when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData); + when(mockDatabaseMetaData.getDriverName()).thenReturn("Microsoft SQL Server"); + identityDatabaseUtil.when(IdentityDatabaseUtil::getDBConnection).thenReturn(mockConnection); + + OAuthServerConfiguration mockOAuthServerConfiguration = mock(OAuthServerConfiguration.class); + oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance) + .thenReturn(mockOAuthServerConfiguration); + when(mockOAuthServerConfiguration.getHashAlgorithm()).thenReturn("SHA-256"); + + accessTokenDAO.revokeAccessTokensInBatch(tokens, isHashedToken); + } }