diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index 25f280d2c..61def05fb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -227,6 +227,7 @@ import org.whispersystems.textsecuregcm.storage.MessagesDynamoDb; import org.whispersystems.textsecuregcm.storage.MessagesManager; import org.whispersystems.textsecuregcm.storage.OneTimeDonationsManager; +import org.whispersystems.textsecuregcm.storage.PersistentTimer; import org.whispersystems.textsecuregcm.storage.PhoneNumberIdentifiers; import org.whispersystems.textsecuregcm.storage.Profiles; import org.whispersystems.textsecuregcm.storage.ProfilesManager; @@ -1097,6 +1098,7 @@ protected void configureServer(final ServerBuilder serverBuilder) { log.info("Registered spam filter: {}", filter.getClass().getName()); }); + final PersistentTimer persistentTimer = new PersistentTimer(rateLimitersCluster, clock); final PhoneVerificationTokenManager phoneVerificationTokenManager = new PhoneVerificationTokenManager( phoneNumberIdentifiers, registrationServiceClient, registrationRecoveryPasswordsManager, registrationRecoveryChecker); @@ -1115,7 +1117,7 @@ protected void configureServer(final ServerBuilder serverBuilder) { config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), zkAuthOperations, callingGenericZkSecretParams, clock), new ChallengeController(rateLimitChallengeManager, challengeConstraintChecker), - new DeviceController(accountsManager, clientPublicKeysManager, rateLimiters, config.getMaxDevices()), + new DeviceController(accountsManager, clientPublicKeysManager, rateLimiters, persistentTimer, config.getMaxDevices()), new DeviceCheckController(clock, backupAuthManager, appleDeviceCheckManager, rateLimiters, config.getDeviceCheck().backupRedemptionLevel(), config.getDeviceCheck().backupRedemptionDuration()), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index 291d7bc92..1a90916e8 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -7,7 +7,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HttpHeaders; import io.dropwizard.auth.Auth; -import io.lettuce.core.RedisException; import io.micrometer.core.instrument.Metrics; import io.micrometer.core.instrument.Tags; import io.micrometer.core.instrument.Timer; @@ -81,6 +80,7 @@ import org.whispersystems.textsecuregcm.storage.DeviceCapability; import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException; +import org.whispersystems.textsecuregcm.storage.PersistentTimer; import org.whispersystems.textsecuregcm.util.DeviceCapabilityAdapter; import org.whispersystems.textsecuregcm.util.EnumMapUtil; import org.whispersystems.textsecuregcm.util.ExceptionUtils; @@ -100,6 +100,7 @@ public class DeviceController { private final AccountsManager accounts; private final ClientPublicKeysManager clientPublicKeysManager; private final RateLimiters rateLimiters; + private final PersistentTimer persistentTimer; private final Map maxDeviceConfiguration; private final EnumMap linkedDeviceListenersByPlatform; @@ -108,9 +109,11 @@ public class DeviceController { private static final String LINKED_DEVICE_LISTENER_GAUGE_NAME = MetricsUtil.name(DeviceController.class, "linkedDeviceListeners"); + private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE = "wait_for_linked_device"; private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAME = MetricsUtil.name(DeviceController.class, "waitForLinkedDeviceDuration"); + private static final String WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE = "wait_for_transfer_archive"; private static final String WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME = MetricsUtil.name(DeviceController.class, "waitForTransferArchiveDuration"); @@ -124,11 +127,13 @@ public class DeviceController { public DeviceController(final AccountsManager accounts, final ClientPublicKeysManager clientPublicKeysManager, final RateLimiters rateLimiters, + final PersistentTimer persistentTimer, final Map maxDeviceConfiguration) { this.accounts = accounts; this.clientPublicKeysManager = clientPublicKeysManager; this.rateLimiters = rateLimiters; + this.persistentTimer = persistentTimer; this.maxDeviceConfiguration = maxDeviceConfiguration; linkedDeviceListenersByPlatform = @@ -366,32 +371,30 @@ The amount of time (in seconds) to wait for a response. If the expected device i @HeaderParam(HttpHeaders.USER_AGENT) String userAgent) { final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent); linkedDeviceListenerCounter.incrementAndGet(); - final Timer.Sample sample = Timer.start(); return rateLimiters.getWaitForLinkedDeviceLimiter() .validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) - .thenCompose(ignored -> accounts.waitForNewLinkedDevice( - authenticatedDevice.getAccount().getUuid(), - authenticatedDevice.getAuthenticatedDevice(), - tokenIdentifier, - Duration.ofSeconds(timeoutSeconds))) - .thenApply(maybeDeviceInfo -> maybeDeviceInfo - .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) - .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) - .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class, - e -> Response.status(Response.Status.BAD_REQUEST).build())) - .whenComplete((response, throwable) -> { - linkedDeviceListenerCounter.decrementAndGet(); - - if (response != null) { - sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) - .publishPercentileHistogram(true) - .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent), - io.micrometer.core.instrument.Tag.of("deviceFound", - String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) - .register(Metrics.globalRegistry)); - } - }); + .thenCompose(ignored -> persistentTimer.start(WAIT_FOR_LINKED_DEVICE_TIMER_NAMESPACE, tokenIdentifier)) + .thenCompose(sample -> accounts.waitForNewLinkedDevice( + authenticatedDevice.getAccount().getUuid(), + authenticatedDevice.getAuthenticatedDevice(), + tokenIdentifier, + Duration.ofSeconds(timeoutSeconds)) + .thenApply(maybeDeviceInfo -> maybeDeviceInfo + .map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) + .exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class, + e -> Response.status(Response.Status.BAD_REQUEST).build())) + .whenComplete((response, throwable) -> { + linkedDeviceListenerCounter.decrementAndGet(); + + if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) { + sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME) + .publishPercentileHistogram(true) + .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))) + .register(Metrics.globalRegistry)); + } + })); } private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) { @@ -529,7 +532,8 @@ The amount of time (in seconds) to wait for a response. If a transfer archive fo public CompletionStage recordTransferArchiveUploaded(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, @NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) { - return rateLimiters.getUploadTransferArchiveLimiter().validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) + return rateLimiters.getUploadTransferArchiveLimiter() + .validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) .thenCompose(ignored -> accounts.recordTransferArchiveUpload(authenticatedDevice.getAccount(), transferArchiveUploadedRequest.destinationDeviceId(), Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()), @@ -568,30 +572,25 @@ The amount of time (in seconds) to wait for a response. If a transfer archive fo @HeaderParam(HttpHeaders.USER_AGENT) @Nullable String userAgent) { - final Timer.Sample sample = Timer.start(); final String rateLimiterKey = authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI) + ":" + authenticatedDevice.getAuthenticatedDevice().getId(); return rateLimiters.getWaitForTransferArchiveLimiter().validateAsync(rateLimiterKey) - .thenCompose(ignored -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(), - authenticatedDevice.getAuthenticatedDevice(), - Duration.ofSeconds(timeoutSeconds))) - .thenApply(maybeTransferArchive -> maybeTransferArchive - .map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build()) - .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) - .whenComplete((response, throwable) -> { - if (response == null) { - return; - } - sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME) - .publishPercentileHistogram(true) - .tags(Tags.of( - UserAgentTagUtil.getPlatformTag(userAgent), - io.micrometer.core.instrument.Tag.of( - "archiveUploaded", - String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode())))) - .register(Metrics.globalRegistry)); - }); + .thenCompose(ignored -> persistentTimer.start(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAMESPACE, rateLimiterKey)) + .thenCompose(sample -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(), + authenticatedDevice.getAuthenticatedDevice(), + Duration.ofSeconds(timeoutSeconds)) + .thenApply(maybeTransferArchive -> maybeTransferArchive + .map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())) + .whenComplete((response, throwable) -> { + if (response != null && response.getStatus() == Response.Status.OK.getStatusCode()) { + sample.stop(Timer.builder(WAIT_FOR_TRANSFER_ARCHIVE_TIMER_NAME) + .publishPercentileHistogram(true) + .tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent))) + .register(Metrics.globalRegistry)); + } + })); } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/storage/PersistentTimer.java b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PersistentTimer.java new file mode 100644 index 000000000..105ffb61b --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/storage/PersistentTimer.java @@ -0,0 +1,104 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.storage; + +import com.google.common.annotations.VisibleForTesting; +import io.lettuce.core.SetArgs; +import io.micrometer.core.instrument.Timer; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import javax.annotation.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient; +import org.whispersystems.textsecuregcm.util.Util; + +/** + * Timers for operations that may span machines or requests and require a persistently stored timer start itme + */ +public class PersistentTimer { + + private static final Logger logger = LoggerFactory.getLogger(PersistentTimer.class); + + private static String TIMER_NAMESPACE = "persistent_timer"; + @VisibleForTesting + static final Duration TIMER_TTL = Duration.ofHours(1); + + private final FaultTolerantRedisClusterClient redisClient; + private final Clock clock; + + + public PersistentTimer(final FaultTolerantRedisClusterClient redisClient, final Clock clock) { + this.redisClient = redisClient; + this.clock = clock; + } + + public class Sample { + + private final Instant start; + private final String redisKey; + + public Sample(final Instant start, final String redisKey) { + this.start = start; + this.redisKey = redisKey; + } + + /** + * Stop the timer, recording the duration between now and the first call to start. This deletes the persistent timer. + * + * @param timer The micrometer timer to record the duration to + * @return A future that completes when the resources associated with the persistent timer have been destroyed + */ + public CompletableFuture stop(Timer timer) { + Duration duration = Duration.between(start, clock.instant()); + timer.record(duration); + return redisClient.withCluster(connection -> connection.async().del(redisKey)) + .toCompletableFuture() + .thenRun(Util.NOOP); + } + } + + /** + * Start the timer if a timer with the provided namespaced key has not already been started, otherwise return the + * existing sample. + * + * @param namespace A namespace prefix to use for the timer + * @param key The unique key within the namespace that identifies the timer + * @return A future that completes with a {@link Sample} that can later be used to record the final duration. + */ + public CompletableFuture start(final String namespace, final String key) { + final Instant now = clock.instant(); + final String redisKey = redisKey(namespace, key); + + return redisClient.withCluster(connection -> + connection.async().setGet(redisKey, String.valueOf(now.getEpochSecond()), SetArgs.Builder.nx().ex(TIMER_TTL))) + .toCompletableFuture() + .thenApply(serialized -> new Sample(parseStoredTimestamp(serialized).orElse(now), redisKey)); + } + + @VisibleForTesting + String redisKey(final String namespace, final String key) { + return String.format("%s::%s::%s", TIMER_NAMESPACE, namespace, key); + } + + private static Optional parseStoredTimestamp(final @Nullable String serialized) { + return Optional + .ofNullable(serialized) + .flatMap(s -> { + try { + return Optional.of(Long.parseLong(s)); + } catch (NumberFormatException e) { + logger.warn("Failed to parse stored timestamp {}", s, e); + return Optional.empty(); + } + }) + .map(Instant::ofEpochSecond); + } + +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index c735cec6e..0cac6fce1 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -12,7 +12,6 @@ import static org.mockito.ArgumentMatchers.anyByte; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.clearInvocations; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -90,6 +89,7 @@ import org.whispersystems.textsecuregcm.storage.DeviceCapability; import org.whispersystems.textsecuregcm.storage.DeviceSpec; import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException; +import org.whispersystems.textsecuregcm.storage.PersistentTimer; import org.whispersystems.textsecuregcm.tests.util.AccountsHelper; import org.whispersystems.textsecuregcm.tests.util.AuthHelper; import org.whispersystems.textsecuregcm.tests.util.KeysHelper; @@ -104,6 +104,7 @@ class DeviceControllerTest { private static final AccountsManager accountsManager = mock(AccountsManager.class); private static final ClientPublicKeysManager clientPublicKeysManager = mock(ClientPublicKeysManager.class); + private static final PersistentTimer persistentTimer = mock(PersistentTimer.class); private static final RateLimiters rateLimiters = mock(RateLimiters.class); private static final RateLimiter rateLimiter = mock(RateLimiter.class); @SuppressWarnings("unchecked") @@ -123,6 +124,7 @@ class DeviceControllerTest { accountsManager, clientPublicKeysManager, rateLimiters, + persistentTimer, deviceConfiguration); @RegisterExtension @@ -161,6 +163,9 @@ void setup() { when(clientPublicKeysManager.setPublicKey(any(), anyByte(), any())) .thenReturn(CompletableFuture.completedFuture(null)); + when(persistentTimer.start(anyString(), anyString())) + .thenReturn(CompletableFuture.completedFuture(mock(PersistentTimer.Sample.class))); + AccountsHelper.setupMockUpdate(accountsManager); } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/storage/PersistentTimerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PersistentTimerTest.java new file mode 100644 index 000000000..02b5006ec --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/storage/PersistentTimerTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2025 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ +package org.whispersystems.textsecuregcm.storage; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import io.micrometer.core.instrument.Timer; +import java.time.Duration; +import java.time.Instant; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.whispersystems.textsecuregcm.redis.RedisClusterExtension; +import org.whispersystems.textsecuregcm.util.TestClock; + +class PersistentTimerTest { + + private static final String NAMESPACE = "namespace"; + private static final String KEY = "key"; + + @RegisterExtension + private static final RedisClusterExtension CLUSTER_EXTENSION = RedisClusterExtension.builder().build(); + private TestClock clock; + private PersistentTimer timer; + + @BeforeEach + public void setup() { + clock = TestClock.pinned(Instant.ofEpochSecond(10)); + timer = new PersistentTimer(CLUSTER_EXTENSION.getRedisCluster(), clock); + } + + @Test + public void testStop() { + PersistentTimer.Sample sample = timer.start(NAMESPACE, KEY).join(); + final String redisKey = timer.redisKey(NAMESPACE, KEY); + + final String actualStartString = CLUSTER_EXTENSION.getRedisCluster() + .withCluster(conn -> conn.sync().get(redisKey)); + final Instant actualStart = Instant.ofEpochSecond(Long.parseLong(actualStartString)); + assertThat(actualStart).isEqualTo(clock.instant()); + + final long ttl = CLUSTER_EXTENSION.getRedisCluster() + .withCluster(conn -> conn.sync().ttl(redisKey)); + + assertThat(ttl).isBetween(0L, PersistentTimer.TIMER_TTL.getSeconds()); + + Timer mockTimer = mock(Timer.class); + clock.pin(clock.instant().plus(Duration.ofSeconds(5))); + sample.stop(mockTimer).join(); + verify(mockTimer).record(Duration.ofSeconds(5)); + + final String afterDeletion = CLUSTER_EXTENSION.getRedisCluster() + .withCluster(conn -> conn.sync().get(redisKey)); + + assertThat(afterDeletion).isNull(); + } + + @Test + public void testNamespace() { + Timer mockTimer = mock(Timer.class); + + clock.pin(Instant.ofEpochSecond(10)); + PersistentTimer.Sample timer1 = timer.start("n1", KEY).join(); + clock.pin(Instant.ofEpochSecond(20)); + PersistentTimer.Sample timer2 = timer.start("n2", KEY).join(); + clock.pin(Instant.ofEpochSecond(30)); + + timer2.stop(mockTimer).join(); + verify(mockTimer).record(Duration.ofSeconds(10)); + + timer1.stop(mockTimer).join(); + verify(mockTimer).record(Duration.ofSeconds(20)); + } + + @Test + public void testMultipleStart() { + Timer mockTimer = mock(Timer.class); + + clock.pin(Instant.ofEpochSecond(10)); + PersistentTimer.Sample timer1 = timer.start(NAMESPACE, KEY).join(); + clock.pin(Instant.ofEpochSecond(11)); + PersistentTimer.Sample timer2 = timer.start(NAMESPACE, KEY).join(); + clock.pin(Instant.ofEpochSecond(12)); + PersistentTimer.Sample timer3 = timer.start(NAMESPACE, KEY).join(); + + clock.pin(Instant.ofEpochSecond(20)); + timer2.stop(mockTimer).join(); + verify(mockTimer).record(Duration.ofSeconds(10)); + + assertThatNoException().isThrownBy(() -> timer1.stop(mockTimer).join()); + assertThatNoException().isThrownBy(() -> timer3.stop(mockTimer).join()); + } + + +}