diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 6936d2d44..291d7bc92 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -343,7 +343,7 @@ public LinkDeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) Bas @ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed; clients may repeat the call to continue waiting") @ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid") @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") - public CompletableFuture waitForLinkedDevice( + public CompletionStage waitForLinkedDevice( @ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, @PathParam("tokenIdentifier") @@ -363,40 +363,35 @@ The amount of time (in seconds) to wait for a response. If the expected device i given amount of time, this endpoint will return a status of HTTP/204. """) final int timeoutSeconds, - @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException { - - rateLimiters.getWaitForLinkedDeviceLimiter().validate(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)); - + @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent); linkedDeviceListenerCounter.incrementAndGet(); - final Timer.Sample sample = Timer.start(); - try { - return accounts.waitForNewLinkedDevice(authenticatedDevice.getAccount().getUuid(), - authenticatedDevice.getAuthenticatedDevice(), tokenIdentifier, Duration.ofSeconds(timeoutSeconds)) - .thenApply(maybeDeviceInfo -> maybeDeviceInfo - .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) - .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) - .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class, - e -> Response.status(Response.Status.BAD_REQUEST).build())) - .whenComplete((response, throwable) -> { - linkedDeviceListenerCounter.decrementAndGet(); - - if (response != null) { - sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) - .publishPercentileHistogram(true) - .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), - io.micrometer.core.instrument.Tag.of("deviceFound", - String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) - .register(Metrics.globalRegistry)); - } - }); - } catch (final RedisException e) { - // `waitForNewLinkedDevice` could fail synchronously if the Redis circuit breaker is open; prevent counter drift - // if that happens - linkedDeviceListenerCounter.decrementAndGet(); - throw e; - } + + return rateLimiters.getWaitForLinkedDeviceLimiter() + .validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) + .thenCompose(ignored -> accounts.waitForNewLinkedDevice( + authenticatedDevice.getAccount().getUuid(), + authenticatedDevice.getAuthenticatedDevice(), + tokenIdentifier, + Duration.ofSeconds(timeoutSeconds))) + .thenApply(maybeDeviceInfo -> maybeDeviceInfo + .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) + .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class, + e -> Response.status(Response.Status.BAD_REQUEST).build())) + .whenComplete((response, throwable) -> { + linkedDeviceListenerCounter.decrementAndGet(); + + if (response != null) { + sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) + .publishPercentileHistogram(true) + .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), + io.micrometer.core.instrument.Tag.of("deviceFound", + String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) + .register(Metrics.globalRegistry)); + } + }); } private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) { diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 2a302c4a4..c735cec6e 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -955,6 +955,8 @@ void waitForLinkedDevice() { .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo))); + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); + try (final Response response = resources.getJerseyTest() .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .request() @@ -979,6 +981,8 @@ void waitForLinkedDeviceNoDeviceLinked() { .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); + try (final Response response = resources.getJerseyTest() .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .request() @@ -997,6 +1001,8 @@ void waitForLinkedDeviceBadTokenIdentifier() { .waitForNewLinkedDevice(eq(AuthHelper.VALID_UUID), eq(AuthHelper.VALID_DEVICE), eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException())); + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); + try (final Response response = resources.getJerseyTest() .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) .request() @@ -1042,10 +1048,11 @@ private static List waitForLinkedDeviceBadTokenIdentifierLength() { } @Test - void waitForLinkedDeviceRateLimited() throws RateLimitExceededException { + void waitForLinkedDeviceRateLimited() { final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); - doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID); + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)) + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null))); try (final Response response = resources.getJerseyTest() .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)