Skip to content

Commit

Permalink
Introduce paritionByRequest and bypassLimitByPredicate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kul committed May 10, 2024
1 parent 5c86608 commit ec425e4
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ public ServletLimiterBuilder partitionByParameter(String name) {
return partitionResolver(request -> Optional.ofNullable(request.getParameter(name)).orElse(null));
}

/**
* Partition the limit by the request instance. Percentages of the limit are partitioned to named
* groups. Group membership is derived from the provided mapping function.
* @param requestToGroup Mapping function from request to a named group.
* @return Chainable builder
*/
public ServletLimiterBuilder partitionByRequest(Function<HttpServletRequest, String> requestToGroup) {
return partitionResolver(request -> Optional.ofNullable(request).map(requestToGroup).orElse(null));
}

/**
* Partition the limit by the full path. Percentages of the limit are partitioned to named
* groups. Group membership is derived from the provided mapping function.
Expand Down Expand Up @@ -142,6 +152,16 @@ public ServletLimiterBuilder bypassLimitByMethod(String method) {
return bypassLimitResolver((context) -> method.equals(context.getMethod()));
}

/**
* Bypass limit if the predicate function returns true.
* @param predicate The predicate function to which {@link HttpServletRequest } instance is passed.
* If the predicate return true, the limit will be bypassed.
* @return Chainable builder
*/
public ServletLimiterBuilder bypassLimitByPredicate(Function<HttpServletRequest, Boolean> predicate) {
return bypassLimitResolver((context) -> predicate.apply(context));
}

@Override
protected ServletLimiterBuilder self() {
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public void beforeEachTest() {
limiter = Mockito.spy(new ServletLimiterBuilder()
.bypassLimitByMethod("GET")
.bypassLimitByPathInfo("/admin/health")
.bypassLimitByPredicate(ctx -> ctx.getMethod().equals("PATCH"))
.named(testName.getMethodName())
.metricRegistry(spectatorMetricRegistry)
.build());
Expand Down Expand Up @@ -130,6 +131,24 @@ public void testDoFilterBypassCheckPassedForPath() throws ServletException, IOEx
verifyCounts(0, 0, 0, 0, 1);
}

@Test
public void testDoFilterBypassCheckPassedForPredicate() throws ServletException, IOException {

ConcurrencyLimitServletFilter filter = new ConcurrencyLimitServletFilter(limiter);

MockHttpServletRequest request = new MockHttpServletRequest();
request.setMethod("PATCH");
request.setPathInfo("/admin/patch");
MockHttpServletResponse response = new MockHttpServletResponse();
MockFilterChain filterChain = new MockFilterChain();

filter.doFilter(request, response, filterChain);

assertEquals("Request should be passed to the downstream chain", request, filterChain.getRequest());
assertEquals("Response should be passed to the downstream chain", response, filterChain.getResponse());
verifyCounts(0, 0, 0, 0, 1);
}

@Test
public void testDoFilterBypassCheckFailed() throws ServletException, IOException {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,44 @@ public void nullPathDoesNotMatchesGroup() {
Mockito.verify(pathToGroup, Mockito.times(0)).get(Mockito.<String>any());
}

@Test
public void requestMatchesGroup() {
Map<String, String> requestMethodToGroup = Mockito.spy(new HashMap<>());
requestMethodToGroup.put("PATCH", "live");

Limiter<HttpServletRequest> limiter = new ServletLimiterBuilder()
.limit(VegasLimit.newDefault())
.partitionByRequest(request -> requestMethodToGroup.get(request.getMethod()))
.partition("live", 0.8)
.partition("batch", 0.2)
.build();

HttpServletRequest request = createMockRequestWithType("PATCH");
Optional<Listener> listener = limiter.acquire(request);

Assert.assertTrue(listener.isPresent());
Mockito.verify(requestMethodToGroup, Mockito.times(1)).get("PATCH");
}

@Test
public void requestDoesNotMatchesGroup() {
Map<String, String> requestMethodToGroup = Mockito.spy(new HashMap<>());
requestMethodToGroup.put("PATCH", "live");

Limiter<HttpServletRequest> limiter = new ServletLimiterBuilder()
.limit(VegasLimit.newDefault())
.partitionByRequest(request -> requestMethodToGroup.get(request.getMethod()))
.partition("live", 0.8)
.partition("batch", 0.2)
.build();

HttpServletRequest request = createMockRequestWithType("PUT");
Optional<Listener> listener = limiter.acquire(request);

Assert.assertTrue(listener.isPresent());
Mockito.verify(requestMethodToGroup, Mockito.times(1)).get("PUT");
}

private HttpServletRequest createMockRequestWithPrincipal(String name) {
HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
Principal principal = Mockito.mock(Principal.class);
Expand All @@ -169,4 +207,11 @@ private HttpServletRequest createMockRequestWithPathInfo(String name) {
Mockito.when(request.getPathInfo()).thenReturn(name);
return request;
}

private HttpServletRequest createMockRequestWithType(String type) {
HttpServletRequest request = Mockito.mock(HttpServletRequest.class);

Mockito.when(request.getMethod()).thenReturn(type);
return request;
}
}

0 comments on commit ec425e4

Please sign in to comment.