Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sadilchamishka committed Feb 8, 2025
1 parent 319aec9 commit 62151d1
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 18 deletions.
1 change: 1 addition & 0 deletions components/org.wso2.carbon.identity.oauth/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@
org.apache.commons.collections.map; version="${commons-collections.wso2.osgi.version.range}",
org.apache.commons.lang; version="${commons-lang.wso2.osgi.version.range}",
org.apache.commons.logging; version="${commons-logging.osgi.version.range}",
org.apache.commons.collections; version="${commons-logging.osgi.version.range}",

com.google.gdata.client.authn.oauth; version="${gdata-core.imp.pkg.version.range}",

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.wso2.carbon.identity.oauth2.dao;

import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
Expand Down Expand Up @@ -1514,10 +1515,12 @@ public void revokeAccessTokensInBatch(String[] tokens, boolean isHashedToken) th
}
ps.executeBatch();
IdentityDatabaseUtil.commitTransaction(connection);
// To revoke request objects which have persisted against the access token.
OAuth2TokenUtil.postUpdateAccessTokens(Arrays.asList(tokens), OAuthConstants.TokenStates.
TOKEN_STATE_REVOKED);
if (isTokenCleanupFeatureEnabled) {
if (connection.getMetaData().getDriverName().contains("Microsoft")) {
/* When token is deleted, the request objects get on delete cascade except for the SQL server.
Hence, invoke the event listener to revoke the request objects.*/
revokeRequestObjectEntries(Arrays.asList(tokens));
}
oldTokenCleanupObject.cleanupTokensInBatch(oldTokens, connection);
}
} catch (SQLException e) {
Expand Down Expand Up @@ -1545,14 +1548,13 @@ public void revokeAccessTokensInBatch(String[] tokens, boolean isHashedToken) th


if (isTokenCleanupFeatureEnabled) {
oldTokenCleanupObject.cleanupTokenByTokenValue(
getHashingPersistenceProcessor().getProcessedAccessTokenIdentifier(tokens[0]), connection);
/* When token is deleted, the request objects get on delete cascade except for the SQL server.
Hence, invoke the event listener to revoke the request objects.*/
if (connection.getMetaData().getDriverName().contains("Microsoft")) {
OAuth2TokenUtil.postUpdateAccessTokens(Arrays.asList(tokens), OAuthConstants.TokenStates.
TOKEN_STATE_REVOKED);
/* When token is deleted, the request objects get on delete cascade except for the SQL server.
Hence, invoke the event listener to revoke the request objects.*/
revokeRequestObjectEntries(Arrays.asList(tokens));
}
oldTokenCleanupObject.cleanupTokenByTokenValue(
getHashingPersistenceProcessor().getProcessedAccessTokenIdentifier(tokens[0]), connection);
}
} catch (SQLException e) {
// IdentityDatabaseUtil.rollbackTransaction(connection);
Expand Down Expand Up @@ -1618,13 +1620,13 @@ public void revokeAccessTokensIndividually(String[] tokens, boolean isHashedToke
}
accessTokenId.add(getTokenIdByAccessToken(token));
}
// To revoke request objects which have persisted against the access token.
if (accessTokenId.size() > 0) {
OAuth2TokenUtil.postUpdateAccessTokens(accessTokenId, OAuthConstants.TokenStates.
TOKEN_STATE_REVOKED);
}

if (isTokenCleanupFeatureEnabled) {
if (connection.getMetaData().getDriverName().contains("Microsoft")) {
/* When token is deleted, the request objects get on delete cascade except for the SQL server.
Hence, invoke the event listener to revoke the request objects.*/
revokeRequestObjectEntries(accessTokenId);
}
for (String token : tokens) {
oldTokenCleanupObject.cleanupTokenByTokenValue(
getHashingPersistenceProcessor().getProcessedAccessTokenIdentifier(token), connection);
Expand Down Expand Up @@ -3331,4 +3333,15 @@ private String getRootTenantDomainByOrganizationId(String organizationId) throws
organizationId, e);
}
}

