diff --git a/services/pom.xml b/services/pom.xml index 7b02eac4..b0d3efc4 100644 --- a/services/pom.xml +++ b/services/pom.xml @@ -15,7 +15,7 @@ 1.0 Mfa addon Rest endpoints - 0.64 + 0.69 diff --git a/services/src/main/java/org/exoplatform/mfa/filter/MfaFilter.java b/services/src/main/java/org/exoplatform/mfa/filter/MfaFilter.java index dbfb6a69..44412def 100644 --- a/services/src/main/java/org/exoplatform/mfa/filter/MfaFilter.java +++ b/services/src/main/java/org/exoplatform/mfa/filter/MfaFilter.java @@ -14,6 +14,7 @@ import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import org.apache.commons.lang.StringUtils; import org.exoplatform.container.ExoContainer; import org.exoplatform.container.PortalContainer; import org.exoplatform.mfa.api.MfaService; @@ -41,21 +42,26 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha HttpServletRequest httpServletRequest = (HttpServletRequest) request; HttpServletResponse httpServletResponse = (HttpServletResponse) response; HttpSession session = httpServletRequest.getSession(); - ExoContainer container = PortalContainer.getInstance(); + PortalContainer container = PortalContainer.getInstance(); MfaService mfaService = container.getComponentInstanceOfType(MfaService.class); String requestUri = httpServletRequest.getRequestURI(); - if (httpServletRequest.getRemoteUser() != null && - mfaService.isMfaFeatureActivated() && - excludedUrls.stream().noneMatch(requestUri::startsWith) && - (mfaService.isProtectedUri(requestUri) || - mfaService.currentUserIsInProtectedGroup(ConversationState.getCurrent().getIdentity())) && - shouldAuthenticateFromSession(session)) { + if (httpServletRequest.getRemoteUser() != null && mfaService.isMfaFeatureActivated() && (mfaService.isProtectedUri(requestUri) + || mfaService.currentUserIsInProtectedGroup(ConversationState.getCurrent().getIdentity()))) { + if (shouldAuthenticateFromSession(session) && excludedUrls.stream().noneMatch(requestUri::startsWith)) { LOG.debug("Mfa Filter must redirect on page to fill token"); - httpServletResponse.sendRedirect(MFA_URI+"?initialUri=" + requestUri); + httpServletResponse.sendRedirect(MFA_URI + "?initialUri=" + requestUri); return; + } else if (!shouldAuthenticateFromSession(session) && requestUri.startsWith(MFA_URI)) { + String queryString = httpServletRequest.getQueryString(); + String initialUri = "/"; + if (StringUtils.isNotBlank(queryString) && queryString.contains("initialUri=")) { + initialUri = queryString.substring(11); + } + httpServletResponse.sendRedirect(initialUri); + return; + } } - chain.doFilter(request, response); } diff --git a/services/src/test/java/org/exoplatform/mfa/filter/MfaFilterTest.java b/services/src/test/java/org/exoplatform/mfa/filter/MfaFilterTest.java new file mode 100644 index 00000000..786d299a --- /dev/null +++ b/services/src/test/java/org/exoplatform/mfa/filter/MfaFilterTest.java @@ -0,0 +1,81 @@ +package org.exoplatform.mfa.filter; + +import org.exoplatform.commons.utils.CommonsUtils; +import org.exoplatform.container.ExoContainer; +import org.exoplatform.container.PortalContainer; +import org.exoplatform.mfa.api.MfaService; +import org.exoplatform.services.listener.ListenerService; +import org.exoplatform.services.security.ConversationState; +import org.exoplatform.services.security.Identity; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + +import java.io.IOException; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +public class MfaFilterTest { + + MockedStatic PORTAL_CONTAINER = Mockito.mockStatic(PortalContainer.class); + MockedStatic CONVERSATION_STATE = Mockito.mockStatic(ConversationState.class); + PortalContainer portalContainer = Mockito.mock(PortalContainer.class); + ConversationState conversationState = Mockito.mock(ConversationState.class); + + @Before + public void setUp() throws Exception { + PORTAL_CONTAINER.when(PortalContainer::getInstance).thenReturn(portalContainer); + CONVERSATION_STATE.when(ConversationState::getCurrent).thenReturn(conversationState); + } + + @Test + public void testDoFilter() { + MfaFilter mfaFilter = new MfaFilter(); + HttpServletRequest httpServletRequest = Mockito.mock(HttpServletRequest.class); + HttpServletResponse httpServletResponse = Mockito.mock(HttpServletResponse.class); + FilterChain chain = Mockito.mock(FilterChain.class); + MfaService mfaService = Mockito.mock(MfaService.class); + ConversationState conversationState = Mockito.mock(ConversationState.class); + HttpSession httpSession = Mockito.mock(HttpSession.class); + + when(httpServletRequest.getRemoteUser()).thenReturn("root"); + when(mfaService.isMfaFeatureActivated()).thenReturn(true); + Identity identity = new Identity("root"); + when(conversationState.getIdentity()).thenReturn(identity); + when(ConversationState.getCurrent()).thenReturn(conversationState); + when(mfaService.currentUserIsInProtectedGroup(identity)).thenReturn(true); + when(httpServletRequest.getRequestURI()).thenReturn("/portal/dw/protectedUri"); + when(httpServletRequest.getQueryString()).thenReturn("initialUri=/portal/dw/protectedUri"); + when(portalContainer.getComponentInstanceOfType(MfaService.class)).thenReturn(mfaService); + when(httpServletRequest.getSession()).thenReturn(httpSession); + try { + mfaFilter.doFilter(httpServletRequest, httpServletResponse, chain); + verify(chain,times(0)).doFilter(httpServletRequest, httpServletResponse); + } catch (IOException | ServletException e) { + fail(); + } + when(httpSession.getAttribute("mfaValidated")).thenReturn(Boolean.TRUE); + when(httpServletRequest.getRequestURI()).thenReturn("/portal/dw/mfa-access?initialUri=/portal/dw/spaces"); + try { + mfaFilter.doFilter(httpServletRequest, httpServletResponse, chain); + verify(chain,times(0)).doFilter(httpServletRequest, httpServletResponse); + } catch (IOException | ServletException e) { + fail(); + } + } + + @After + public void tearDown() throws Exception { + PORTAL_CONTAINER.close(); + CONVERSATION_STATE.close(); + } +}