/* When token is deleted, the request objects get on delete cascade except for the SQL server.
Hence, invoke the event listener to revoke the request objects.*/
private void revokeRequestObjectEntries(List<String> tokens) throws IdentityOAuth2Exception {

if (CollectionUtils.isEmpty(tokens)) {
return;
}
OAuth2TokenUtil.postUpdateAccessTokens(tokens, OAuthConstants.TokenStates.
TOKEN_STATE_REVOKED);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,19 @@
import org.wso2.carbon.identity.application.mgt.ApplicationManagementService;
import org.wso2.carbon.identity.common.testng.WithCarbonHome;
import org.wso2.carbon.identity.common.testng.WithRealmService;
import org.wso2.carbon.identity.core.util.IdentityDatabaseUtil;
import org.wso2.carbon.identity.oauth.cache.AuthorizationGrantCache;
import org.wso2.carbon.identity.oauth.cache.CacheEntry;
import org.wso2.carbon.identity.oauth.cache.OAuthCache;
import org.wso2.carbon.identity.oauth.cache.OAuthCacheKey;
import org.wso2.carbon.identity.oauth.internal.OAuthComponentServiceHolder;
import org.wso2.carbon.identity.oauth2.dao.AccessTokenDAO;
import org.wso2.carbon.identity.oauth2.dao.AuthorizationCodeDAO;
import org.wso2.carbon.identity.oauth2.dao.AuthorizationCodeDAOImpl;
import org.wso2.carbon.identity.oauth2.dao.OAuthTokenPersistenceFactory;
import org.wso2.carbon.identity.oauth2.dao.TokenManagementDAO;
import org.wso2.carbon.identity.oauth2.model.AccessTokenDO;
import org.wso2.carbon.identity.oauth2.model.AuthzCodeDO;
import org.wso2.carbon.identity.oauth2.util.OAuth2Util;
import org.wso2.carbon.identity.organization.management.organization.user.sharing.OrganizationUserSharingService;
import org.wso2.carbon.identity.organization.management.organization.user.sharing.models.UserAssociation;
Expand All @@ -64,8 +69,12 @@
import org.wso2.carbon.user.core.util.UserCoreUtil;
import org.wso2.carbon.utils.multitenancy.MultitenantConstants;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand Down Expand Up @@ -126,6 +135,8 @@ public class OAuthUtilTest {
private AutoCloseable closeable;
private MockedStatic<OrganizationManagementUtil> organizationManagementUtil;
private MockedStatic<OAuth2Util> oAuth2Util;
private MockedStatic<IdentityDatabaseUtil> identityDatabaseUtil;
private MockedStatic<AuthorizationGrantCache> authorizationGrantCache;
private MockedStatic<OAuthTokenPersistenceFactory> oAuthTokenPersistenceFactory;

@BeforeMethod
Expand All @@ -140,15 +151,19 @@ public void setUp() throws Exception {
OAuthComponentServiceHolder.getInstance().setOrganizationManager(organizationManager);
OAuthComponentServiceHolder.getInstance().setRealmService(realmService);
oAuth2Util = mockStatic(OAuth2Util.class);
identityDatabaseUtil = mockStatic(IdentityDatabaseUtil.class);
oAuthTokenPersistenceFactory = mockStatic(OAuthTokenPersistenceFactory.class);
authorizationGrantCache = mockStatic(AuthorizationGrantCache.class);
}

@AfterMethod
public void tearDown() throws Exception {

organizationManagementUtil.close();
oAuth2Util.close();
identityDatabaseUtil.close();
oAuthTokenPersistenceFactory.close();
authorizationGrantCache.close();
reset(organizationUserSharingService);
reset(roleManagementService);
reset(applicationManagementService);
Expand Down Expand Up @@ -476,6 +491,48 @@ public void testAuthenticatedUserInSharedUserFlow(boolean isSSOLoginUser, boolea
}
}

@Test
public void testRevokeAuthzCodes() throws Exception {

UserStoreManager userStoreManager = mock(UserStoreManager.class);

// Create a real instance of AuthorizationCodeDAO and spy on it
AuthorizationCodeDAO authorizationCodeDAO = Mockito.spy(new AuthorizationCodeDAOImpl());

when(userStoreManager.getTenantId()).thenReturn(-1234);
when(userStoreManager.getRealmConfiguration()).thenReturn(mock(RealmConfiguration.class));

OAuthTokenPersistenceFactory mockOAuthTokenPersistenceFactory = mock(OAuthTokenPersistenceFactory.class);
when(OAuthTokenPersistenceFactory.getInstance()).thenReturn(mockOAuthTokenPersistenceFactory);
when(mockOAuthTokenPersistenceFactory.getAuthorizationCodeDAO()).thenReturn(authorizationCodeDAO);

List<AuthzCodeDO> authorizationCodes = new ArrayList<>();
AuthzCodeDO mockAuthzCodeDO = mock(AuthzCodeDO.class);
when(mockAuthzCodeDO.getConsumerKey()).thenReturn("consumer-key");
when(mockAuthzCodeDO.getAuthorizationCode()).thenReturn("auth-code");
when(mockAuthzCodeDO.getAuthzCodeId()).thenReturn("auth-code-id");

authorizationCodes.add(mockAuthzCodeDO);

// Mock the getAuthorizationCodesDataByUser method to return the list of authorization codes
Mockito.doReturn(authorizationCodes).when(authorizationCodeDAO)
.getAuthorizationCodesDataByUser(any(AuthenticatedUser.class));

when(OAuth2Util.buildCacheKeyStringForAuthzCode(anyString(), anyString())).thenReturn("testAuthzCode");

Connection connection = mock(Connection.class);
PreparedStatement preparedStatement = mock(PreparedStatement.class);
when(IdentityDatabaseUtil.getDBConnection()).thenReturn(connection);
when(connection.prepareStatement(anyString())).thenReturn(preparedStatement);

AuthorizationGrantCache mockAuthorizationGrantCache = mock(AuthorizationGrantCache.class);
when(AuthorizationGrantCache.getInstance()).thenReturn(mockAuthorizationGrantCache);

boolean result = OAuthUtil.revokeAuthzCodes("testUser", userStoreManager);
// Verify the result
assertTrue(result, "Authorization code revocation failed.");
}

private OAuthCache getOAuthCache(OAuthCacheKey oAuthCacheKey) {

// Add some value to OAuthCache.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, WSO2 LLC. (http://www.wso2.org).
* Copyright (c) 2024-2025, WSO2 LLC. (http://www.wso2.org).
*
* WSO2 Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
Expand Down Expand Up @@ -29,14 +29,21 @@
import org.testng.annotations.Test;
import org.wso2.carbon.identity.common.testng.WithCarbonHome;
import org.wso2.carbon.identity.core.util.IdentityDatabaseUtil;
import org.wso2.carbon.identity.oauth.config.OAuthServerConfiguration;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception;
import org.wso2.carbon.identity.oauth2.dao.util.DAOUtils;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.HashMap;
import java.util.Map;

import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;

@WithCarbonHome
Expand All @@ -51,6 +58,7 @@ public class AccessTokenDAOImplTest {
public static final String DB_NAME = "AccessTokenDB";
Connection connection = null;
private MockedStatic<IdentityDatabaseUtil> identityDatabaseUtil;
private MockedStatic<OAuthServerConfiguration> oAuthServerConfiguration;

@BeforeClass
public void initTest() throws Exception {
Expand All @@ -60,13 +68,20 @@ public void initTest() throws Exception {
} catch (Exception e) {
throw new IdentityOAuth2Exception("Error while initializing the data source", e);
}
accessTokenDAO = new AccessTokenDAOImpl();
}

@BeforeMethod
public void setUp() throws Exception {

connection = DAOUtils.getConnection(DB_NAME);
identityDatabaseUtil = mockStatic(IdentityDatabaseUtil.class);
oAuthServerConfiguration = mockStatic(OAuthServerConfiguration.class);

OAuthServerConfiguration mockOAuthServerConfiguration = mock(OAuthServerConfiguration.class);
oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance)
.thenReturn(mockOAuthServerConfiguration);
when(mockOAuthServerConfiguration.isTokenCleanupEnabled()).thenReturn(true);
accessTokenDAO = new AccessTokenDAOImpl();
}

@AfterMethod
Expand All @@ -76,6 +91,7 @@ public void closeup() throws Exception {
connection.close();
}
identityDatabaseUtil.close();
oAuthServerConfiguration.close();
}

@AfterClass
Expand All @@ -88,8 +104,6 @@ public void tearDown() throws Exception {
public void getSessionIdentifierByTokenId() throws Exception {

connection = DAOUtils.getConnection(DB_NAME);
identityDatabaseUtil = mockStatic(IdentityDatabaseUtil.class);

identityDatabaseUtil.when(() -> IdentityDatabaseUtil.getDBConnection(false))
.thenReturn(connection);
assertEquals(accessTokenDAO.getSessionIdentifierByTokenId("2sa9a678f890877856y66e75f605d456"),
Expand All @@ -103,4 +117,43 @@ private static void closeH2Base(String databaseName) throws Exception {
dataSource.close();
}
}

@Test
public void testRevokeAccessTokensIndividually() throws Exception {

String[] tokens = {};
boolean isHashedToken = false;

Connection mockConnection = mock(Connection.class);
DatabaseMetaData mockDatabaseMetaData = mock(DatabaseMetaData.class);
when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData);
when(mockDatabaseMetaData.getDriverName()).thenReturn("Microsoft SQL Server");
identityDatabaseUtil.when(IdentityDatabaseUtil::getDBConnection).thenReturn(mockConnection);
accessTokenDAO.revokeAccessTokensIndividually(tokens, isHashedToken);
}

@Test
public void testRevokeAccessTokensInBatch() throws Exception {

String[] tokens = {"token1", "token2"};
boolean isHashedToken = true;

Connection mockConnection = mock(Connection.class);
PreparedStatement preparedStatement = mock(PreparedStatement.class);
ResultSet resultSet = mock(ResultSet.class);
when(preparedStatement.executeQuery()).thenReturn(resultSet);
when(mockConnection.prepareStatement(anyString())).thenReturn(preparedStatement);

DatabaseMetaData mockDatabaseMetaData = mock(DatabaseMetaData.class);
when(mockConnection.getMetaData()).thenReturn(mockDatabaseMetaData);
when(mockDatabaseMetaData.getDriverName()).thenReturn("Microsoft SQL Server");
identityDatabaseUtil.when(IdentityDatabaseUtil::getDBConnection).thenReturn(mockConnection);

OAuthServerConfiguration mockOAuthServerConfiguration = mock(OAuthServerConfiguration.class);
oAuthServerConfiguration.when(OAuthServerConfiguration::getInstance)
.thenReturn(mockOAuthServerConfiguration);
when(mockOAuthServerConfiguration.getHashAlgorithm()).thenReturn("SHA-256");

accessTokenDAO.revokeAccessTokensInBatch(tokens, isHashedToken);
}
}

0 comments on commit 62151d1

Please sign in to comment.